Skip to content

[Pytorch] Add variable-K Cutlass GroupGEMM for fine-grained MoE wgrad#3069

Open
cassiewilliam wants to merge 2 commits into
NVIDIA:mainfrom
cassiewilliam:feat/varlenk_groupgemm
Open

[Pytorch] Add variable-K Cutlass GroupGEMM for fine-grained MoE wgrad#3069
cassiewilliam wants to merge 2 commits into
NVIDIA:mainfrom
cassiewilliam:feat/varlenk_groupgemm

Conversation

@cassiewilliam
Copy link
Copy Markdown
Contributor

@cassiewilliam cassiewilliam commented Jun 1, 2026

Description

This PR extends the CUTLASS Group GEMM support added in #2045 to cover the variable-K
(K-grouped / ragged-K) BF16 weight-gradient (wgrad)
path of fine-grained MoE models on H100 (SM90).

In expert-parallel MoE training the per-expert token counts — the contraction dimension of the
wgrad GEMM D_i = B_iᵀ @ A_i — are ragged and generally not 128-aligned, which the existing
uniform-K CUTLASS grouped-GEMM fast path from #2045 cannot serve. This PR adds a dedicated path
that handles ragged per-expert token counts directly (SM90 TMA/WGMMA), zero-initializes empty
(K=0) groups, and writes each per-expert D_i in place. Inputs are BF16; output is FP32 (default)
or BF16. The standard uniform-K and Multi-Stream cuBLAS paths are unchanged.

Performance on H100 80GB, BF16, wgrad (D_i = B_iᵀ @ A_i), CUTLASS vs. the Multi-Stream cuBLAS
baseline. Shape is (g, m, n, k[mink, avgk, maxk]): g groups, m = expert dim, n = hidden dim,
k = the per-group routed-token count — the ragged contraction this kernel is built for.

run benchmark with

NVTE_USE_CUTLASS_GROUPED_GEMM=1 python benchmarks/gemm/benchmark_grouped_gemm_fwd_bwd.py --use-cutlass --dtype bf16 --num-experts <E> --ep-size 8 --hidden-dim 2048 --expert-dim 512 [--jagged-splits ...]

Shape(g, m, n, k[mink, avgk, maxk]) TE (cuBLAS, TFLOPs) Cutlass (TFLOPs) Speed-Up
(20, 512, 2048, k[3328, 3328, 3328]) 445.50 567.65 1.27×
(20, 512, 2048, k[512, 3104, 6016]) 444.52 520.92 1.17×
(32, 512, 2048, k[1024, 2048, 3072]) 361.13 564.13 1.56×
(32, 512, 2048, k[512, 1024, 1536]) 173.50 512.61 2.95×

The gain grows as the per-group K shrinks: small, ragged groups are where the Multi-Stream cuBLAS
per-group launch overhead dominates.

Correctness reuses the existing test harness from #2045 (unchanged in this PR): the parametrized
tests/pytorch/test_grouped_linear.py::test_grouped_gemm with layout=NT (the wgrad case),
use_cutlass=True, dtype=bfloat16 over ragged group splits exercises exactly this path and passes
on SM90.

This path reuses the NVTE_USE_CUTLASS_GROUPED_GEMM toggle introduced in #2045 (default 0):
export NVTE_USE_CUTLASS_GROUPED_GEMM=1 routes the BF16 NT wgrad through CUTLASS, 0 keeps the
Multi-Stream cuBLAS implementation. NVTE_CUTLASS_GROUPED_GEMM_WARN_FALLBACK still warns on fallback.

Type of change

  • New feature (non-breaking change which adds functionality)

Changes

  • cutlass_grouped_gemm.cuh: add CutlassGroupedGemmWgrad<trans_a, trans_b, ElementD> — an SM90
    grouped-GEMM template specialized for the NT wgrad layout — with explicit instantiations for
    FP32 and BF16 output.
  • cutlass_grouped_gemm.cu: add cutlass_grouped_gemm_varlen_k(...). It validates the BF16 NT wgrad
    contract, splits groups into the non-empty set (excluding K=0 groups whose null A/B pointers
    would crash TMA descriptor construction, zero-initializing their outputs when not accumulating),
    and dispatches on output dtype — mirroring the existing cutlass_grouped_gemm call path.
  • cublaslt_gemm.cu: wire the path into the nvte_multi_tensor_gemm dispatch
    (uniform-K fast path → K-grouped wgrad → cuBLAS fallback).

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label Jun 1, 2026
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Jun 1, 2026

Greptile Summary

This PR adds a dedicated variable-K (ragged-K) CUTLASS GroupGEMM kernel for the BF16 NT wgrad path in fine-grained MoE training on H100 (SM90). It routes BF16-in/(FP32|BF16)-out wgrad calls with ragged per-expert token counts through a new cutlass_grouped_gemm_varlen_k path, zero-initializing empty (K=0) groups and dispatching non-empty groups to a new CutlassGroupedGemmWgrad template with separate FP32 (Cooperative 128×128×64) and BF16 (Pingpong 128×128×128, ClusterShape 1×2×1) schedule specialisations.

  • cublaslt_gemm.cu: adds is_bf16_wgrad_dtype() and is_bf16_wgrad_shape() guards, inserts a new else if branch between the uniform-K CUTLASS path and the cuBLAS fallback to capture ragged-K BF16 wgrad, and wires it to cutlass_grouped_gemm_varlen_k.
  • cutlass_grouped_gemm.cu: implements collect_bf16_wgrad_nt_groups (per-group empty filtering + async memset for K=0 outputs) and cutlass_grouped_gemm_varlen_k (outer A/B swap, non-zero subgroup dispatch on output dtype).
  • cutlass_grouped_gemm.cuh: adds GemmGivenScheduleWgradBase, two GemmGivenScheduleWgrad output-dtype specialisations (FP32 and BF16), the GemmGroupedWgrad alias, and the CutlassGroupedGemmWgrad kernel launcher (workspace sizing, host buffer pack, cudaMemcpyAsync, can_implement / initialize / run).

Confidence Score: 4/5

Safe to merge for the primary path (SM90, use_cutlass=1, properly aligned shapes); the new code is only reachable behind a Hopper+env-var gate and falls back to cuBLAS when neither CUTLASS path matches.

The A/B swap, empty-group zero-init, stream ordering, workspace sizing, and dtype dispatch are all correct. The shape-eligibility guard in is_bf16_wgrad_shape is narrower than the full set of CUTLASS kernel constraints, so certain misaligned shapes that pass the guard will still produce a hard error from can_implement rather than a cuBLAS fallback — a known design trade-off that the existing test coverage on H100 does exercise. No data corruption path exists for shapes that do meet the CUTLASS requirements.

transformer_engine/common/gemm/cutlass_grouped_gemm.cuh — the two GemmGivenScheduleWgrad specialisations carry redundant base-class alias re-declarations that could silently drift from the base on a future refactor.

Important Files Changed

Filename Overview
transformer_engine/common/gemm/cublaslt_gemm.cu Dispatcher wired correctly: SM90+use_cutlass outer guard, is_bf16_wgrad_dtype/shape predicates, and the new else-if branch. Shape guard checks rank, K-matching, and uniform hidden/expert but not alignment (handled by can_implement hard-error or pre-existing cuBLAS fallback).
transformer_engine/common/gemm/cutlass_grouped_gemm.cu collect_bf16_wgrad_nt_groups correctly filters K=0 groups (async memset for non-accumulate case), and cutlass_grouped_gemm_varlen_k correctly swaps outer A/B before dispatch. cudaMemcpyAsync return is unchecked (pre-existing pattern).
transformer_engine/common/gemm/cutlass_grouped_gemm.cuh New GemmGivenScheduleWgradBase and two GemmGivenScheduleWgrad specialisations are functionally sound; FP32 Cooperative 128x128x64 and BF16 Pingpong 128x128x128 / ClusterShape 1x2x1 are correct CUTLASS 3.x configs for SM90. Both specialisations re-declare ~10 type aliases from the base that don't need overriding, creating a maintenance hazard.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["nvte_multi_tensor_gemm"] --> B{"is_hopper AND use_cutlass"}
    B -- No --> CUBLAS["multi_stream_cublas_gemm cuBLAS fallback"]
    B -- Yes --> C{"no bias/gelu AND supported dtype AND uniform K div by 128"}
    C -- Yes --> D["cutlass_grouped_gemm Uniform-K FP16/BF16"]
    C -- No --> E{"no bias/gelu AND bf16 wgrad dtype AND NT layout AND bf16 wgrad shape"}
    E -- No --> CUBLAS
    E -- Yes --> F["cutlass_grouped_gemm_varlen_k"]
    F --> G["collect_bf16_wgrad_nt_groups Zero-init K=0 groups"]
    G --> H{"all groups K=0"}
    H -- Yes --> DONE["return outputs zeroed"]
    H -- No --> I{"out_dtype == FP32"}
    I -- Yes --> J["CutlassGroupedGemmWgrad float Cooperative 128x128x64"]
    I -- No --> K["CutlassGroupedGemmWgrad bf16 Pingpong 128x128x128"]
    J --> L["can_implement / initialize / run"]
    K --> L
Loading

Reviews (3): Last reviewed commit: "Merge branch 'main' into feat/varlenk_gr..." | Re-trigger Greptile

Comment thread transformer_engine/common/gemm/cutlass_grouped_gemm.cuh Outdated
Comment thread transformer_engine/common/gemm/cublaslt_gemm.cu
@cassiewilliam cassiewilliam force-pushed the feat/varlenk_groupgemm branch from d0edc9f to bda3dc3 Compare June 1, 2026 12:27
@cassiewilliam cassiewilliam changed the title Add variable-K (K-grouped) BF16 wgrad grouped GEMM (CUTLASS, SM90) [Pytorch] Add variable-K Cutlass GroupGEMM for fine-grained MoE wgrad Jun 1, 2026
Signed-off-by: Min Yang <min.yang@shopee.com>
@cassiewilliam cassiewilliam force-pushed the feat/varlenk_groupgemm branch from f7a2b73 to e7a4db9 Compare June 1, 2026 12:58
@ptrendx
Copy link
Copy Markdown
Member

ptrendx commented Jun 2, 2026

How does this kernel compare performance-wise with the cuBLASLt grouped gemm? Ideally if cuBLAS is better we would like to move towards that solution instead.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants