[PyTorch] Refactor function to prepare pointers for grouped MLP discrete weights#3076
Conversation
… submodule Introduces a `grouped_mlp_experimental` pybind submodule as a labeled home for hyperspecific helpers that exist to satisfy the cuDNN CuTe DSL grouped GEMM kernels. The submodule itself is documented as unstable, so callers can see at the import path that these helpers are not part of the supported surface. `copy_data_ptrs_to_device` is genuinely general-purpose and stays at the top level; only `transform_and_copy_data_ptrs_to_device` moves into the submodule, and its four call sites in the fused grouped MLP forward/backward are updated accordingly. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Tim Moon <tmoon@nvidia.com>
Replaces `transform_and_copy_data_ptrs_to_device` with a more focused helper, `swizzle_scales_and_pack_ptrs_for_discrete_weights`. The new function takes both the FP8/FP4 weight data tensors and their scale tensors, swizzles the scales, and copies both pointer arrays to device in a single kernel launch (down from two — one from `copy_data_ptrs_to_device` for data and one from the old transform helper for scales). The two returned pointer arrays are views into a single packed device buffer. The general "transform_type" string dispatch is gone: the function only supports `mxfp8_rowwise`, `mxfp8_columnwise`, and `nvfp4`, which were the only modes ever used. The four discrete-weight call sites in the fused grouped MLP forward/backward collapse their paired `copy_data_ptrs_to_device` + transform calls into a single call. The implementation moves to a dedicated source file, `csrc/extensions/grouped_mlp_experimental.cpp`, so the experimental submodule has a clear home for future helpers tied to the cuDNN CuTe DSL grouped GEMM kernels. The declaration in `extensions.h` is grouped under a matching banner. `copy_data_ptrs_to_device` stays in `utils.cpp` since it remains a general-purpose helper. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Tim Moon <tmoon@nvidia.com>
…le helper In swizzle_scales_and_pack_ptrs_for_discrete_weights, take the data shape directly from the data tensors instead of inferring it from the padded scale shape. NVFP4 packs two 4-bit values per byte, so the byte-shape's inner dim is doubled to recover the logical element count. Also replace the trio of is_mxfp8_rowwise / is_mxfp8_columnwise / is_nvfp4 booleans with a function-local TensorFormat enum. Tensor properties (scaling mode, dtypes, swizzle param names) are assigned together per case in a single switch so adding a future format is a single-point change rather than a fresh boolean threaded through the function. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Tim Moon <tmoon@nvidia.com>
After moving the experimental grouped-MLP helper out, the only thing left in extensions/utils.cpp was copy_data_ptrs_to_device, which fits naturally alongside the cublasLt/cuDNN version getters and splits_to_offsets already in extensions/misc.cpp. Move it there and delete the now-empty utils.cpp. Build picks up sources via glob, so no manifest update is needed. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Tim Moon <tmoon@nvidia.com>
Wraps the C++ implementation of swizzle_scales_and_pack_ptrs_for_discrete_weights in a `grouped_mlp_experimental` namespace and renames the format-selector argument from `format` to `swizzle_type` across the declaration, implementation, and pybind binding. The pybind submodule name was already `grouped_mlp_experimental`, so the C++ namespace now mirrors it. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Tim Moon <tmoon@nvidia.com>
After NVFP4Quantizer::quantize runs, run inplace_swizzle_scale_for_gemm on the output when optimize_for_gemm is set and the quantize kernel hasn't already produced swizzled scales. The NVFP4 quantize kernel rejects with_gemm_swizzled_scales=true and emits compact scales, so without this hook callers had to follow up with a manual swizzle in Python (see ops/_common.py:85). The hook is a no-op for MXFP8 (its quantize kernel sets the flag itself) and for any quantizer with optimize_for_gemm=false. Also fixes a latent state-consistency bug in NVFP4Quantizer::convert_and_update_tensor: it was resetting the C++ wrapper's with_gemm_swizzled_scales to false but never touching the Python tensor's _with_gemm_swizzled_scales attribute. Re-quantizing into a tensor that previously held swizzled scales would leave the Python flag stuck at true while the buffer was compact, mismatched state that downstream code could mis-read. The Python attribute is now reset alongside the C++ wrapper, matching what MXFP8Quantizer::convert_and_update_tensor already does. Adds test_swizzle_scales_and_pack_ptrs_for_discrete_weights covering mxfp8_rowwise, mxfp8_columnwise, and nvfp4, comparing the helper's swizzled output against scales produced by the quantizer with optimize_for_gemm=true. NVFP4 was the case that surfaced the quantizer-side issues fixed above. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR refactors the grouped MLP discrete-weight pointer preparation into a new
Confidence Score: 5/5Safe to merge; the core bug fixes are correctly implemented, CUDA stream ordering is sound, and PyTorch reference counting for the swizzled-scale buffer is properly maintained. All changed code paths look correct: the new function uses the actual data tensor shape so padding rows/columns are excluded from swizzling, the NVFP4 post-quantize swizzle in cast.cpp is properly guarded against double-swizzling, and the stale-flag reset in convert_and_update_tensor is placed before the TensorWrapper is returned. tests/pytorch/test_grouped_linear.py — the padded-scale poison section is dead code for the chosen (160, 96) shape and should use a non-exact-multiple shape to actually exercise the main bug fix. Important Files Changed
Reviews (3): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
Signed-off-by: Tim Moon <tmoon@nvidia.com>
| # Check scale pointer values | ||
| scale_bytes = scale_tensors[0].numel() * scale_tensors[0].element_size() | ||
| expected_scale_ptrs = torch.tensor( | ||
| [swizzled_scales_buffer.data_ptr() + i * scale_bytes for i in range(num_tensors)], | ||
| dtype=torch.int64, | ||
| device="cpu", | ||
| ) |
There was a problem hiding this comment.
The expected scale-pointer stride uses
scale_bytes (raw size), but the C++ code allocates the swizzled-scale buffer with swizzled_scales_stride = roundup(scale_bytes, 16). For the current shape (160, 96) every scale size happens to be a multiple of 16 (480 and 960), so both values agree and the assertion passes, but any shape that produces a non-16-aligned scale count would compute incorrect expected pointers. The same implicit assumption appears in the view_as call below — if scale_bytes != swizzled_scales_stride, that call throws a RuntimeError rather than silently validating padding-separated data.
| # Check scale pointer values | |
| scale_bytes = scale_tensors[0].numel() * scale_tensors[0].element_size() | |
| expected_scale_ptrs = torch.tensor( | |
| [swizzled_scales_buffer.data_ptr() + i * scale_bytes for i in range(num_tensors)], | |
| dtype=torch.int64, | |
| device="cpu", | |
| ) | |
| # Check scale pointer values | |
| # Use the same 16-byte-aligned stride as the C++ implementation | |
| import math | |
| scale_bytes = scale_tensors[0].numel() * scale_tensors[0].element_size() | |
| scale_stride = math.ceil(scale_bytes / 16) * 16 | |
| expected_scale_ptrs = torch.tensor( | |
| [swizzled_scales_buffer.data_ptr() + i * scale_stride for i in range(num_tensors)], | |
| dtype=torch.int64, | |
| device="cpu", | |
| ) |
| // Scale shape | ||
| const NVTEShape scale_shape = convertTorchShape(scale_tensors[0].sizes()); | ||
| NVTE_CHECK(scale_shape.ndim == 2, | ||
| "Expected 2D scale tensor, but got shape=", getTensorShape(scale_tensors[0]), "."); |
There was a problem hiding this comment.
Missing per-tensor shape consistency check — the function validates that
data_tensors.size() == scale_tensors.size() but uses only data_tensors[0] and scale_tensors[0] as the reference shapes for all tensors. If any subsequent tensor has a different shape, the NVTETensor configuration will be wrong for that tensor, leading to an out-of-bounds scale swizzle without any diagnostic. Adding an NVTE_CHECK loop over i > 0 to assert shape equality would catch this early.
There was a problem hiding this comment.
The function expects uniformly-sized tensors. Actually performing the checks would incur significant overhead.
|
/te-ci pytorch |
| // Post-quantize swizzle for quantizers whose kernel does not bake | ||
| // the GEMM-swizzled scale layout in directly | ||
| if (quantizer_cpp->optimize_for_gemm && !output_cpp.get_with_gemm_swizzled_scales()) { | ||
| inplace_swizzle_scale_for_gemm(output_py); | ||
| } |
There was a problem hiding this comment.
The changes here and in the C++ NVFP4 quantizer are to fix a bug uncovered by the test. When the NVFP4 quantizer was configured with optimize_for_gemm, it would not actually produce swizzled scales.
|
|
||
| // Allocate single buffer for swizzled scales. Uses a uniform stride since | ||
| // all tensors share the same scale shape. | ||
| const size_t swizzled_scales_stride = roundup(scale_bytes, 16); // Align to 16 bytes |
There was a problem hiding this comment.
We dont need roundup here right?
There was a problem hiding this comment.
It's paranoid and somewhat redundant, but it's also cheap.
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Co-authored-by: vthumbe1503 <vthumbe@nvidia.com> Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
|
/te-ci arm64 amd64 |
for more information, see https://pre-commit.ci
The single test_nvfp4_rht_swizzle_fusion_shape_gate conflated two checks that mainline NVIDIA#3076 split apart. _with_gemm_swizzled_scales only controls WHERE scale-factor swizzling happens, not WHETHER: when False, the GEMM swizzles lazily at call time; when True, the tensor is pre-swizzled and the GEMM skips it. When this test landed, ineligible shapes (rows%64!=0 or cols%128!=0) ended quantize with the flag False. NVIDIA#3076 then added a post-quantize inplace_swizzle_scale_for_gemm fallback that eagerly swizzles ineligible shapes and flips the flag back to True, so under optimize_for_gemm the end-to-end flag is now True for all shapes. The old False expectations encoded pre-NVIDIA#3076 behavior and started failing CI on (64,144), (128,144), (48,128). Split into two self-consistent tests: - shape_gate: probes make_empty() (runs create_tensor only -- no quantize, no fallback), so it observes the fused-kernel shape gate in isolation and keeps the original True/False eligibility table. - end_to_end_swizzled: quantizer(x) must never raise on ineligible shapes and must always yield _with_gemm_swizzled_scales=True (eligible via the fused cast-fusion kernel, ineligible via the NVIDIA#3076 swizzle fallback). Signed-off-by: Cael Ling <caell@nvidia.com>
Description
#3001 refactored the logic for preparing grouped MLP discrete weight pointers for the cuDNN CuTe DSL kernel. However, #3070 identified a correctness bug when dealing with padded scales.
This PR refactors this logic by putting the function (
swizzle_scales_and_pack_ptrs_for_discrete_weights) within a submodule dedicated to grouped MLP (transformer_engine_torch.grouped_mlp_experimental). Given how much hacky custom logic we need for this very narrow, experimental use-case (also see #2991), it seems practical to make a dumping ground that is easier to iterate on. The new function copies both data and scale pointers to device in a single kernel launch and it fixes the bug in #3070.Closes #3070.
Type of change
Changes
transformer_engine_torch.grouped_mlp_experimentalsubmoduleswizzle_scales_and_pack_ptrs_for_discrete_weightsfunction to prepare discrete weight pointers for cuDNN grouped MLP moduleoptimize_for_gemmoptionChecklist: