Skip to content

Enable NVFP4 grouped MLP GLU RHT amax path#3073

Open
sraman-rgb wants to merge 9 commits into
NVIDIA:mainfrom
sraman-rgb:nvfp4-grouped-mlp-glu-rht-amax
Open

Enable NVFP4 grouped MLP GLU RHT amax path#3073
sraman-rgb wants to merge 9 commits into
NVIDIA:mainfrom
sraman-rgb:nvfp4-grouped-mlp-glu-rht-amax

Conversation

@sraman-rgb
Copy link
Copy Markdown
Contributor

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label Jun 1, 2026
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Jun 1, 2026

Greptile Summary

This PR adds the NVFP4 grouped MLP GLU + Random Hadamard Transform (RHT) amax path, enabling a new fused kernel (grouped_gemm_glu_hadamard_wrapper_sm100) that jointly runs the FC1 GEMM+GLU with NVFP4 RHT amax collection, then reuses those pre-computed amaxes for FC2 input quantization rather than re-scanning the activation output.

  • C++ layer: Refactors group_quantize_nvfp4_impl to accept a compute_amax flag (skipping the Hadamard amax step when pre-computed values are provided), adds allreduce_nvfp4_amax_tensors helper, and introduces nvfp4_group_quantize_with_amax for the multi-group NVFP4 path. quantize gains a compute_amax=true default parameter for the single-group path.
  • Python layer: Extracts _wrap_single_nvfp4_as_grouped as a reusable helper, adds _group_quantize_with_amax_for_grouped_mlp, and teaches ForwardGroupedMLP_CuTeGEMMGLU to use grouped_gemm_glu_hadamard_kernel when all conditions are met (NVFP4 + swiglu + with_rht + with_post_rht_amax). The TMEM env-var flag is cached via lru_cache to avoid hot-path lookups.

Confidence Score: 5/5

Safe to merge; the allreduce ordering is correct in both the multi-group and single-group paths, and the new GLU+Hadamard kernel selection is properly gated with graceful fallback.

The distributed allreduce logic is the highest-risk area. In the multi-group path, allreduce_nvfp4_amax_tensors is called before group_quantize_nvfp4_impl, so the cast kernel always sees globally-reduced amaxes. In the single-group path, reduce_amaxes() fires at line 2473 of quantize_impl, before the nvte_quantize_with_hadamard_transform call, so the same ordering guarantee holds. The empty-input early-return correctly participates in the collective when compute_amax=false. The only concern is that quantize_impl is now public, which is a minor API surface issue with no present defect.

transformer_engine/pytorch/csrc/common.h — the promotion of quantize_impl to public is worth a second look before wider adoption of this API.

Important Files Changed

Filename Overview
transformer_engine/pytorch/csrc/common.h Moves quantize_impl from private to public in NVFP4Quantizer to allow C++ extension code to call it directly with compute_amax=false; widens API surface of the class.
transformer_engine/pytorch/csrc/extensions/cast.cpp Adds allreduce_nvfp4_amax_tensors helper, refactors group_quantize_nvfp4_impl to require RHT+post-RHT-amax and accept compute_amax flag, and introduces nvfp4_group_quantize_with_amax that allreduces externally-provided amaxes before the cast kernel.
transformer_engine/pytorch/csrc/quantizer.cpp Restructures quantize_impl so that reduce_amaxes() runs before the cast kernel for both compute_amax=true and compute_amax=false; the empty-input early-return correctly calls reduce_amaxes() in the compute_amax=false branch to keep distributed collectives synchronized.
transformer_engine/pytorch/csrc/extensions/pybind.cpp Adds compute_amax=true default to the quantize Python binding and registers nvfp4_group_quantize_with_amax via a helper; clean, backward-compatible.
transformer_engine/pytorch/ops/_common.py Extracts _wrap_single_nvfp4_as_grouped, adds _group_quantize_with_amax_for_grouped_mlp for both multi-group (C++ allreduce) and single-group (allreduce inside quantize_impl) NVFP4 paths.
transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py Adds _use_tmem_post_rht_amax (properly lru_cached), detects and activates the GLU+Hadamard path when conditions are met, routes amax tensors from the kernel output into _group_quantize_with_amax_for_grouped_mlp.

Sequence Diagram

sequenceDiagram
    participant FW as ForwardGroupedMLP (forward)
    participant KG as grouped_gemm_glu_hadamard_kernel
    participant PY as _group_quantize_with_amax_for_grouped_mlp
    participant CPP as nvfp4_group_quantize_with_amax (C++)
    participant QI as quantize_impl (C++)
    participant AR as allreduce / reduce_amaxes
    participant CK as nvte cast kernel

    FW->>KG: FC1 GEMM + GLU + RHT amax
    KG-->>FW: "fc1_kernel_out {d_tensor, amax_tensor, post_rht_amax_tensor}"
    FW->>PY: _group_quantize_with_amax_for_grouped_mlp(...)
    alt "num_groups != 1"
        PY->>CPP: tex.nvfp4_group_quantize_with_amax
        CPP->>AR: allreduce_nvfp4_amax_tensors (before kernel)
        AR-->>CPP: global amaxes
        CPP->>CK: "group_quantize_nvfp4_impl(compute_amax=false)"
        CK-->>PY: GroupedTensor
    else "num_groups == 1"
        PY->>QI: "tex.quantize(compute_amax=False)"
        QI->>AR: reduce_amaxes() at line 2473 (before cast kernel)
        AR-->>QI: global amaxes
        QI->>CK: nvte_quantize_with_hadamard_transform
        CK-->>PY: FP4 tensor wrapped as GroupedTensor
    end
    FW->>FW: FC2 GEMM using grouped_fc2_x
Loading

Reviews (7): Last reviewed commit: "Address NVFP4 precomputed amax review co..." | Re-trigger Greptile

Comment thread transformer_engine/pytorch/csrc/extensions/cast.cpp
Comment thread transformer_engine/pytorch/ops/_common.py Outdated
Comment thread transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py Outdated
Comment thread transformer_engine/pytorch/csrc/extensions/cast.cpp Outdated
Comment thread transformer_engine/pytorch/csrc/extensions/cast.cpp Outdated
Comment thread transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py Outdated
Comment thread transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py Outdated
Comment thread transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py Outdated
Comment thread transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py Outdated
Copy link
Copy Markdown
Collaborator

@vthumbe1503 vthumbe1503 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly LGTM. Left a few comments on code duplication and other minor issues.

sraman-rgb and others added 3 commits June 1, 2026 16:37
Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
Comment thread transformer_engine/pytorch/csrc/extensions/cast.cpp Outdated
Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
Comment thread transformer_engine/pytorch/ops/_common.py
Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>

py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::object &output,
std::optional<at::Tensor> noop_flag) {
std::optional<at::Tensor> noop_flag, bool compute_amax) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is leaking NVFP4-specific details into a generic API.

Comment thread transformer_engine/pytorch/ops/_common.py
Comment thread transformer_engine/pytorch/ops/_common.py Outdated
Comment thread transformer_engine/pytorch/ops/_common.py Outdated
Comment thread transformer_engine/pytorch/csrc/extensions/pybind.cpp Outdated
Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants