[PyTorch] Use data shape for MXFP8 pointer swizzle#3070
Conversation
MXFP8 scale tensors are padded to GEMM tile boundaries. The pointer swizzle helper inferred the logical data shape from the padded scale shape, so padded E8M0 values could be treated as valid scale blocks and propagated into grouped MLP GEMM scale buffers. Accept an optional actual data shape for MXFP8 pointer swizzle and pass the weight shapes from fused grouped MLP rowwise and columnwise paths. Add a regression test that poisons padded scale entries with E8M0 NaNs and checks that the swizzle masks them out. Signed-off-by: Jeremi Piotrowski <jpiotrowski@nvidia.com>
ce3d904 to
044ac53
Compare
Greptile SummaryThis PR fixes a NaN-loss regression in MXFP8 grouped MLP training for non-128-aligned dimensions:
Confidence Score: 4/5The fix is targeted and well-tested; the only rough edge is a missing positive-value guard on the new int64_t to size_t cast in the C++ core. All four production call sites are updated consistently, the default empty-vector preserves backwards compatibility, and the new test directly exercises the failure mode described in the bug report. The sole concern is the unguarded cast of actual_data_shape values from int64_t to size_t: a negative dimension would silently wrap to a huge size_t and be forwarded to the GPU kernel with no error. transformer_engine/pytorch/csrc/extensions/utils.cpp — the new int64_t to size_t cast lacks a non-negative check. Important Files Changed
Sequence DiagramsequenceDiagram
participant PY as Python MLP
participant CPP as transform_and_copy_data_ptrs_to_device
participant GPU as CUDA swizzle kernel
PY->>CPP: "scale_inv tensors + actual_data_shape=(N,K)"
alt actual_data_shape provided
CPP->>CPP: "data_shape = (N, K)"
else fallback
CPP->>CPP: "data_shape = (N_padded, K_padded)"
end
CPP->>GPU: swizzle with masked padding
GPU-->>CPP: swizzled buffer
CPP-->>PY: ptrs + swizzled buffer
Reviews (1): Last reviewed commit: "[PyTorch] Use data shape for MXFP8 point..." | Re-trigger Greptile |
| NVTE_CHECK(actual_data_shape.size() == 2, "Expected 2D data shape, but got ", | ||
| actual_data_shape.size(), " dimensions."); | ||
| data_shape.data[0] = static_cast<size_t>(actual_data_shape[0]); | ||
| data_shape.data[1] = static_cast<size_t>(actual_data_shape[1]); |
There was a problem hiding this comment.
The
actual_data_shape values are int64_t but are cast directly to size_t without a non-negative guard. A negative dimension (e.g., from a programming mistake at a call site) would silently overflow to a very large size_t, which would then be forwarded to the GPU swizzle kernel as the data extents, potentially causing out-of-bounds memory accesses or incorrect masking.
| NVTE_CHECK(actual_data_shape.size() == 2, "Expected 2D data shape, but got ", | |
| actual_data_shape.size(), " dimensions."); | |
| data_shape.data[0] = static_cast<size_t>(actual_data_shape[0]); | |
| data_shape.data[1] = static_cast<size_t>(actual_data_shape[1]); | |
| NVTE_CHECK(actual_data_shape.size() == 2, "Expected 2D data shape, but got ", | |
| actual_data_shape.size(), " dimensions."); | |
| NVTE_CHECK(actual_data_shape[0] > 0 && actual_data_shape[1] > 0, | |
| "Expected positive data shape dimensions, but got [", actual_data_shape[0], ", ", | |
| actual_data_shape[1], "]."); | |
| data_shape.data[0] = static_cast<size_t>(actual_data_shape[0]); | |
| data_shape.data[1] = static_cast<size_t>(actual_data_shape[1]); |
| # Poison padded scale values with E8M0 NaNs. If swizzle reconstructs the | ||
| # data shape from the padded scale shape, these values are considered real | ||
| # and survive. | ||
| scale = poison_mxfp8_scale_padding_with_nan(scale, valid_scale_shape) |
There was a problem hiding this comment.
Why is the scales getting padded with poision values in the first place?
There was a problem hiding this comment.
@vthumbe1503 I think that part of the scales is normally uninitialized, I poison the
buffer in the test to illustrate what would happen if the uninitialized
values are actually read and included in the computation. This is what
seems to happen in actual training.
This comment was marked as duplicate.
This comment was marked as duplicate.
There was a problem hiding this comment.
Thanks for the bugfix. One of the goals of #3001 was to generalize the helper functions and avoid hard-coding swizzle-specific logic in places where it wasn't needed. However, it seems that this is probably unavoidable.
I'll take a crack at fixing up the function interface. While we're at it, this gives us the opportunity to consolidate the copy-to-device kernel launches for the data pointers and scale pointers.
| std::tuple<at::Tensor, std::optional<at::Tensor>> transform_and_copy_data_ptrs_to_device( | ||
| const std::string &transform_type, const std::vector<at::Tensor> &tensors, | ||
| const c10::Device &device) { | ||
| const c10::Device &device, const std::vector<int64_t> &actual_data_shape) { |
There was a problem hiding this comment.
transform_and_copy_data_ptrs_to_device is named the way it is so that it can be a general-purpose function. You could imagine, for instance, using it for things like SReLU or dtype casts. In practice, though, we only use it for swizzling uniformly-sized MXFP8 and NVFP4 scales.
Hard-coding swizzle-specific logic like actual_data_shape into a general API is not great. But your bug report makes me realize that the general-purpose function is probably not possible, so we should give in and make this a swizzle-specific function. The biggest problem is not the implementation (easy), but that the function name will have to be something ugly like uniform_swizzle_scales_and_copy_data_ptrs_to_device.
|
I have some problem understanding where the problem comes from here. Where is the padding of the weights happening? It seems that the padding happens only after the weight was already cast to MXFP8 and the weight and the scale are padded separately? Why are we not padding with 0s - that padding should happen only once? |
|
#3076 fixes this bug. |
Description
We've started seeing NaN loss in gpt_oss_20b training using MXFP8. gpt_oss_20b uses expert MLP dimensions based on hidden size 2880, producing FC1 weights of (5760, 2880) and FC2 weights of (2880, 2880). I bisected this to 9e5a847. That change is not safe in the presence of MXFP8 scale padding: without the actual unpadded data shape,
transform_and_copy_data_ptrs_to_deviceinfers the data shape from the padded scale shape, so padded E8M0 scale bytes can be copied into the swizzled GEMM scale buffer and consumed by GEMM, causing NaNs. Fix it by passing the actual data shape to transform_and_copy_data_ptrs_to_device. The patch includes a regression test that poisons padded scale entries and verifies they are masked out.Fixes #
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: