A comprehensive, extensible framework for profiling and benchmarking JAX operations on TPUs and other hardware accelerators.
The accelerator_microbenchmarks package provides a structured way to measure
the performance (latency, throughput, memory bandwidth) of various JAX
primitives and composite operations. It includes built-in benchmarks for:
- Compute Operations: Generalized GEMMs, Matrix Multiplications, Attention mechanisms.
- Collective Communications:
psum,all_gather,all_to_all,reduce_scatter(usingshard_map). - Memory Bandwidth: HBM bandwidth profiling.
The framework is highly configurable via YAML files, allowing users to define parameter sweeps, warm-up iterations, and matrix shapes without modifying Python code.
accelerator_microbenchmarks/
├── configs/ # YAML configuration files (e.g., sample.yaml, hbm_sweep.yaml)
├── docs/ # Documentation (README, DEVELOPERS, DESIGN, RATIONALE)
│ ├── DESIGN.md
│ ├── DEVELOPERS.md
│ └── RATIONALE.md
├── pyproject.toml
├── results/ # Can create output directory for benchmark metrics (JSON, CSV)
├── src/
│ └── accelerator_microbenchmarks/
│ ├── benchmarks/ # Concrete benchmark implementations (collectives, matmul, etc.)
│ ├── core/ # Framework core (BaseBenchmark, registry, config parsing)
│ └── main.py # Entry point for running benchmarks
├── README.md
- Configuration: A YAML file (e.g.,
configs/sample.yaml) defines global settings (number of runs, warmup tries) and a list of benchmarks to execute. It supports parameter "sweeps" to automatically test a range of dimensions or mesh shapes. - Registry: The
main.pyrunner parses the YAML and looks up the requested benchmark names in a central registry. - Execution: For each configuration permutation, the framework
instantiates the benchmark, calls its
setup(), runswarmup_triesiterations, and then executesnum_runsiterations while capturing precise timing metrics.
You can install the package locally via pip. It is recommended to do this in a
dedicated virtual environment:
pip install .For editable mode (useful when developing custom benchmarks):
pip install -e .If you are on a machine with available accelerators or want to test functionality on CPU, you can run the binary directly via Bazel:
bazel run //src/accelerator_microbenchmarks:main -- \
--config configs/sample.yamlTo add a new benchmark, please refer to the detailed instructions in DEVELOPERS.md.
The YAML configuration supports discrete values and sweep definitions.
global:
warmup_tries: 2
num_runs: 5
dtype: "bfloat16"
benchmarks:
# 1. Fixed parameters
- name: my_custom_op
size: 2048
# 2. List sweep
- name: my_custom_op
sweep:
size: [1024, 2048, 4096]
# 3. Geometric/Range sweep
- name: hbm_bandwidth
sweep:
size:
start: 1024
end: 8192
multiplier: 2 # Will test 1024, 2048, 4096, 8192By default, the benchmark runner aggregates results and writes them to the
results/ directory as detailed.json and summary.csv.
This code has gone through significant refactoring. In case you are heavily dependent on the old version of the code, you can pin your dependencies to this tag (v1.1-legacy)