Enable NVFP4 grouped MLP GLU RHT amax path#3073
Conversation
Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR adds the NVFP4 grouped MLP GLU + Random Hadamard Transform (RHT) amax path, enabling a new fused kernel (
Confidence Score: 5/5Safe to merge; the allreduce ordering is correct in both the multi-group and single-group paths, and the new GLU+Hadamard kernel selection is properly gated with graceful fallback. The distributed allreduce logic is the highest-risk area. In the multi-group path,
Important Files Changed
Sequence DiagramsequenceDiagram
participant FW as ForwardGroupedMLP (forward)
participant KG as grouped_gemm_glu_hadamard_kernel
participant PY as _group_quantize_with_amax_for_grouped_mlp
participant CPP as nvfp4_group_quantize_with_amax (C++)
participant QI as quantize_impl (C++)
participant AR as allreduce / reduce_amaxes
participant CK as nvte cast kernel
FW->>KG: FC1 GEMM + GLU + RHT amax
KG-->>FW: "fc1_kernel_out {d_tensor, amax_tensor, post_rht_amax_tensor}"
FW->>PY: _group_quantize_with_amax_for_grouped_mlp(...)
alt "num_groups != 1"
PY->>CPP: tex.nvfp4_group_quantize_with_amax
CPP->>AR: allreduce_nvfp4_amax_tensors (before kernel)
AR-->>CPP: global amaxes
CPP->>CK: "group_quantize_nvfp4_impl(compute_amax=false)"
CK-->>PY: GroupedTensor
else "num_groups == 1"
PY->>QI: "tex.quantize(compute_amax=False)"
QI->>AR: reduce_amaxes() at line 2473 (before cast kernel)
AR-->>QI: global amaxes
QI->>CK: nvte_quantize_with_hadamard_transform
CK-->>PY: FP4 tensor wrapped as GroupedTensor
end
FW->>FW: FC2 GEMM using grouped_fc2_x
Reviews (7): Last reviewed commit: "Address NVFP4 precomputed amax review co..." | Re-trigger Greptile |
vthumbe1503
left a comment
There was a problem hiding this comment.
Mostly LGTM. Left a few comments on code duplication and other minor issues.
Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
|
|
||
| py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::object &output, | ||
| std::optional<at::Tensor> noop_flag) { | ||
| std::optional<at::Tensor> noop_flag, bool compute_amax) { |
There was a problem hiding this comment.
This is leaking NVFP4-specific details into a generic API.
Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: