Model acceleration libraries#

Applies to Linux


This section discusses model acceleration techniques and libraries to improve memory efficiency and performance.

Flash Attention 2#

Flash Attention is a technique designed to reduce memory movements between GPU SRAM and high-bandwidth memory (HBM). By using a tiling approach, Flash Attention 2 improves memory locality in the nested loops of query, key, and value computations within the Attention modules of LLMs. These modules include Multi-Head Attention (MHA), Group-Query Attention (GQA), and Multi-Query Attention (MQA). This reduction in memory movements significantly decreases the time-to-first-token (TTFT) latency for large batch sizes and long prompt sequences, thereby enhancing overall performance.

Attention module of a large language module utilizing tiling

Installing Flash Attention 2#

ROCm provides two different implementations of Flash Attention 2 modules. They can be deployed interchangeably:

To install CK Flash Attention 2, use the following commands.

# Install from source
git clone
cd flash-attention/
GPU_ARCHS=gfx942 python install #MI300 series

Hugging Face Transformers can easily deploy the CK Flash Attention 2 module by passing an argument attn_implementation="flash_attention_2" in the from_pretrained class.

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_name = "NousResearch/Meta-Llama-3-8B"

tokenizer = AutoTokenizer.from_pretrained(model_name, torch_dtype=torch.float16, use_fast=False)
inputs = tokenizer('Today is', return_tensors='pt').to(device)

model_eager = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, attn_implementation="eager").cuda(device)
model_ckFAv2 = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, attn_implementation="flash_attention_2").cuda(device)

print("eager GQA: ", tokenizer.decode(model_eager.generate(**inputs, max_new_tokens=10)[0], skip_special_tokens=True))
print("ckFAv2 GQA: ", tokenizer.decode(model_ckFAv2.generate(**inputs, max_new_tokens=10)[0], skip_special_tokens=True))

#  eager GQA:  Today is the day of the Lord, and we are the
# ckFAv2 GQA: Today is the day of the Lord, and we are the

The Triton Flash Attention 2 module is implemented in Python and uses OpenAI’s JIT compiler. This module has been upstreamed into the vLLM serving toolkit, discussed in :doc:’llm-inference-frameworks’.

  1. To install Triton Flash Attention 2 and run the benchmark, use the following commands.

    # Install from the source
    pip uninstall pytorch-triton-rocm triton -y
    git clone
    cd triton/python
    GPU_ARCHS=gfx942 python install #MI300 series
    pip install matplotlib pandas
  2. To test, run the Triton Flash Attention 2 performance benchmark.

    # Test the triton FA v2 kernel
    # Results (Okay to release TFLOPS number ???)
        BATCH    HQ    HK  N_CTX_Q  N_CTX_K      TFLOPS
    0    16.0  16.0  16.0   1024.0   1024.0  287.528411
    1     8.0  16.0  16.0   2048.0   2048.0  287.490806
    2     4.0  16.0  16.0   4096.0   4096.0  345.966031
    3     2.0  16.0  16.0   8192.0   8192.0  361.369510
    4     1.0  16.0  16.0  16384.0  16384.0  356.873720
    5     2.0  48.0  48.0   1024.0   1024.0  216.916235
    6     2.0  48.0  48.0   2048.0   1024.0  271.027578
    7     2.0  48.0  48.0   4096.0   8192.0  337.367372
    8     2.0  48.0  48.0   8192.0   4096.0  363.481649
    9     2.0  48.0  48.0  16384.0   8192.0  375.013622
    10    8.0  16.0  16.0   1989.0  15344.0  321.791333
    11    4.0  16.0  16.0   4097.0    163.0  122.104888
    12    2.0  16.0  16.0   8122.0   2159.0  337.060283
    13    1.0  16.0  16.0  16281.0      7.0    5.234012
    14    2.0  48.0  48.0   1021.0   1020.0  214.657425
    15    2.0  48.0  48.0   2001.0   2048.0  314.429118
    16    2.0  48.0  48.0   3996.0   9639.0  330.411368
    17    2.0  48.0  48.0   8181.0   1021.0  324.614980


xFormers also improves the performance of attention modules. Although xFormers attention performs very similarly to Flash Attention 2 due to its tiling behavior of query, key, and value, it’s widely used for LLMs and Stable Diffusion models with the Hugging Face Diffusers library.

Installing CK xFormers#

Use the following commands to install CK xFormers.

# Install from source
git clone
cd xformers/
git submodule update --init --recursive
PYTORCH_ROCM_ARCH=gfx942 python install #Instinct MI300-series

PyTorch built-in acceleration#

PyTorch compilation mode synthesizes the model into a graph and then lowers it to prime operators. These operators are compiled using TorchInductor, which uses OpenAI Triton as a building block for GPU acceleration. One advantage of PyTorch compilation mode is that its GPU kernels are written in Python, making modifying and extending them easier. PyTorch compilation mode often delivers higher performance, as model operations are fused before runtime, which allows for easy deployment of high-performance kernels.

PyTorch compilation#

To utilize the PyTorch compilation mode, specific layers of the model must be explicitly assigned as compilation targets. In the case of LLM, where autoregressive token decoding generates dynamically changing key/value sizes, limiting the key/value size to a static dimension, max_cache_length, is necessary to utilize the performance benefits of the PyTorch compilation.

# Sample script to run LLM with the static key-value cache and PyTorch compilation
from transformers import AutoModelForCausalLM, AutoTokenizer, StaticCache
import torch
from typing import Optional
import os
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
os.environ["TOKENIZERS_PARALLELISM"] = "false"
model_name = "NousResearch/Meta-Llama-3-8B"
prompts = []

for b in range(1):
    prompts.append("New york city is where "

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).to(device).eval()
inputs = tokenizer(prompts, return_tensors="pt").to(model.device)

def decode_one_tokens(model, cur_token, input_pos, cache_position):
    logits = model(cur_token, position_ids=input_pos, cache_position=cache_position, return_dict=False, use_cache=True)[0]
    new_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
    return new_token

batch_size, seq_length = inputs["input_ids"].shape

# Static key-value cache
max_cache_length = 1024
max_new_tokens = 10
model._setup_cache(StaticCache, batch_size, max_cache_len=max_cache_length)
cache_position = torch.arange(seq_length, device=device)
generated_ids = torch.zeros(batch_size, seq_length + max_new_tokens + 1,, device=device)
generated_ids[:, cache_position] = inputs["input_ids"].to(device).to(

logits = model(**inputs, cache_position=cache_position, return_dict=False, use_cache=True)[0]
next_token = torch.argmax(logits[:, -1], dim=-1)[:, None]

# torch compilation
decode_one_tokens = torch.compile(decode_one_tokens, mode="max-autotune-no-cudagraphs",fullgraph=True)

generated_ids[:, seq_length] = next_token[:, 0]
cache_position = torch.tensor([seq_length + 1], device=device)

with torch.no_grad():
    for _ in range(1, max_new_tokens):
        with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
            next_token = decode_one_tokens(model, next_token.clone(), None, cache_position)
            generated_ids[:, cache_position] =
        cache_position += 1

PyTorch TunableOp#

ROCm PyTorch (2.2.0 and later) allows users to use high-performance ROCm GEMM kernel libraries through PyTorch’s built-in TunableOp options. This enables users to automatically pick up the best-performing GEMM kernels from rocBLAS and hipBLASLt libraries during runtime.

During warm-up runs or offline profiling steps, users can create a GEMM Table that enumerates the kernel information. During the model’s run, the best-performing kernel substitutes torch.nn.functional.linear(input, weight, bias=None) with the kernel specified in the GEMM table. The Tunable GitHub page describes the options.

# To turn on TunableOp, simply set this environment variable

# Python
import torch
import torch.nn as nn
import torch.nn.functional as F
A = torch.rand(100, 20, device="cuda")
W = torch.rand(200, 20, device="cuda")
Out = F.linear(A, W)

# tunableop_results0.csv
GEMM and TunableOp

Learn more about optimizing kernels with TunableOp in Optimizing Triton kernels.


FBGEMM (Facebook General Matrix Multiplication) is a low-precision, high-performance CPU kernel library for matrix-matrix multiplications and convolutions. It is used for server-side inference and as a back end for PyTorch quantized operators. FBGEMM offers optimized on-CPU performance for reduced precision calculations, strong performance on native tensor formats, and the ability to generate high-performance shape- and size-specific kernels at runtime.

FBGEMM_GPU collects several high-performance PyTorch GPU operator libraries for use in training and inference. It provides efficient table-batched embedding functionality, data layout transformation, and quantization support.

For more information about FBGEMM and FBGEMM_GPU, see the PyTorch FBGEMM GitHub and the PyTorch FBGEMM documentation. The Meta blog post about FBGEMM provides additional background about the library.

Installing FBGEMM_GPU#

Installing FBGEMM_GPU consists of the following steps:

  • Set up an isolated Miniconda environment

  • Install ROCm using Docker or the package manager

  • Install the nightly PyTorch build

  • Complete the pre-build and build tasks


FBGEMM_GPU doesn’t require the installation of FBGEMM. To optionally install FBGEMM, see the FBGEMM install instructions.

Set up the Miniconda environment#

To install Miniconda, use the following commands.

  1. Install a Miniconda environment for reproducible builds. All subsequent commands run inside this environment.

    export PLATFORM_NAME="$(uname -s)-$(uname -m)"
    # Set the Miniconda prefix directory
    # Download the Miniconda installer
    wget -q "${PLATFORM_NAME}.sh" -O
    # Run the installer
    bash -b -p "$miniconda_prefix" -u
    # Load the shortcuts
    . ~/.bashrc
    # Run updates
    conda update -n base -c defaults -y conda
  2. Create a Miniconda environment with Python 3.12:

    env_name=<ENV NAME>
    # Create the environment
    conda create -y --name ${env_name} python="${python_version}"
    # Upgrade PIP and pyOpenSSL package
    conda run -n ${env_name} pip install --upgrade pip
    conda run -n ${env_name} python -m pip install pyOpenSSL>22.1.0
  3. Install additional build tools:

    conda install -n ${env_name} -y \
       click \
       cmake \
       hypothesis \
       jinja2 \
       make \
       ncurses \
       ninja \
       numpy \
       scikit-build \

Install the ROCm components#

FBGEMM_GPU can run in a ROCm Docker container or in conjunction with the full ROCm installation. The Docker method is recommended because it requires fewer steps and provides a stable environment.

To run FBGEMM_GPU in the Docker container, pull the Minimal Docker image for ROCm. This image includes all preinstalled ROCm packages required to integrate FBGEMM. To pull and run the ROCm Docker image, use this command:

# Run for ROCm 6.2.0
docker run -it --network=host --shm-size 16G --device=/dev/kfd --device=/dev/dri --group-add video \
--cap-add=SYS_PTRACE --security-opt seccomp=unconfined --ipc=host rocm/rocm-terminal:6.2 /bin/bash


The Full Docker image for ROCm, which includes all ROCm packages, can also be used. However, it results in a very large container, so the minimal Docker image is recommended.

You can also install ROCm using the package manager. FBGEMM_GPU requires the installation of the full ROCm package. For more information, see the ROCm installation guide. The ROCm package also requires the MIOpen component as a dependency. To install MIOpen, use the apt install command.

apt install hipify-clang miopen-hip miopen-hip-dev

Install PyTorch#

Install PyTorch using pip for the most reliable and consistent results.

  1. Install the nightly PyTorch build using pip.

    # Install the latest nightly, ROCm variant
    conda run -n ${env_name} pip install --pre torch --index-url
  2. Ensure PyTorch loads correctly. Verify the version and variant of the installation using an import test.

    # Ensure that the package loads properly
    conda run -n ${env_name} python -c "import torch.distributed"
    # Verify the version and variant of the installation
    conda run -n ${env_name} python -c "import torch; print(torch.__version__)"

Perform the prebuild and build#

  1. Clone the FBGEMM repository and the relevant submodules. Use pip to install the components in requirements.txt. Run the following commands inside the Miniconda environment.

    # Select a version tag
    # Clone the repo along with its submodules
    git clone --branch=v0.8.0 --recursive fbgemm_${FBGEMM_VERSION}
    # Install additional required packages for building and testing
    cd fbgemm_${FBGEMM_VERSION}/fbgemm_gpu
    pip install requirements.txt
  2. Clear the build cache to remove stale build information.

    # !! Run in fbgemm_gpu/ directory inside the Conda environment !!
    python clean
  3. Set the wheel build variables, including the package name, Python version tag, and Python platform name.

    # Set the package name depending on the build variant
    export package_name=fbgemm_gpu_rocm
    # Set the Python version tag.  It should follow the convention `py<major><minor>`,
    # for example, Python 3.12 --> py312
    export python_tag=py312
    # Determine the processor architecture
    export ARCH=$(uname -m)
    # Set the Python platform name for the Linux case
    export python_plat_name="manylinux2014_${ARCH}"
  4. Build FBGEMM_GPU for the ROCm platform. Set ROCM_PATH to the path to your ROCm installation. Run these commands from the fbgemm_gpu/ directory inside the Miniconda environment.

    # !! Run in the fbgemm_gpu/ directory inside the Conda environment !!
    export ROCM_PATH=</path/to/rocm>
    # Build for the target architecture of the ROCm device installed on the machine (for example, 'gfx942;gfx90a')
    # See :doc:`The Linux system requirements <../../reference/system-requirements>` for a list of supported GPUs.
    export PYTORCH_ROCM_ARCH=$(${ROCM_PATH}/bin/rocminfo | grep -o -m 1 'gfx.*')
    # Build the wheel artifact only
    python bdist_wheel \
       --package_variant=rocm \
       --python-tag="${python_tag}" \
       --plat-name="${python_plat_name}" \
       -DHIP_ROOT_DIR="${ROCM_PATH}" \
    # Build and install the library into the Conda environment
    python install \
       --package_variant=rocm \
       -DHIP_ROOT_DIR="${ROCM_PATH}" \

Post-build validation#

After building FBGEMM_GPU, run some verification checks to ensure the build is correct. Continue to run all commands inside the fbgemm_gpu/ directory inside the Miniconda environment.

  1. The build process generates many build artifacts and C++ templates, so it is important to confirm no undefined symbols remain.

    # !! Run in fbgemm_gpu/ directory inside the Conda environment !!
    # Locate the built .SO file
    fbgemm_gpu_lib_path=$(find . -name
    # Check that the undefined symbols don't include fbgemm_gpu-defined functions
    nm -gDCu "${fbgemm_gpu_lib_path}" | sort
  2. Verify the referenced version number of GLIBCXX and the presence of certain function symbols:

    # !! Run in fbgemm_gpu/ directory inside the Conda environment !!
    # Locate the built .SO file
    fbgemm_gpu_lib_path=$(find . -name
    # Note the versions of GLIBCXX referenced by the .SO
    # The available on the install target must support these versions
    objdump -TC "${fbgemm_gpu_lib_path}" | grep GLIBCXX | sed 's/.*GLIBCXX_\([.0-9]*\).*/GLIBCXX_\1/g' | sort -Vu | cat
    # Test for the existence of a given function symbol in the .SO
    nm -gDC "${fbgemm_gpu_lib_path}" | grep " fbgemm_gpu::merge_pooled_embeddings("
    nm -gDC "${fbgemm_gpu_lib_path}" | grep " fbgemm_gpu::jagged_2d_to_dense("

Testing FBGEMM#

FBGEMM includes tests and benchmarks to validate performance. To run these tests, you must use ROCm 5.7 or a more recent version on the host and container. To run FBGEMM tests, follow these instructions:

# !! Run inside the Conda environment !!

# From the /fbgemm_gpu/ directory
cd test

# Enable for debugging failed kernel executions

# Run the test
python -m pytest -v -rsx -s -W ignore::pytest.PytestCollectionWarning

To run the FBGEMM_GPU uvm test, use these commands. These tests only support the AMD MI210 and more recent accelerators.

# Run this inside the Conda environment from the /fbgemm_gpu/ directory
export HSA_XNACK=1
cd test

python -m pytest -v -rsx -s -W ignore::pytest.PytestCollectionWarning ./uvm/