Skip to content

[PyTorch] Use data shape for MXFP8 pointer swizzle#3070

Closed
jepio wants to merge 1 commit into
NVIDIA:mainfrom
jepio:jpiotrowski/mxfp8-swizzle-data-shape-fix
Closed

[PyTorch] Use data shape for MXFP8 pointer swizzle#3070
jepio wants to merge 1 commit into
NVIDIA:mainfrom
jepio:jpiotrowski/mxfp8-swizzle-data-shape-fix

Conversation

@jepio
Copy link
Copy Markdown
Member

@jepio jepio commented Jun 1, 2026

Description

We've started seeing NaN loss in gpt_oss_20b training using MXFP8. gpt_oss_20b uses expert MLP dimensions based on hidden size 2880, producing FC1 weights of (5760, 2880) and FC2 weights of (2880, 2880). I bisected this to 9e5a847. That change is not safe in the presence of MXFP8 scale padding: without the actual unpadded data shape, transform_and_copy_data_ptrs_to_device infers the data shape from the padded scale shape, so padded E8M0 scale bytes can be copied into the swizzled GEMM scale buffer and consumed by GEMM, causing NaNs. Fix it by passing the actual data shape to transform_and_copy_data_ptrs_to_device. The patch includes a regression test that poisons padded scale entries and verifies they are masked out.

Fixes #

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Fix MXFP8 scale padding being read for non-128-aligned dimensions

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

MXFP8 scale tensors are padded to GEMM tile boundaries. The pointer swizzle
helper inferred the logical data shape from the padded scale shape, so padded
E8M0 values could be treated as valid scale blocks and propagated into grouped
MLP GEMM scale buffers.

Accept an optional actual data shape for MXFP8 pointer swizzle and pass the
weight shapes from fused grouped MLP rowwise and columnwise paths. Add a
regression test that poisons padded scale entries with E8M0 NaNs and checks
that the swizzle masks them out.

Signed-off-by: Jeremi Piotrowski <jpiotrowski@nvidia.com>
@jepio jepio force-pushed the jpiotrowski/mxfp8-swizzle-data-shape-fix branch from ce3d904 to 044ac53 Compare June 1, 2026 13:59
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Jun 1, 2026

Greptile Summary

This PR fixes a NaN-loss regression in MXFP8 grouped MLP training for non-128-aligned dimensions: transform_and_copy_data_ptrs_to_device was inferring the data shape from the padded scale tensor shape, causing padded E8M0 bytes to be treated as valid scale values and passed to the GEMM kernel. The fix threads the actual unpadded weight shape through to the C++ swizzle function so it can correctly mask out those padding blocks.

  • Adds an optional actual_data_shape parameter to transform_and_copy_data_ptrs_to_device (defaulting to empty so existing call sites remain valid) and updates all four MXFP8 grouped-MLP call sites (two forward, two backward) to supply it.
  • Adds a focused regression test that poisons padded scale entries with E8M0 NaNs and asserts they are masked out after the swizzle, directly validating the fix.

Confidence Score: 4/5

The fix is targeted and well-tested; the only rough edge is a missing positive-value guard on the new int64_t to size_t cast in the C++ core.

All four production call sites are updated consistently, the default empty-vector preserves backwards compatibility, and the new test directly exercises the failure mode described in the bug report. The sole concern is the unguarded cast of actual_data_shape values from int64_t to size_t: a negative dimension would silently wrap to a huge size_t and be forwarded to the GPU kernel with no error.

transformer_engine/pytorch/csrc/extensions/utils.cpp — the new int64_t to size_t cast lacks a non-negative check.

Important Files Changed

Filename Overview
transformer_engine/pytorch/csrc/extensions/utils.cpp Core fix: actual_data_shape parameter added and used to override the previously incorrect padded-scale inference; no negative-value guard on the cast from int64_t to size_t
transformer_engine/pytorch/csrc/extensions.h Declaration updated to add actual_data_shape with a default empty-vector, matching the implementation change
transformer_engine/pytorch/csrc/extensions/pybind.cpp pybind11 registration updated with matching default std::vector<int64_t>{} for actual_data_shape; backwards-compatible with existing call sites
transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py Both FC1 and FC2 forward calls now pass fc1_weight_shape / fc2_weight_shape as actual_data_shape, preventing padded scale bytes from entering the GEMM
transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py FC2 dactivation and FC1 dgrad backward paths both updated to supply the actual weight shape, mirroring the forward fix
tests/pytorch/mxfp8/test_mxfp8_quantize_swizzle_fusion.py Adds regression test test_mxfp8_pointer_swizzle_uses_unpadded_data_shape that poisons padded scale bytes and verifies they are masked out by the fix

Sequence Diagram

sequenceDiagram
    participant PY as Python MLP
    participant CPP as transform_and_copy_data_ptrs_to_device
    participant GPU as CUDA swizzle kernel
    PY->>CPP: "scale_inv tensors + actual_data_shape=(N,K)"
    alt actual_data_shape provided
        CPP->>CPP: "data_shape = (N, K)"
    else fallback
        CPP->>CPP: "data_shape = (N_padded, K_padded)"
    end
    CPP->>GPU: swizzle with masked padding
    GPU-->>CPP: swizzled buffer
    CPP-->>PY: ptrs + swizzled buffer
Loading

Reviews (1): Last reviewed commit: "[PyTorch] Use data shape for MXFP8 point..." | Re-trigger Greptile

Comment on lines +100 to +103
NVTE_CHECK(actual_data_shape.size() == 2, "Expected 2D data shape, but got ",
actual_data_shape.size(), " dimensions.");
data_shape.data[0] = static_cast<size_t>(actual_data_shape[0]);
data_shape.data[1] = static_cast<size_t>(actual_data_shape[1]);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 The actual_data_shape values are int64_t but are cast directly to size_t without a non-negative guard. A negative dimension (e.g., from a programming mistake at a call site) would silently overflow to a very large size_t, which would then be forwarded to the GPU swizzle kernel as the data extents, potentially causing out-of-bounds memory accesses or incorrect masking.

Suggested change
NVTE_CHECK(actual_data_shape.size() == 2, "Expected 2D data shape, but got ",
actual_data_shape.size(), " dimensions.");
data_shape.data[0] = static_cast<size_t>(actual_data_shape[0]);
data_shape.data[1] = static_cast<size_t>(actual_data_shape[1]);
NVTE_CHECK(actual_data_shape.size() == 2, "Expected 2D data shape, but got ",
actual_data_shape.size(), " dimensions.");
NVTE_CHECK(actual_data_shape[0] > 0 && actual_data_shape[1] > 0,
"Expected positive data shape dimensions, but got [", actual_data_shape[0], ", ",
actual_data_shape[1], "].");
data_shape.data[0] = static_cast<size_t>(actual_data_shape[0]);
data_shape.data[1] = static_cast<size_t>(actual_data_shape[1]);

# Poison padded scale values with E8M0 NaNs. If swizzle reconstructs the
# data shape from the padded scale shape, these values are considered real
# and survive.
scale = poison_mxfp8_scale_padding_with_nan(scale, valid_scale_shape)
Copy link
Copy Markdown
Collaborator

@vthumbe1503 vthumbe1503 Jun 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the scales getting padded with poision values in the first place?

Copy link
Copy Markdown
Member Author

@jepio jepio Jun 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vthumbe1503 I think that part of the scales is normally uninitialized, I poison the
buffer in the test to illustrate what would happen if the uninitialized
values are actually read and included in the computation. This is what
seems to happen in actual training.

@jepio

This comment was marked as duplicate.

Copy link
Copy Markdown
Member

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the bugfix. One of the goals of #3001 was to generalize the helper functions and avoid hard-coding swizzle-specific logic in places where it wasn't needed. However, it seems that this is probably unavoidable.

I'll take a crack at fixing up the function interface. While we're at it, this gives us the opportunity to consolidate the copy-to-device kernel launches for the data pointers and scale pointers.

Comment on lines 41 to +43
std::tuple<at::Tensor, std::optional<at::Tensor>> transform_and_copy_data_ptrs_to_device(
const std::string &transform_type, const std::vector<at::Tensor> &tensors,
const c10::Device &device) {
const c10::Device &device, const std::vector<int64_t> &actual_data_shape) {
Copy link
Copy Markdown
Member

@timmoon10 timmoon10 Jun 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

transform_and_copy_data_ptrs_to_device is named the way it is so that it can be a general-purpose function. You could imagine, for instance, using it for things like SReLU or dtype casts. In practice, though, we only use it for swizzling uniformly-sized MXFP8 and NVFP4 scales.

Hard-coding swizzle-specific logic like actual_data_shape into a general API is not great. But your bug report makes me realize that the general-purpose function is probably not possible, so we should give in and make this a swizzle-specific function. The biggest problem is not the implementation (easy), but that the function name will have to be something ugly like uniform_swizzle_scales_and_copy_data_ptrs_to_device.

@ptrendx
Copy link
Copy Markdown
Member

ptrendx commented Jun 1, 2026

I have some problem understanding where the problem comes from here. Where is the padding of the weights happening? It seems that the padding happens only after the weight was already cast to MXFP8 and the weight and the scale are padded separately? Why are we not padding with 0s - that padding should happen only once?

@timmoon10
Copy link
Copy Markdown
Member

#3076 fixes this bug.

@jepio jepio deleted the jpiotrowski/mxfp8-swizzle-data-shape-fix branch June 3, 2026 07:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants