Skip to content

[PyTorch] Refactor function to prepare pointers for grouped MLP discrete weights#3076

Merged
timmoon10 merged 12 commits into
NVIDIA:mainfrom
timmoon10:tmoon/refactor-grouped-mlp-prepare-ptr-func-for-discrete-weights
Jun 3, 2026
Merged

[PyTorch] Refactor function to prepare pointers for grouped MLP discrete weights#3076
timmoon10 merged 12 commits into
NVIDIA:mainfrom
timmoon10:tmoon/refactor-grouped-mlp-prepare-ptr-func-for-discrete-weights

Conversation

@timmoon10
Copy link
Copy Markdown
Member

@timmoon10 timmoon10 commented Jun 2, 2026

Description

#3001 refactored the logic for preparing grouped MLP discrete weight pointers for the cuDNN CuTe DSL kernel. However, #3070 identified a correctness bug when dealing with padded scales.

This PR refactors this logic by putting the function (swizzle_scales_and_pack_ptrs_for_discrete_weights) within a submodule dedicated to grouped MLP (transformer_engine_torch.grouped_mlp_experimental). Given how much hacky custom logic we need for this very narrow, experimental use-case (also see #2991), it seems practical to make a dumping ground that is easier to iterate on. The new function copies both data and scale pointers to device in a single kernel launch and it fixes the bug in #3070.

Closes #3070.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Create transformer_engine_torch.grouped_mlp_experimental submodule
  • Create swizzle_scales_and_pack_ptrs_for_discrete_weights function to prepare discrete weight pointers for cuDNN grouped MLP module
  • Fix bug where C++ NVFP4 quantizer would not respect optimize_for_gemm option
  • Add MXFP8 and NVFP4 test for discrete weight pointer function, including with poisoned scale pads

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

timmoon10 and others added 7 commits June 1, 2026 23:07
… submodule

Introduces a `grouped_mlp_experimental` pybind submodule as a labeled
home for hyperspecific helpers that exist to satisfy the cuDNN CuTe
DSL grouped GEMM kernels. The submodule itself is documented as
unstable, so callers can see at the import path that these helpers
are not part of the supported surface.

`copy_data_ptrs_to_device` is genuinely general-purpose and stays at
the top level; only `transform_and_copy_data_ptrs_to_device` moves
into the submodule, and its four call sites in the fused grouped MLP
forward/backward are updated accordingly.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Replaces `transform_and_copy_data_ptrs_to_device` with a more focused
helper, `swizzle_scales_and_pack_ptrs_for_discrete_weights`. The new
function takes both the FP8/FP4 weight data tensors and their scale
tensors, swizzles the scales, and copies both pointer arrays to device
in a single kernel launch (down from two — one from
`copy_data_ptrs_to_device` for data and one from the old transform
helper for scales). The two returned pointer arrays are views into a
single packed device buffer.

The general "transform_type" string dispatch is gone: the function
only supports `mxfp8_rowwise`, `mxfp8_columnwise`, and `nvfp4`, which
were the only modes ever used. The four discrete-weight call sites in
the fused grouped MLP forward/backward collapse their paired
`copy_data_ptrs_to_device` + transform calls into a single call.

The implementation moves to a dedicated source file,
`csrc/extensions/grouped_mlp_experimental.cpp`, so the experimental
submodule has a clear home for future helpers tied to the cuDNN CuTe
DSL grouped GEMM kernels. The declaration in `extensions.h` is
grouped under a matching banner. `copy_data_ptrs_to_device` stays in
`utils.cpp` since it remains a general-purpose helper.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
…le helper

In swizzle_scales_and_pack_ptrs_for_discrete_weights, take the data
shape directly from the data tensors instead of inferring it from the
padded scale shape. NVFP4 packs two 4-bit values per byte, so the
byte-shape's inner dim is doubled to recover the logical element
count.

Also replace the trio of is_mxfp8_rowwise / is_mxfp8_columnwise /
is_nvfp4 booleans with a function-local TensorFormat enum. Tensor
properties (scaling mode, dtypes, swizzle param names) are assigned
together per case in a single switch so adding a future format is a
single-point change rather than a fresh boolean threaded through the
function.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
After moving the experimental grouped-MLP helper out, the only thing
left in extensions/utils.cpp was copy_data_ptrs_to_device, which fits
naturally alongside the cublasLt/cuDNN version getters and
splits_to_offsets already in extensions/misc.cpp. Move it there and
delete the now-empty utils.cpp. Build picks up sources via glob, so
no manifest update is needed.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Wraps the C++ implementation of
swizzle_scales_and_pack_ptrs_for_discrete_weights in a
`grouped_mlp_experimental` namespace and renames the format-selector
argument from `format` to `swizzle_type` across the declaration,
implementation, and pybind binding. The pybind submodule name was
already `grouped_mlp_experimental`, so the C++ namespace now mirrors
it.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
After NVFP4Quantizer::quantize runs, run inplace_swizzle_scale_for_gemm
on the output when optimize_for_gemm is set and the quantize kernel
hasn't already produced swizzled scales. The NVFP4 quantize kernel
rejects with_gemm_swizzled_scales=true and emits compact scales, so
without this hook callers had to follow up with a manual swizzle in
Python (see ops/_common.py:85). The hook is a no-op for MXFP8 (its
quantize kernel sets the flag itself) and for any quantizer with
optimize_for_gemm=false.

Also fixes a latent state-consistency bug in
NVFP4Quantizer::convert_and_update_tensor: it was resetting the C++
wrapper's with_gemm_swizzled_scales to false but never touching the
Python tensor's _with_gemm_swizzled_scales attribute. Re-quantizing
into a tensor that previously held swizzled scales would leave the
Python flag stuck at true while the buffer was compact, mismatched
state that downstream code could mis-read. The Python attribute is
now reset alongside the C++ wrapper, matching what
MXFP8Quantizer::convert_and_update_tensor already does.

Adds test_swizzle_scales_and_pack_ptrs_for_discrete_weights covering
mxfp8_rowwise, mxfp8_columnwise, and nvfp4, comparing the helper's
swizzled output against scales produced by the quantizer with
optimize_for_gemm=true. NVFP4 was the case that surfaced the
quantizer-side issues fixed above.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10 timmoon10 requested a review from vthumbe1503 June 2, 2026 03:20
@timmoon10 timmoon10 requested a review from ksivaman as a code owner June 2, 2026 03:20
@timmoon10 timmoon10 added bug Something isn't working refactor MoE labels Jun 2, 2026
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Jun 2, 2026

Greptile Summary

This PR refactors the grouped MLP discrete-weight pointer preparation into a new transformer_engine_torch.grouped_mlp_experimental submodule and fixes two bugs: (1) the NVFP4 quantizer silently ignored optimize_for_gemm, and (2) the old scale-swizzle code derived the data shape from the padded scale tensor rather than from the actual data tensor, causing incorrect swizzle output when scales had padding.

  • New swizzle_scales_and_pack_ptrs_for_discrete_weights takes both data and scale tensors explicitly, uses the actual data-tensor shape for the swizzle kernel (so padding rows/columns in the scale tensor are not swizzled), and packs data + scale pointers in a single kernel launch instead of two.
  • NVFP4 optimize_for_gemm fix: cast.cpp now applies inplace_swizzle_scale_for_gemm as a post-quantize step when the quantizer kernel doesn't bake the swizzle in directly; quantizer.cpp resets _with_gemm_swizzled_scales on re-used tensors in convert_and_update_tensor so the flag is never stale.
  • utils.cpp is deleted; its generic copy_data_ptrs_to_device moves to misc.cpp and the more complex swizzle-and-pack logic moves to the new experimental file.

Confidence Score: 5/5

Safe to merge; the core bug fixes are correctly implemented, CUDA stream ordering is sound, and PyTorch reference counting for the swizzled-scale buffer is properly maintained.

All changed code paths look correct: the new function uses the actual data tensor shape so padding rows/columns are excluded from swizzling, the NVFP4 post-quantize swizzle in cast.cpp is properly guarded against double-swizzling, and the stale-flag reset in convert_and_update_tensor is placed before the TensorWrapper is returned.

tests/pytorch/test_grouped_linear.py — the padded-scale poison section is dead code for the chosen (160, 96) shape and should use a non-exact-multiple shape to actually exercise the main bug fix.

Important Files Changed

Filename Overview
transformer_engine/pytorch/csrc/extensions/grouped_mlp_experimental.cpp New file implementing the unified scale-swizzle + pointer-pack function; correctly uses actual data tensor shape (rather than the padded scale shape from the old code) so the swizzle kernel only touches valid scale entries. Stream ordering and PyTorch reference counting are correct.
transformer_engine/pytorch/csrc/extensions/cast.cpp Adds a post-quantize swizzle step for quantizers (e.g. NVFP4) whose kernel cannot bake the GEMM-swizzled layout in directly; guarded by both optimize_for_gemm and !get_with_gemm_swizzled_scales() to prevent double-swizzling.
transformer_engine/pytorch/csrc/quantizer.cpp Removes the TODO stubs for NVFP4 optimize_for_gemm and adds an explicit _with_gemm_swizzled_scales = false reset in convert_and_update_tensor so a reused tensor doesn't carry a stale true flag into the post-quantize swizzle check in cast.cpp.
transformer_engine/pytorch/csrc/extensions/utils.cpp Deleted; its two functions are split across misc.cpp (copy_data_ptrs_to_device) and the new grouped_mlp_experimental.cpp (swizzle + pack).
tests/pytorch/test_grouped_linear.py New test covers pointer values and swizzled-scale content for all three swizzle types, but the chosen shape (160, 96) has exact-multiple dimensions so the padded-scale poison section is a no-op and the core bug fix goes untested.

Reviews (3): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Comment on lines +1825 to +1831
# Check scale pointer values
scale_bytes = scale_tensors[0].numel() * scale_tensors[0].element_size()
expected_scale_ptrs = torch.tensor(
[swizzled_scales_buffer.data_ptr() + i * scale_bytes for i in range(num_tensors)],
dtype=torch.int64,
device="cpu",
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 The expected scale-pointer stride uses scale_bytes (raw size), but the C++ code allocates the swizzled-scale buffer with swizzled_scales_stride = roundup(scale_bytes, 16). For the current shape (160, 96) every scale size happens to be a multiple of 16 (480 and 960), so both values agree and the assertion passes, but any shape that produces a non-16-aligned scale count would compute incorrect expected pointers. The same implicit assumption appears in the view_as call below — if scale_bytes != swizzled_scales_stride, that call throws a RuntimeError rather than silently validating padding-separated data.

Suggested change
# Check scale pointer values
scale_bytes = scale_tensors[0].numel() * scale_tensors[0].element_size()
expected_scale_ptrs = torch.tensor(
[swizzled_scales_buffer.data_ptr() + i * scale_bytes for i in range(num_tensors)],
dtype=torch.int64,
device="cpu",
)
# Check scale pointer values
# Use the same 16-byte-aligned stride as the C++ implementation
import math
scale_bytes = scale_tensors[0].numel() * scale_tensors[0].element_size()
scale_stride = math.ceil(scale_bytes / 16) * 16
expected_scale_ptrs = torch.tensor(
[swizzled_scales_buffer.data_ptr() + i * scale_stride for i in range(num_tensors)],
dtype=torch.int64,
device="cpu",
)

Comment on lines +93 to +96
// Scale shape
const NVTEShape scale_shape = convertTorchShape(scale_tensors[0].sizes());
NVTE_CHECK(scale_shape.ndim == 2,
"Expected 2D scale tensor, but got shape=", getTensorShape(scale_tensors[0]), ".");
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Missing per-tensor shape consistency check — the function validates that data_tensors.size() == scale_tensors.size() but uses only data_tensors[0] and scale_tensors[0] as the reference shapes for all tensors. If any subsequent tensor has a different shape, the NVTETensor configuration will be wrong for that tensor, leading to an out-of-bounds scale swizzle without any diagnostic. Adding an NVTE_CHECK loop over i > 0 to assert shape equality would catch this early.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function expects uniformly-sized tensors. Actually performing the checks would incur significant overhead.

@timmoon10
Copy link
Copy Markdown
Member Author

/te-ci pytorch

Comment on lines +65 to +69
// Post-quantize swizzle for quantizers whose kernel does not bake
// the GEMM-swizzled scale layout in directly
if (quantizer_cpp->optimize_for_gemm && !output_cpp.get_with_gemm_swizzled_scales()) {
inplace_swizzle_scale_for_gemm(output_py);
}
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes here and in the C++ NVFP4 quantizer are to fix a bug uncovered by the test. When the NVFP4 quantizer was configured with optimize_for_gemm, it would not actually produce swizzled scales.

vthumbe1503
vthumbe1503 previously approved these changes Jun 2, 2026
Copy link
Copy Markdown
Collaborator

@vthumbe1503 vthumbe1503 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

Comment thread tests/pytorch/test_grouped_linear.py Outdated

// Allocate single buffer for swizzled scales. Uses a uniform stride since
// all tensors share the same scale shape.
const size_t swizzled_scales_stride = roundup(scale_bytes, 16); // Align to 16 bytes
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We dont need roundup here right?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's paranoid and somewhat redundant, but it's also cheap.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Co-authored-by: vthumbe1503 <vthumbe@nvidia.com>
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
@timmoon10
Copy link
Copy Markdown
Member Author

/te-ci arm64 amd64

@timmoon10 timmoon10 merged commit c1e827f into NVIDIA:main Jun 3, 2026
12 of 15 checks passed
@timmoon10 timmoon10 deleted the tmoon/refactor-grouped-mlp-prepare-ptr-func-for-discrete-weights branch June 3, 2026 01:56
cael-ling added a commit to cael-ling/TransformerEngine that referenced this pull request Jun 4, 2026
The single test_nvfp4_rht_swizzle_fusion_shape_gate conflated two checks
that mainline NVIDIA#3076 split apart. _with_gemm_swizzled_scales only controls
WHERE scale-factor swizzling happens, not WHETHER: when False, the GEMM
swizzles lazily at call time; when True, the tensor is pre-swizzled and the
GEMM skips it. When this test landed, ineligible shapes (rows%64!=0 or
cols%128!=0) ended quantize with the flag False. NVIDIA#3076 then added a
post-quantize inplace_swizzle_scale_for_gemm fallback that eagerly swizzles
ineligible shapes and flips the flag back to True, so under optimize_for_gemm
the end-to-end flag is now True for all shapes. The old False expectations
encoded pre-NVIDIA#3076 behavior and started failing CI on (64,144), (128,144),
(48,128).

Split into two self-consistent tests:
- shape_gate: probes make_empty() (runs create_tensor only -- no quantize,
  no fallback), so it observes the fused-kernel shape gate in isolation and
  keeps the original True/False eligibility table.
- end_to_end_swizzled: quantizer(x) must never raise on ineligible shapes
  and must always yield _with_gemm_swizzled_scales=True (eligible via the
  fused cast-fusion kernel, ineligible via the NVIDIA#3076 swizzle fallback).

Signed-off-by: Cael Ling <caell@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working MoE refactor

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants