Training a model with PyTorch for ROCm#
2025-02-21
7 min read time
PyTorch is an open-source machine learning framework that is widely used for model training with GPU-optimized components for transformer-based models.
The PyTorch for ROCm training Docker (rocm/pytorch-training:v25.3
) image
provides a prebuilt optimized environment for fine-tuning and pretraining a
model on AMD Instinct MI325X and MI300X accelerators. It includes the following
software components to accelerate training workloads:
Software component |
Version |
---|---|
ROCm |
6.3.0 |
PyTorch |
2.7.0a0+git637433 |
Python |
3.10 |
Transformer Engine |
1.11 |
Flash Attention |
3.0.0 |
hipBLASLt |
git258a2162 |
Triton |
3.1 |
Supported models#
The following models are pre-optimized for performance on the AMD Instinct MI300X accelerator.
Llama 3.1 8B
Llama 3.1 70B
FLUX.1-dev
Note
Only these models are supported in the following steps.
Some models, such as Llama 3, require an external license agreement through a third party (for example, Meta).
System validation#
If you have already validated your system settings, skip this step. Otherwise, complete the system validation and optimization steps to set up your system before starting training.
Disable NUMA auto-balancing#
Generally, application performance can benefit from disabling NUMA auto-balancing. However, it might be detrimental to performance with certain types of workloads.
Run the command cat /proc/sys/kernel/numa_balancing
to check your current NUMA (Non-Uniform
Memory Access) settings. Output 0
indicates this setting is disabled. If there is no output or
the output is 1
, run the following command to disable NUMA auto-balancing.
sudo sh -c 'echo 0 > /proc/sys/kernel/numa_balancing'
See Disable NUMA auto-balancing for more information.
Environment setup#
This Docker image is optimized for specific model configurations outlined below. Performance can vary for other training workloads, as AMD doesn’t validate configurations and run conditions outside those described.
Download the Docker image#
Use the following command to pull the Docker image from Docker Hub.
docker pull rocm/pytorch-training:v25.3
Run the Docker container.
docker run -it --device /dev/dri --device /dev/kfd --network host --ipc host --group-add video --cap-add SYS_PTRACE --security-opt seccomp=unconfined --privileged -v $HOME:$HOME -v $HOME/.ssh:/root/.ssh --shm-size 64G --name training_env rocm/pytorch-training:v25.3
Use these commands if you exit the
training_env
container and need to return to it.docker start training_env docker exec -it training_env bash
In the Docker container, clone the ROCm/MAD repository and navigate to the benchmark scripts directory.
git clone https://github.com/ROCm/MAD cd MAD/scripts/pytorch-train
Prepare training datasets and dependencies#
The following benchmarking examples may require downloading models and datasets
from Hugging Face. To ensure successful access to gated repos, set your
HF_TOKEN
.
Run the setup script to install libraries and datasets needed for benchmarking.
./pytorch_benchmark_setup.sh
pytorch_benchmark_setup.sh
installs the following libraries:
Library |
Benchmark model |
Reference |
---|---|---|
|
Llama 3.1 8B, FLUX |
|
|
Llama 3.1 8B, 70B, FLUX |
Hugging Face Datasets 3.2.0 |
|
Llama 3.1 70B |
|
|
Llama 3.1 70B |
|
|
Llama 3.1 70B |
|
|
Llama 3.1 70B |
|
|
Llama 3.1 70B |
|
|
Llama 3.1 70B |
|
|
Llama 3.1 70B, FLUX |
SentencePiece 0.2.0 |
|
Llama 3.1 70 B, FLUX |
TensorBoard 2.18.0 |
|
FLUX |
csvkit 2.0.1 |
|
FLUX |
DeepSpeed 0.16.2 |
|
FLUX |
Hugging Face Diffusers 0.31.0 |
|
FLUX |
GitPython 3.1.44 |
|
FLUX |
opencv-python-headless 4.10.0.84 |
|
FLUX |
PEFT 0.14.0 |
|
FLUX |
Protocol Buffers 5.29.2 |
|
FLUX |
PyTest 8.3.4 |
|
FLUX |
python-dotenv 1.0.1 |
|
FLUX |
Seaborn 0.13.2 |
|
FLUX |
Transformers 4.47.0 |
pytorch_benchmark_setup.sh
downloads the following models from Hugging Face:
Along with the following datasets:
Start training on AMD Instinct accelerators#
The prebuilt PyTorch with ROCm training environment allows users to quickly validate system performance, conduct training benchmarks, and achieve superior performance for models like Llama 3.1 and Llama 2. This container should not be expected to provide generalized performance across all training workloads. You can expect the container to perform in the model configurations described in the following section, but other configurations are not validated by AMD.
Use the following instructions to set up the environment, configure the script to train models, and reproduce the benchmark results on MI300X series accelerators with the AMD PyTorch training Docker image.
Once your environment is set up, use the following commands and examples to start benchmarking.
Pretraining#
To start the pretraining benchmark, use the following command with the appropriate options. See the following list of options and their descriptions.
./pytorch_benchmark_report.sh -t $training_mode -m $model_repo -p $datatype -s $sequence_length
Options and available models#
Name |
Options |
Description |
---|---|---|
|
|
Benchmark pretraining |
|
Benchmark full weight fine-tuning (Llama 3.1 70B with BF16) |
|
|
Benchmark LoRA fine-tuning (Llama 3.1 70B with BF16) |
|
|
FP8 or BF16 |
Only Llama 3.1 8B supports FP8 precision. |
|
Llama-3.1-8B |
|
Llama-3.1-70B |
||
Flux |
Fine-tuning#
To start the fine-tuning benchmark, use the following command. It will run the benchmarking example of Llama 2 70B with the WikiText dataset using the AMD fork of torchtune.
./pytorch_benchmark_report.sh -t {finetune_fw, finetune_lora} -p BF16 -m Llama-3.1-70B
Benchmarking examples#
Here are some examples of how to use the command.
Example 1: Llama 3.1 70B with BF16 precision with torchtitan.
./pytorch_benchmark_report.sh -t pretrain -p BF16 -m Llama-3.1-70B -s 8192
Example 2: Llama 3.1 8B with FP8 precision using Transformer Engine (TE) and Hugging Face Accelerator.
./pytorch_benchmark_report.sh -t pretrain -p FP8 -m Llama-3.1-70B -s 8192
Example 3: FLUX.1-dev with BF16 precision with FluxBenchmark.
./pytorch_benchmark_report.sh -t pretrain -p BF16 -m Flux
Example 4: Torchtune full weight fine-tuning with Llama 3.1 70B
./pytorch_benchmark_report.sh -t finetune_fw -p BF16 -m Llama-3.1-70B
Example 5: Torchtune LoRA fine-tuning with Llama 3.1 70B
./pytorch_benchmark_report.sh -t finetune_lora -p BF16 -m Llama-3.1-70B