[Pytorch] Add variable-K Cutlass GroupGEMM for fine-grained MoE wgrad#3069
[Pytorch] Add variable-K Cutlass GroupGEMM for fine-grained MoE wgrad#3069cassiewilliam wants to merge 2 commits into
Conversation
Greptile SummaryThis PR adds a dedicated variable-K (ragged-K) CUTLASS GroupGEMM kernel for the BF16 NT wgrad path in fine-grained MoE training on H100 (SM90). It routes BF16-in/(FP32|BF16)-out wgrad calls with ragged per-expert token counts through a new
Confidence Score: 4/5Safe to merge for the primary path (SM90, use_cutlass=1, properly aligned shapes); the new code is only reachable behind a Hopper+env-var gate and falls back to cuBLAS when neither CUTLASS path matches. The A/B swap, empty-group zero-init, stream ordering, workspace sizing, and dtype dispatch are all correct. The shape-eligibility guard in transformer_engine/common/gemm/cutlass_grouped_gemm.cuh — the two Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["nvte_multi_tensor_gemm"] --> B{"is_hopper AND use_cutlass"}
B -- No --> CUBLAS["multi_stream_cublas_gemm cuBLAS fallback"]
B -- Yes --> C{"no bias/gelu AND supported dtype AND uniform K div by 128"}
C -- Yes --> D["cutlass_grouped_gemm Uniform-K FP16/BF16"]
C -- No --> E{"no bias/gelu AND bf16 wgrad dtype AND NT layout AND bf16 wgrad shape"}
E -- No --> CUBLAS
E -- Yes --> F["cutlass_grouped_gemm_varlen_k"]
F --> G["collect_bf16_wgrad_nt_groups Zero-init K=0 groups"]
G --> H{"all groups K=0"}
H -- Yes --> DONE["return outputs zeroed"]
H -- No --> I{"out_dtype == FP32"}
I -- Yes --> J["CutlassGroupedGemmWgrad float Cooperative 128x128x64"]
I -- No --> K["CutlassGroupedGemmWgrad bf16 Pingpong 128x128x128"]
J --> L["can_implement / initialize / run"]
K --> L
Reviews (3): Last reviewed commit: "Merge branch 'main' into feat/varlenk_gr..." | Re-trigger Greptile |
d0edc9f to
bda3dc3
Compare
Signed-off-by: Min Yang <min.yang@shopee.com>
f7a2b73 to
e7a4db9
Compare
|
How does this kernel compare performance-wise with the cuBLASLt grouped gemm? Ideally if cuBLAS is better we would like to move towards that solution instead. |
Description
This PR extends the CUTLASS Group GEMM support added in #2045 to cover the variable-K
(K-grouped / ragged-K) BF16 weight-gradient (wgrad) path of fine-grained MoE models on H100 (SM90).
In expert-parallel MoE training the per-expert token counts — the contraction dimension of the
wgrad GEMM
D_i = B_iᵀ @ A_i— are ragged and generally not 128-aligned, which the existinguniform-K CUTLASS grouped-GEMM fast path from #2045 cannot serve. This PR adds a dedicated path
that handles ragged per-expert token counts directly (SM90 TMA/WGMMA), zero-initializes empty
(
K=0) groups, and writes each per-expertD_iin place. Inputs are BF16; output is FP32 (default)or BF16. The standard uniform-K and Multi-Stream cuBLAS paths are unchanged.
Performance on H100 80GB, BF16, wgrad (
D_i = B_iᵀ @ A_i), CUTLASS vs. the Multi-Stream cuBLASbaseline. Shape is
(g, m, n, k[mink, avgk, maxk]):ggroups,m= expert dim,n= hidden dim,k= the per-group routed-token count — the ragged contraction this kernel is built for.run benchmark with
NVTE_USE_CUTLASS_GROUPED_GEMM=1 python benchmarks/gemm/benchmark_grouped_gemm_fwd_bwd.py --use-cutlass --dtype bf16 --num-experts <E> --ep-size 8 --hidden-dim 2048 --expert-dim 512 [--jagged-splits ...]The gain grows as the per-group K shrinks: small, ragged groups are where the Multi-Stream cuBLAS
per-group launch overhead dominates.
Correctness reuses the existing test harness from #2045 (unchanged in this PR): the parametrized
tests/pytorch/test_grouped_linear.py::test_grouped_gemmwithlayout=NT(the wgrad case),use_cutlass=True,dtype=bfloat16over ragged group splits exercises exactly this path and passeson SM90.
This path reuses the
NVTE_USE_CUTLASS_GROUPED_GEMMtoggle introduced in #2045 (default0):export NVTE_USE_CUTLASS_GROUPED_GEMM=1routes the BF16 NT wgrad through CUTLASS,0keeps theMulti-Stream cuBLAS implementation.
NVTE_CUTLASS_GROUPED_GEMM_WARN_FALLBACKstill warns on fallback.Type of change
Changes
cutlass_grouped_gemm.cuh: addCutlassGroupedGemmWgrad<trans_a, trans_b, ElementD>— an SM90grouped-GEMM template specialized for the NT wgrad layout — with explicit instantiations for
FP32 and BF16 output.
cutlass_grouped_gemm.cu: addcutlass_grouped_gemm_varlen_k(...). It validates the BF16 NT wgradcontract, splits groups into the non-empty set (excluding
K=0groups whose null A/B pointerswould crash TMA descriptor construction, zero-initializing their outputs when not accumulating),
and dispatches on output dtype — mirroring the existing
cutlass_grouped_gemmcall path.cublaslt_gemm.cu: wire the path into thenvte_multi_tensor_gemmdispatch(uniform-K fast path → K-grouped wgrad → cuBLAS fallback).
Checklist: