JAX compatibility#
2025-01-09
16 min read time
JAX provides a NumPy-like API, which combines automatic differentiation and the Accelerated Linear Algebra (XLA) compiler to achieve high-performance machine learning at scale.
JAX uses composable transformations of Python and NumPy through just-in-time (JIT) compilation, automatic vectorization, and parallelization. To learn about JAX, including profiling and optimizations, see the official JAX documentation.
ROCm support for JAX is upstreamed and users can build the official source code with ROCm support:
ROCm JAX release:
Offers AMD-validated and community Docker images with ROCm and JAX pre-installed.
ROCm JAX repository: ROCm/jax
See the ROCm JAX installation guide to get started.
Official JAX release:
Official JAX repository: jax-ml/jax
See the AMD GPU (Linux) installation section in the JAX documentation.
Note
AMD releases official ROCm JAX Docker images quarterly alongside new ROCm releases. These images undergo full AMD testing. Community ROCm JAX Docker images follow upstream JAX releases and use the latest available ROCm version.
Docker image compatibility#
AMD validates and publishes ready-made JAX images with ROCm backends on Docker Hub. The following Docker image tags and associated inventories are validated for ROCm 6.3.1. Click the icon to view the image on Docker Hub.
AMD publishes community JAX images with ROCm backends on Docker Hub. The following Docker image tags and associated inventories are tested for ROCm 6.2.4.
Docker image |
JAX |
Linux |
Python |
---|---|---|---|
rocm/jax-community | Ubuntu 22.04 |
||
rocm/jax-community | Ubuntu 22.04 |
||
rocm/jax-community | Ubuntu 22.04 |
Critical ROCm libraries for JAX#
The functionality of JAX with ROCm is determined by its underlying library dependencies. These critical ROCm components affect the capabilities, performance, and feature set available to developers.
ROCm library |
Version |
Purpose |
Used in |
---|---|---|---|
2.3.0 |
Provides GPU-accelerated Basic Linear Algebra Subprograms (BLAS) for matrix and vector operations. |
Matrix multiplication in |
|
0.10.0 |
hipBLASLt is an extension of hipBLAS, providing additional features like epilogues fused into the matrix multiplication kernel or use of integer tensor cores. |
Matrix multiplication in |
|
3.3.0 |
Provides a C++ template library for parallel algorithms for reduction, scan, sort and select. |
Reduction functions ( |
|
1.0.17 |
Provides GPU-accelerated Fast Fourier Transform (FFT) operations. |
Used in functions like |
|
2.11.0 |
Provides fast random number generation for GPUs. |
The |
|
2.3.0 |
Provides GPU-accelerated solvers for linear systems, eigenvalues, and singular value decompositions (SVD). |
Solving linear systems ( |
|
3.1.2 |
Accelerates operations on sparse matrices, such as sparse matrix-vector or matrix-matrix products. |
Sparse matrix multiplication ( |
|
0.2.2 |
Accelerates operations on sparse matrices, such as sparse matrix-vector or matrix-matrix products. |
Sparse matrix multiplication ( |
|
3.3.0 |
Optimized for deep learning primitives such as convolutions, pooling, normalization, and activation functions. |
Speeds up convolutional neural networks (CNNs), recurrent neural
networks (RNNs), and other layers. Used in operations like
|
|
2.21.5 |
Optimized for multi-GPU communication for operations like all-reduce, broadcast, and scatter. |
Distribute computations across multiple GPU with |
|
3.3.0 |
Provides a C++ template library for parallel algorithms like sorting, reduction, and scanning. |
Reduction operations like |
Supported and unsupported features#
The following table maps GPU-accelerated JAX modules to their supported ROCm and JAX versions.
Module |
Description |
Since JAX |
Since ROCm |
---|---|---|---|
|
Implements the NumPy API, using the primitives in |
0.1.56 |
5.0.0 |
|
Provides GPU-accelerated and differentiable implementations of many
functions from the SciPy library, leveraging JAX’s transformations
(e.g., |
0.1.56 |
5.0.0 |
|
A library of primitives operations that underpins libraries such as
|
0.1.57 |
5.0.0 |
|
Provides a number of routines for deterministic generation of sequences of pseudorandom numbers. |
0.1.58 |
5.0.0 |
|
Allows to define partitioning and distributing arrays across multiple devices. |
0.3.20 |
5.1.0 |
|
For exchanging tensor data between JAX and other libraries that support the DLPack standard. |
0.1.57 |
5.0.0 |
|
Enables the scaling of computations across multiple devices on a single machine or across multiple machines. |
0.1.74 |
5.0.0 |
|
Provides utilities for working with and managing data types in JAX arrays and computations. |
0.1.66 |
5.0.0 |
|
Contains image manipulation functions like resize, scale and translation. |
0.1.57 |
5.0.0 |
|
Contains common functions for neural network libraries. |
0.1.56 |
5.0.0 |
|
Computes the minimum, maximum, sum or product within segments of an array. |
0.1.57 |
5.0.0 |
|
Contains JAX’s tracing and time profiling features. |
0.1.57 |
5.0.0 |
|
Contains interfaces to stages of the compiled execution process. |
0.3.4 |
5.0.0 |
|
Provides utilities for working with tree-like container data structures. |
0.4.26 |
5.6.0 |
|
Provides utilities for working with nested data structures, or
|
0.1.65 |
5.0.0 |
|
Provides JAX-specific static type annotations. |
0.3.18 |
5.1.0 |
|
Provides modules for access to JAX internal machinery module. The
|
0.4.15 |
5.5.0 |
|
Serves as a collection of example code and libraries that demonstrate various capabilities of JAX. |
0.1.74 |
5.0.0 |
|
Namespace for experimental features and APIs that are in development or are not yet fully stable for production use. |
0.1.56 |
5.0.0 |
|
Set of internal tools and types for bridging between JAX’s Python frontend and its XLA backend. |
0.4.6 |
5.3.0 |
|
Library that integrates the Triton deep learning compiler with JAX. |
jax_triton 0.2.0 |
6.2.4 |
jax.scipy module#
A SciPy-like API for scientific computing.
Module |
Since JAX |
Since ROCm |
---|---|---|
|
0.3.11 |
5.1.0 |
|
0.1.71 |
5.0.0 |
|
0.4.15 |
5.5.0 |
|
0.1.76 |
5.0.0 |
|
0.1.56 |
5.0.0 |
|
0.1.56 |
5.0.0 |
|
0.1.57 |
5.0.0 |
|
0.1.56 |
5.0.0 |
|
0.4.12 |
5.4.0 |
|
0.1.56 |
5.0.0 |
|
0.1.56 |
5.0.0 |
|
0.1.56 |
5.0.0 |
jax.scipy.stats module#
Module |
Since JAX |
Since ROCm |
---|---|---|
|
0.1.56 |
5.0.0 |
|
0.1.56 |
5.0.0 |
|
0.1.61 |
5.0.0 |
|
0.4.14 |
5.4.0 |
|
0.1.56 |
5.0.0 |
|
0.1.61 |
5.0.0 |
|
0.1.56 |
5.0.0 |
|
0.1.56 |
5.0.0 |
|
0.1.56 |
5.0.0 |
|
0.3.15 |
5.2.0 |
|
0.1.56 |
5.0.0 |
|
0.1.56 |
5.0.0 |
|
0.1.56 |
5.0.0 |
|
0.3.18 |
5.1.0 |
|
0.1.56 |
5.0.0 |
|
0.1.72 |
5.0.0 |
|
0.1.56 |
5.0.0 |
|
0.1.56 |
5.0.0 |
|
0.1.56 |
5.0.0 |
|
0.1.56 |
5.0.0 |
|
0.4.0 |
5.3.0 |
|
0.1.56 |
5.0.0 |
|
0.4.2 |
5.3.0 |
|
0.4.20 |
5.6.0 |
jax.extend module#
Modules for JAX extensions.
Module |
Since JAX |
Since ROCm |
---|---|---|
|
0.4.30 |
6.0.0 |
|
0.4.17 |
5.6.0 |
|
0.4.26 |
5.6.0 |
|
0.4.15 |
5.5.0 |
jax.experimental module#
Experimental modules and APIs.
Module |
Since JAX |
Since ROCm |
---|---|---|
|
0.1.75 |
5.0.0 |
|
0.1.68 |
5.0.0 |
|
0.4.0 |
5.3.0 |
|
0.1.56 |
5.0.0 |
|
0.4.26 |
5.6.0 |
|
0.1.76 |
5.0.0 |
|
0.3.2 |
5.0.0 |
|
0.4.15 |
5.5.0 |
|
0.1.61 |
5.0.0 |
|
0.4.0 |
5.3.0 |
|
0.4.3 |
5.3.0 |
|
0.1.75 |
5.0.0 |
API |
Since JAX |
Since ROCm |
---|---|---|
|
0.1.60 |
5.0.0 |
|
0.1.60 |
5.0.0 |
jax.experimental.pallas module#
Module for Pallas, a JAX extension for custom kernels.
Module |
Since JAX |
Since ROCm |
---|---|---|
|
0.4.31 |
6.1.3 |
|
0.4.15 |
5.5.0 |
|
0.4.32 |
6.1.3 |
jax.experimental.sparse module#
Experimental support for sparse matrix operations.
Module |
Since JAX |
Since ROCm |
---|---|---|
|
0.3.15 |
5.2.0 |
|
0.3.25 |
❌ |
|
Since JAX |
Since ROCm |
---|---|---|
|
0.1.72 |
5.0.0 |
|
0.3.20 |
5.1.0 |
|
0.1.75 |
5.0.0 |
|
0.4.27 |
5.6.0 |
|
0.1.75 |
5.0.0 |
Unsupported JAX features#
The following are GPU-accelerated JAX features not currently supported by ROCm.
Feature |
Description |
Since JAX |
---|---|---|
Mixed Precision with TF32 |
Mixed precision with TF32 is used for matrix multiplications, convolutions, and other linear algebra operations, particularly in deep learning workloads like CNNs and transformers. |
0.2.25 |
RNN support |
Currently only LSTM with double bias is supported with float32 input and weight. |
0.3.25 |
XLA int4 support |
4-bit integer (int4) precision in the XLA compiler. |
0.4.0 |
|
Converts a dense matrix to a sparse matrix representation. |
Experimental |
Use cases and recommendations#
The nanoGPT in JAX blog explores the implementation and training of a Generative Pre-trained Transformer (GPT) model in JAX, inspired by Andrej Karpathy’s PyTorch-based nanoGPT. By comparing how essential GPT components—such as self-attention mechanisms and optimizers—are realized in PyTorch and JAX, also highlight JAX’s unique features.
The Optimize GPT Training: Enabling Mixed Precision Training in JAX using ROCm on AMD GPUs blog post provides a comprehensive guide on enhancing the training efficiency of GPT models by implementing mixed precision techniques in JAX, specifically tailored for AMD GPUs utilizing the ROCm platform.
The Supercharging JAX with Triton Kernels on AMD GPUs blog demonstrates how to develop a custom fused dropout-activation kernel for matrices using Triton, integrate it with JAX, and benchmark its performance using ROCm.
The Distributed fine-tuning with JAX on AMD GPUs outlines the process of fine-tuning a Bidirectional Encoder Representations from Transformers (BERT)-based large language model (LLM) using JAX for a text classification task. The blog post discuss techniques for parallelizing the fine-tuning across multiple AMD GPUs and assess the model’s performance on a holdout dataset. During the fine-tuning, a BERT-base-cased transformer model and the General Language Understanding Evaluation (GLUE) benchmark dataset was used on a multi-GPU setup.
The MI300X workload optimization guide provides detailed guidance on optimizing workloads for the AMD Instinct MI300X accelerator using ROCm. The page is aimed at helping users achieve optimal performance for deep learning and other high-performance computing tasks on the MI300X GPU.
For more use cases and recommendations, see ROCm JAX blog posts.