Enable fused SDPA vector kernel for asymmetric Q/V head dims (192, 128)#3637
Open
yohann-bearzi wants to merge 1 commit into
Open
Enable fused SDPA vector kernel for asymmetric Q/V head dims (192, 128)#3637yohann-bearzi wants to merge 1 commit into
yohann-bearzi wants to merge 1 commit into
Conversation
The sdpa_vector kernel template already supported a separate value head_dim (template <typename T, int D, int V = D>), but no (D, V) pairs with D != V were instantiated, and use_fallback required query_head_dim == value_head_dim. MiMo-V2.5 uses head_dim=192 for Q/K and v_head_dim=128 for V, falling through to a compiled-graph decomposition (multiple GatherAxis dispatches per attention layer) instead of the fused kernel. Adds instantiate_sdpa_vector(type, 192, 128) and relaxes use_fallback to allow this specific asymmetric case. Other head dims remain unchanged. Verified on MiMo-V2.5: decode 28.7 -> 29.8 tok/s (+4%). Top-1 stable. The remaining decode bottleneck is MoE routing, not attention.
Author
|
For models using block_fp8 quantization (e.g. MiMo-V2.5, which motivated this asymmetric-head-dim fix), the block_fp8 matmul and MoE kernels are available as a standalone MLX extension: https://github.com/yohann-bearzi/mlx-block-fp8 — it builds against stock upstream MLX (kernels vendored, no fork) and pairs with this SDPA change for full MiMo-V2.5 decode throughput. The MiMo-V2.5 MLX weights (block_fp8) are on Hugging Face: https://huggingface.co/bearzi/MiMo-V2.5-MLX |
Author
|
@zcbenz would you please be able to review? |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
The
sdpa_vectorandsdpa_vector_2pass_1kernels are already templated on a separate value head dim (template <typename T, int D, int V = D>), but no(D, V)pairs withD != Vwere instantiated, anduse_fallbackrequiredquery_head_dim == value_head_dim. Models with asymmetric head dims therefore fall back to a compiled-graph attention decomposition (multipleGatherAxisdispatches per layer) instead of the fused kernel.This adds
instantiate_sdpa_vector(type, 192, 128)and relaxesuse_fallbackto allow this specific asymmetric case. All other head dims are unchanged.Motivation: MiMo-V2.5 uses
head_dim=192for Q/K andv_head_dim=128for V. On currentmainit hits the fallback path.Testing: Verified on MiMo-V2.5 (Metal, M3 Ultra): decode throughput 28.7 → 29.8 tok/s (+4%), generated output (top-1) unchanged. The kernel templates already supported
V != D; this only instantiates and enables the existing code path. Incremental build on currentmainis clean.Scope: Minimal — 2 files, +5/-3. No new kernel code, just one instantiation + the fallback guard.