Skip to content

Enable fused SDPA vector kernel for asymmetric Q/V head dims (192, 128)#3637

Open
yohann-bearzi wants to merge 1 commit into
ml-explore:mainfrom
yohann-bearzi:sdpa-asym-headdim
Open

Enable fused SDPA vector kernel for asymmetric Q/V head dims (192, 128)#3637
yohann-bearzi wants to merge 1 commit into
ml-explore:mainfrom
yohann-bearzi:sdpa-asym-headdim

Conversation

@yohann-bearzi
Copy link
Copy Markdown

The sdpa_vector and sdpa_vector_2pass_1 kernels are already templated on a separate value head dim (template <typename T, int D, int V = D>), but no (D, V) pairs with D != V were instantiated, and use_fallback required query_head_dim == value_head_dim. Models with asymmetric head dims therefore fall back to a compiled-graph attention decomposition (multiple GatherAxis dispatches per layer) instead of the fused kernel.

This adds instantiate_sdpa_vector(type, 192, 128) and relaxes use_fallback to allow this specific asymmetric case. All other head dims are unchanged.

Motivation: MiMo-V2.5 uses head_dim=192 for Q/K and v_head_dim=128 for V. On current main it hits the fallback path.

Testing: Verified on MiMo-V2.5 (Metal, M3 Ultra): decode throughput 28.7 → 29.8 tok/s (+4%), generated output (top-1) unchanged. The kernel templates already supported V != D; this only instantiates and enables the existing code path. Incremental build on current main is clean.

Scope: Minimal — 2 files, +5/-3. No new kernel code, just one instantiation + the fallback guard.

The sdpa_vector kernel template already supported a separate value
head_dim (template <typename T, int D, int V = D>), but no (D, V) pairs
with D != V were instantiated, and use_fallback required
query_head_dim == value_head_dim.

MiMo-V2.5 uses head_dim=192 for Q/K and v_head_dim=128 for V, falling
through to a compiled-graph decomposition (multiple GatherAxis dispatches
per attention layer) instead of the fused kernel.

Adds instantiate_sdpa_vector(type, 192, 128) and relaxes use_fallback to
allow this specific asymmetric case. Other head dims remain unchanged.

Verified on MiMo-V2.5: decode 28.7 -> 29.8 tok/s (+4%). Top-1 stable.
The remaining decode bottleneck is MoE routing, not attention.
@yohann-bearzi
Copy link
Copy Markdown
Author

For models using block_fp8 quantization (e.g. MiMo-V2.5, which motivated this asymmetric-head-dim fix), the block_fp8 matmul and MoE kernels are available as a standalone MLX extension: https://github.com/yohann-bearzi/mlx-block-fp8 — it builds against stock upstream MLX (kernels vendored, no fork) and pairs with this SDPA change for full MiMo-V2.5 decode throughput.

The MiMo-V2.5 MLX weights (block_fp8) are on Hugging Face: https://huggingface.co/bearzi/MiMo-V2.5-MLX

@yohann-bearzi
Copy link
Copy Markdown
Author

@zcbenz would you please be able to review?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant