JAX on ROCm#

2024-12-20

8 min read time

Applies to Linux

This directory provides setup instructions and necessary files to build, test, and run JAX with ROCm support in a Docker environment, suitable for both runtime and CI workflows. Explore the following methods to use or build JAX on ROCm.

Using a prebuilt Docker image#

The ROCm JAX team provides prebuilt Docker images, which is the simplest way to use JAX on ROCm. These images are available on Docker Hub and come with JAX configured for ROCm.

  1. To pull the latest ROCm JAX Docker image, run:

    docker pull rocm/jax-community:latest
    

    Note

    For specific versions of JAX, review the periodically pushed Docker images at ROCm JAX Community on Docker Hub.

    Additional Docker images are available at ROCm JAX on Docker Hub. These contain the latest ROCm version but might use an older version of JAX.

  2. Once the image is downloaded, launch a container using the following command:

    docker run -it -d --network=host --device=/dev/kfd --device=/dev/dri --ipc=host --shm-size 64G \
    --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v $(pwd):/jax_dir \
    --name rocm_jax rocm/jax-community:latest /bin/bash
    
    docker attach rocm_jax
    

    Tip

    • The --shm-size parameter allocates shared memory for the container. Adjust it based on your system’s resources if needed.

    • Replace $(pwd) with the absolute path to the directory you want to mount inside the container.

    • If you prefer to use rocm/jax, remember to replace rocm/jax-community with rocm/jax.

  3. Verify the installation of ROCm JAX. See Testing your JAX installation with ROCm.

Using a ROCm base Docker image and installing JAX#

If you prefer to use the ROCm Ubuntu image or already have a ROCm Ubuntu container, follow these steps to install JAX in the container.

  1. Pull the ROCm Ubuntu Docker image. For example, use the following command to pull the ROCm Ubuntu image:

    docker pull rocm/dev-ubuntu-22.04:6.3-complete
    
  2. Launch the Docker container. After pulling the image, launch a container using this command:

    docker run -it -d --network=host --device=/dev/kfd --device=/dev/dri --ipc=host --shm-size 64G \
    --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v $(pwd):/jax_dir \
    --name rocm_jax rocm/dev-ubuntu-22.04:6.3-complete /bin/bash
    docker attach rocm_jax
    
  3. Install the latest version of JAX. Inside the running container, install the required version of JAX with ROCm support using pip:

    pip3 install jax[rocm]
    
  4. Verify the installed JAX version. Check whether the correct version of JAX and its ROCm plugins are installed.

    pip3 freeze | grep jax
    

    Expected output:

    jax==0.4.35
    jax-rocm60-pjrt==0.4.35
    jax-rocm60-plugin==0.4.35
    jaxlib==0.4.35
    
  5. Explicitly set the LLVM_PATH environment variable. This helps XLA find ld.lld in the PATH at runtime.

    export LLVM_PATH=/opt/rocm/llvm
    
  6. Verify the installation of ROCm JAX. See Testing your JAX installation with ROCm.

Install JAX on bare-metal or a custom container#

Follow these steps if you prefer to install ROCm manually on your host system or in a custom container.

  1. Install ROCm. Follow the ROCm installation guide to install ROCm on your system.

    Once installed, verify your ROCm installation using:

    rocm-smi
    
     ========================================== ROCm System Management Interface ==========================================
     ==================================================== Concise Info ====================================================
    Device  [Model : Revision]    Temp        Power     Partitions      SCLK     MCLK     Fan  Perf  PwrCap  VRAM%  GPU%
              Name (20 chars)       (Junction)  (Socket)  (Mem, Compute)
      ======================================================================================================================
      0       [0x74a1 : 0x00]       50.0°C      170.0W    NPS1, SPX       131Mhz   900Mhz   0%   auto  750.0W    0%   0%
              AMD Instinct MI300X
      1       [0x74a1 : 0x00]       51.0°C      176.0W    NPS1, SPX       132Mhz   900Mhz   0%   auto  750.0W    0%   0%
              AMD Instinct MI300X
      2       [0x74a1 : 0x00]       50.0°C      177.0W    NPS1, SPX       132Mhz   900Mhz   0%   auto  750.0W    0%   0%
              AMD Instinct MI300X
      3       [0x74a1 : 0x00]       53.0°C      176.0W    NPS1, SPX       132Mhz   900Mhz   0%   auto  750.0W    0%   0%
              AMD Instinct MI300X
      ======================================================================================================================
      ================================================ End of ROCm SMI Log =================================================
    
  2. Install the required version of JAX with ROCm support using pip:

    pip3 install jax[rocm]
    
  3. Verify the installed JAX version. Check whether the correct version of JAX and its ROCm plugins are installed.

    pip3 freeze | grep jax
    
  4. Explicitly set the LLVM_PATH environment variable.

    export LLVM_PATH=/opt/rocm/llvm
    
  5. Verify the installation of ROCm JAX.

    Run the following commands to verify that ROCm JAX is installed correctly:

    python3 -c "import jax; print(jax.devices())"
    python3 -c "import jax.numpy as jnp; x = jnp.arange(5); print(x)"
    

    Expected output:

    [RocmDevice(id=0), RocmDevice(id=1), RocmDevice(id=2), RocmDevice(id=3)]
    
    [0 1 2 3 4]
    

Build ROCm JAX from source#

Follow these steps to build JAX with ROCm support from source.

  1. Clone the ROCm-specific fork of JAX with the desired branch:

    git clone https://github.com/ROCm/jax -b <branch_name>
    cd jax
    
  1. Run the following command to build the necessary wheels:

    python3 ./build/build.py build --wheels=jaxlib,jax-rocm-plugin,jax-rocm-pjrt \
        --rocm_version=60 --rocm_path=/opt/rocm-[version]
    

    This will generate three wheels in the dist/ directory:

    • jaxlib (generic, device agnostic library)

    • jax-rocm-plugin (ROCm-specific plugin)

    • jax-rocm-pjrt (ROCm-specific runtime)

  2. Install the custom JAX wheels.

    python3 setup.py develop --user && pip3 -m pip install dist/*.whl
    

Simplified build script#

For a streamlined build process, consider using the jax/build/rocm/dev_build_rocm.py script. See rocm/jax for more information.

Testing your JAX installation with ROCm#

After launching the container, test whether JAX detects ROCm devices as expected:

python -c "import jax; print(jax.devices())"
python3 -c "import jax.numpy as jnp; x = jnp.arange(5); print(x)"

If the setup is successful, the output should list all available ROCm devices.

Expected output:

[RocmDevice(id=0), RocmDevice(id=1), RocmDevice(id=2), RocmDevice(id=3)]
[0 1 2 3 4]