From 8e06855a68c52fb37a360e9caa55f51f5c89a7bc Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Mon, 1 Jun 2026 23:06:51 +0000 Subject: [PATCH 01/12] [PyTorch] Move transform_and_copy_data_ptrs_to_device to experimental submodule Introduces a `grouped_mlp_experimental` pybind submodule as a labeled home for hyperspecific helpers that exist to satisfy the cuDNN CuTe DSL grouped GEMM kernels. The submodule itself is documented as unstable, so callers can see at the import path that these helpers are not part of the supported surface. `copy_data_ptrs_to_device` is genuinely general-purpose and stays at the top level; only `transform_and_copy_data_ptrs_to_device` moves into the submodule, and its four call sites in the fused grouped MLP forward/backward are updated accordingly. Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: Tim Moon --- .../pytorch/csrc/extensions/pybind.cpp | 17 ++++++++++++---- .../pytorch/ops/fused/backward_grouped_mlp.py | 20 +++++++++++-------- .../pytorch/ops/fused/forward_grouped_mlp.py | 20 +++++++++++-------- 3 files changed, 37 insertions(+), 20 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index c1b38a2275..944ba9b233 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -494,10 +494,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::call_guard()); m.def("copy_data_ptrs_to_device", &transformer_engine::pytorch::copy_data_ptrs_to_device, py::arg("tensors"), py::arg("device"), py::call_guard()); - m.def("transform_and_copy_data_ptrs_to_device", - &transformer_engine::pytorch::transform_and_copy_data_ptrs_to_device, - py::arg("transform_type"), py::arg("tensors"), py::arg("device"), - py::call_guard()); m.def("splits_to_offsets", &transformer_engine::pytorch::splits_to_offsets, "Compute grouped tensor offsets from split sizes", py::arg("first_dims"), py::arg("logical_last_dim"), py::call_guard()); @@ -607,6 +603,19 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::call_guard(), py::arg("allgather_communicator"), py::arg("send_stream"), py::arg("recv_stream")); + // Helpers for experimental fused grouped MLP + // + // These are intended for compatibility with cuDNN CuTe DSL grouped + // GEMM kernels. Since those are unstable and under active + // development, these helpers should also be considered unstable. + auto grouped_mlp_experimental = m.def_submodule( + "grouped_mlp_experimental", + "Experimental helpers for the fused grouped MLP (unstable, may change or disappear)."); + grouped_mlp_experimental.def("transform_and_copy_data_ptrs_to_device", + &transformer_engine::pytorch::transform_and_copy_data_ptrs_to_device, + py::arg("transform_type"), py::arg("tensors"), py::arg("device"), + py::call_guard()); + // Data structures py::class_(m, "FP8TensorMeta") .def(py::init<>()) diff --git a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py index 792b6d7811..0f06c96cfa 100644 --- a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py @@ -656,10 +656,12 @@ def fuser_backward( swizzle_type = ( "uniform_nvfp4_swizzle" if use_nvfp4 else "uniform_mxfp8_columnwise_swizzle" ) - fc2_sfb_ptrs, _fc2_sfb_buffer = tex.transform_and_copy_data_ptrs_to_device( - swizzle_type, - [w._columnwise_scale_inv for w in grouped_fc2_weight], - device, + fc2_sfb_ptrs, _fc2_sfb_buffer = ( + tex.grouped_mlp_experimental.transform_and_copy_data_ptrs_to_device( + swizzle_type, + [w._columnwise_scale_inv for w in grouped_fc2_weight], + device, + ) ) fc2_dactivation_kwargs["b_ptrs"] = fc2_b_ptrs fc2_dactivation_kwargs["sfb_ptrs"] = fc2_sfb_ptrs @@ -918,10 +920,12 @@ def fuser_backward( swizzle_type = ( "uniform_nvfp4_swizzle" if use_nvfp4 else "uniform_mxfp8_columnwise_swizzle" ) - fc1_sfb_ptrs, _fc1_sfb_buffer = tex.transform_and_copy_data_ptrs_to_device( - swizzle_type, - [w._columnwise_scale_inv for w in grouped_fc1_weight], - device, + fc1_sfb_ptrs, _fc1_sfb_buffer = ( + tex.grouped_mlp_experimental.transform_and_copy_data_ptrs_to_device( + swizzle_type, + [w._columnwise_scale_inv for w in grouped_fc1_weight], + device, + ) ) fc1_dgrad_kwargs["b_ptrs"] = fc1_b_ptrs diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index f4f2108578..9a2113a8ba 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -454,10 +454,12 @@ def fuser_forward( device, ) swizzle_type = "uniform_nvfp4_swizzle" if use_nvfp4 else "uniform_mxfp8_rowwise_swizzle" - fc1_sfb_ptrs, _fc1_sfb_buffer = tex.transform_and_copy_data_ptrs_to_device( - swizzle_type, - [w._rowwise_scale_inv for w in grouped_fc1_weight], - device, + fc1_sfb_ptrs, _fc1_sfb_buffer = ( + tex.grouped_mlp_experimental.transform_and_copy_data_ptrs_to_device( + swizzle_type, + [w._rowwise_scale_inv for w in grouped_fc1_weight], + device, + ) ) fc1_activation_kwargs["b_ptrs"] = fc1_b_ptrs fc1_activation_kwargs["sfb_ptrs"] = fc1_sfb_ptrs @@ -632,10 +634,12 @@ def fuser_forward( swizzle_type = ( "uniform_nvfp4_swizzle" if use_nvfp4 else "uniform_mxfp8_rowwise_swizzle" ) - fc2_sfb_ptrs, _fc2_sfb_buffer = tex.transform_and_copy_data_ptrs_to_device( - swizzle_type, - [w._rowwise_scale_inv for w in grouped_fc2_weight], - device, + fc2_sfb_ptrs, _fc2_sfb_buffer = ( + tex.grouped_mlp_experimental.transform_and_copy_data_ptrs_to_device( + swizzle_type, + [w._rowwise_scale_inv for w in grouped_fc2_weight], + device, + ) ) fc2_quant_kwargs["b_ptrs"] = fc2_b_ptrs fc2_quant_kwargs["sfb_ptrs"] = fc2_sfb_ptrs From 1f42b3ffbe75b2363d03ed59feedaaaa8f608575 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Mon, 1 Jun 2026 23:53:06 +0000 Subject: [PATCH 02/12] [PyTorch] Fuse data + scale ptr packing for discrete-weight grouped MLP MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replaces `transform_and_copy_data_ptrs_to_device` with a more focused helper, `swizzle_scales_and_pack_ptrs_for_discrete_weights`. The new function takes both the FP8/FP4 weight data tensors and their scale tensors, swizzles the scales, and copies both pointer arrays to device in a single kernel launch (down from two — one from `copy_data_ptrs_to_device` for data and one from the old transform helper for scales). The two returned pointer arrays are views into a single packed device buffer. The general "transform_type" string dispatch is gone: the function only supports `mxfp8_rowwise`, `mxfp8_columnwise`, and `nvfp4`, which were the only modes ever used. The four discrete-weight call sites in the fused grouped MLP forward/backward collapse their paired `copy_data_ptrs_to_device` + transform calls into a single call. The implementation moves to a dedicated source file, `csrc/extensions/grouped_mlp_experimental.cpp`, so the experimental submodule has a clear home for future helpers tied to the cuDNN CuTe DSL grouped GEMM kernels. The declaration in `extensions.h` is grouped under a matching banner. `copy_data_ptrs_to_device` stays in `utils.cpp` since it remains a general-purpose helper. Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: Tim Moon --- transformer_engine/pytorch/csrc/extensions.h | 18 ++- .../extensions/grouped_mlp_experimental.cpp | 138 +++++++++++++++++ .../pytorch/csrc/extensions/pybind.cpp | 13 +- .../pytorch/csrc/extensions/utils.cpp | 140 ------------------ .../pytorch/ops/fused/backward_grouped_mlp.py | 28 +--- .../pytorch/ops/fused/forward_grouped_mlp.py | 26 +--- 6 files changed, 175 insertions(+), 188 deletions(-) create mode 100644 transformer_engine/pytorch/csrc/extensions/grouped_mlp_experimental.cpp diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 1f9974448b..b0bea69c88 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -490,8 +490,22 @@ at::Tensor splits_to_offsets(const at::Tensor &first_dims, int64_t logical_last_ at::Tensor copy_data_ptrs_to_device(const std::vector &tensors, const c10::Device &device); -std::tuple> transform_and_copy_data_ptrs_to_device( - const std::string &transform_type, const std::vector &tensors, +/*************************************************************************************************** + * Experimental helpers for the fused grouped MLP + * + * These primarily exist to support cuDNN CuTe DSL grouped GEMM + * kernels. Since those are unstable and under active development, + * these helpers should also be considered unstable. + **************************************************************************************************/ + +// Prepare discrete weight tensors for the cuDNN CuTe DSL grouped GEMM +// kernel by swizzling scales and copying data and scale pointers to +// device. All tensors must share a uniform shape and `format` must be +// one of "mxfp8_rowwise", "mxfp8_columnwise", or "nvfp4". Returns +// {data_ptrs_device, scale_ptrs_device, swizzled_scales_buffer}. +std::tuple swizzle_scales_and_pack_ptrs_for_discrete_weights( + const std::vector &data_tensors, + const std::vector &scale_tensors, const std::string &format, const c10::Device &device); /*************************************************************************************************** diff --git a/transformer_engine/pytorch/csrc/extensions/grouped_mlp_experimental.cpp b/transformer_engine/pytorch/csrc/extensions/grouped_mlp_experimental.cpp new file mode 100644 index 0000000000..5aac02f6ef --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/grouped_mlp_experimental.cpp @@ -0,0 +1,138 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +// Experimental helpers for the fused grouped MLP. + +#include + +#include +#include +#include +#include + +#include "common/common.h" +#include "extensions.h" + +namespace transformer_engine::pytorch { + +std::tuple swizzle_scales_and_pack_ptrs_for_discrete_weights( + const std::vector &data_tensors, + const std::vector &scale_tensors, const std::string &format, + const c10::Device &device) { + const size_t num_tensors = data_tensors.size(); + NVTE_CHECK(scale_tensors.size() == num_tensors, + "Expected data_tensors and scale_tensors to have matching sizes, but got ", + num_tensors, " and ", scale_tensors.size(), "."); + + // Decode format + const bool is_mxfp8_rowwise = format == "mxfp8_rowwise"; + const bool is_mxfp8_columnwise = format == "mxfp8_columnwise"; + const bool is_nvfp4 = format == "nvfp4"; + NVTE_CHECK(is_mxfp8_rowwise || is_mxfp8_columnwise || is_nvfp4, "Unsupported format (", format, + "). Expected one of: mxfp8_rowwise, mxfp8_columnwise, nvfp4."); + + // Trivial case: no tensors. Return three empty tensors. + if (num_tensors == 0) { + auto empty_ptrs = at::empty({0}, at::TensorOptions().dtype(at::kLong).device(device)); + auto empty_scales = at::empty({0}, at::TensorOptions().dtype(at::kByte).device(device)); + return {empty_ptrs, empty_ptrs.clone(), std::move(empty_scales)}; + } + + // CUDA stream + auto stream = at::cuda::getCurrentCUDAStream(); + + // Tensor format + const NVTEScalingMode scaling_mode = is_nvfp4 ? NVTE_NVFP4_1D_SCALING : NVTE_MXFP8_1D_SCALING; + const transformer_engine::DType data_dtype = + is_nvfp4 ? transformer_engine::DType::kFloat4E2M1 : transformer_engine::DType::kFloat8E4M3; + const transformer_engine::DType scale_dtype = + is_nvfp4 ? transformer_engine::DType::kFloat8E4M3 : transformer_engine::DType::kFloat8E8M0; + + // Scale shape + const NVTEShape scale_shape = convertTorchShape(scale_tensors[0].sizes()); + NVTE_CHECK(scale_shape.ndim == 2, + "Expected 2D scale tensor, but got shape=", getTensorShape(scale_tensors[0]), "."); + const size_t scale_numel = scale_shape.data[0] * scale_shape.data[1]; + const size_t scale_dtype_bits = transformer_engine::pytorch::typeToNumBits(scale_dtype); + const size_t scale_bytes = ceildiv(scale_numel * scale_dtype_bits, 8); + + // Expected data shape. + // Note: May not match actual data shape since the scales are padded. + // This is fine since the swizzle kernel does not touch the data. + NVTEShape data_shape; + data_shape.ndim = 2; + if (is_mxfp8_rowwise) { + data_shape.data[0] = scale_shape.data[0]; + data_shape.data[1] = scale_shape.data[1] * 32; + } else if (is_mxfp8_columnwise) { + data_shape.data[0] = scale_shape.data[0] * 32; + data_shape.data[1] = scale_shape.data[1]; + } else { // nvfp4 + data_shape.data[0] = scale_shape.data[0]; + data_shape.data[1] = scale_shape.data[1] * 16; + } + + // Allocate single buffer for swizzled scales. Uses a uniform stride since + // all tensors share the same scale shape. + const size_t swizzled_scales_stride = roundup(scale_bytes, 16); // Align to 16 bytes + auto swizzled_scales = at::empty({static_cast(swizzled_scales_stride * num_tensors)}, + at::TensorOptions().dtype(at::kByte).device(device)); + uint8_t *swizzled_scales_dptr = reinterpret_cast(swizzled_scales.data_ptr()); + + // Allocate input/output NVTETensors as a single batch. The first + // num_tensors entries are inputs; the next num_tensors are outputs. + MultiTensorWrapper nvte_tensors(2 * num_tensors, scaling_mode); + NVTETensor *inputs_nvte = nvte_tensors.data(); + NVTETensor *outputs_nvte = nvte_tensors.data() + num_tensors; + + auto set_param = [](NVTETensor t, NVTETensorParam param, void *dptr, + transformer_engine::DType dtype, const NVTEShape &shape) { + NVTEBasicTensor data{dptr, static_cast(dtype), shape}; + nvte_set_tensor_param_v2(t, param, &data, sizeof(data)); + }; + + // MXFP8 columnwise tags its data/scale params on the columnwise side; the + // other two formats (mxfp8_rowwise, nvfp4) tag on the rowwise side. + const NVTETensorParam data_param = is_mxfp8_columnwise ? kNVTEColumnwiseData : kNVTERowwiseData; + const NVTETensorParam scale_param = + is_mxfp8_columnwise ? kNVTEColumnwiseScaleInv : kNVTERowwiseScaleInv; + + for (size_t i = 0; i < num_tensors; ++i) { + const uint8_t swizzled_flag = 1; + nvte_set_tensor_param_v2(outputs_nvte[i], kNVTEWithGEMMSwizzledScales, &swizzled_flag, + sizeof(swizzled_flag)); + void *in_scale_ptr = scale_tensors[i].data_ptr(); + void *out_scale_ptr = swizzled_scales_dptr + i * swizzled_scales_stride; + set_param(inputs_nvte[i], data_param, nullptr, data_dtype, data_shape); + set_param(inputs_nvte[i], scale_param, in_scale_ptr, scale_dtype, scale_shape); + set_param(outputs_nvte[i], data_param, nullptr, data_dtype, data_shape); + set_param(outputs_nvte[i], scale_param, out_scale_ptr, scale_dtype, scale_shape); + } + + // Launch swizzle kernel + nvte_multi_tensor_swizzle_scaling_factors(inputs_nvte, outputs_nvte, num_tensors, stream); + + // Pack data pointers (first half) and swizzled scale pointers (second half) + // into a single host buffer and copy to device with one kernel launch. + std::vector packed_ptrs_host(2 * num_tensors); + for (size_t i = 0; i < num_tensors; ++i) { + packed_ptrs_host[i] = reinterpret_cast(data_tensors[i].data_ptr()); + packed_ptrs_host[num_tensors + i] = + reinterpret_cast(swizzled_scales_dptr + i * swizzled_scales_stride); + } + auto packed_ptrs_device = at::empty({static_cast(2 * num_tensors)}, + at::TensorOptions().dtype(at::kLong).device(device)); + nvte_copy_host_to_device_via_kernel(packed_ptrs_host.data(), packed_ptrs_device.data_ptr(), + 2 * num_tensors * sizeof(uint64_t), stream); + + // Return the two pointer arrays as views into the packed device buffer. + auto data_ptrs = packed_ptrs_device.narrow(0, 0, static_cast(num_tensors)); + auto scale_ptrs = packed_ptrs_device.narrow(0, static_cast(num_tensors), + static_cast(num_tensors)); + return {std::move(data_ptrs), std::move(scale_ptrs), std::move(swizzled_scales)}; +} + +} // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 944ba9b233..3029b22dda 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -603,17 +603,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::call_guard(), py::arg("allgather_communicator"), py::arg("send_stream"), py::arg("recv_stream")); - // Helpers for experimental fused grouped MLP - // - // These are intended for compatibility with cuDNN CuTe DSL grouped - // GEMM kernels. Since those are unstable and under active - // development, these helpers should also be considered unstable. + // Experimental fused grouped MLP auto grouped_mlp_experimental = m.def_submodule( "grouped_mlp_experimental", "Experimental helpers for the fused grouped MLP (unstable, may change or disappear)."); - grouped_mlp_experimental.def("transform_and_copy_data_ptrs_to_device", - &transformer_engine::pytorch::transform_and_copy_data_ptrs_to_device, - py::arg("transform_type"), py::arg("tensors"), py::arg("device"), + grouped_mlp_experimental.def("swizzle_scales_and_pack_ptrs_for_discrete_weights", + &transformer_engine::pytorch::swizzle_scales_and_pack_ptrs_for_discrete_weights, + py::arg("data_tensors"), py::arg("scale_tensors"), + py::arg("format"), py::arg("device"), py::call_guard()); // Data structures diff --git a/transformer_engine/pytorch/csrc/extensions/utils.cpp b/transformer_engine/pytorch/csrc/extensions/utils.cpp index 453f238c0d..228c831ec0 100644 --- a/transformer_engine/pytorch/csrc/extensions/utils.cpp +++ b/transformer_engine/pytorch/csrc/extensions/utils.cpp @@ -6,10 +6,6 @@ #include -#include -#include -#include -#include #include #include "common/common.h" @@ -38,140 +34,4 @@ at::Tensor copy_data_ptrs_to_device(const std::vector &tensors, return ptrs_device; } -std::tuple> transform_and_copy_data_ptrs_to_device( - const std::string &transform_type, const std::vector &tensors, - const c10::Device &device) { - const size_t num_tensors = tensors.size(); - - // Trivial cases - if (transform_type.empty()) { - // No transform, just load pointers on device - return {copy_data_ptrs_to_device(tensors, device), std::nullopt}; - } - if (num_tensors == 0) { - // No input tensors, return tensor with no elements - return {at::empty({int64_t{0}}, at::TensorOptions().dtype(at::kLong).device(device)), - std::nullopt}; - } - - // CUDA stream - auto stream = at::cuda::getCurrentCUDAStream(); - - // Swizzle scales for GEMM, with uniform tensor sizes - const bool uniform_mxfp8_rowwise_swizzle = transform_type == "uniform_mxfp8_rowwise_swizzle"; - const bool uniform_mxfp8_colwise_swizzle = transform_type == "uniform_mxfp8_columnwise_swizzle"; - const bool uniform_nvfp4_swizzle = transform_type == "uniform_nvfp4_swizzle"; - if (uniform_mxfp8_rowwise_swizzle || uniform_mxfp8_colwise_swizzle || uniform_nvfp4_swizzle) { - // Tensor format - NVTEScalingMode scaling_mode = NVTE_INVALID_SCALING; - if (uniform_mxfp8_rowwise_swizzle || uniform_mxfp8_colwise_swizzle) { - scaling_mode = NVTE_MXFP8_1D_SCALING; - } else if (uniform_nvfp4_swizzle) { - scaling_mode = NVTE_NVFP4_1D_SCALING; - } - - // Data types - transformer_engine::DType data_dtype, scale_dtype; - switch (scaling_mode) { - case NVTE_MXFP8_1D_SCALING: - data_dtype = transformer_engine::DType::kFloat8E4M3; - scale_dtype = transformer_engine::DType::kFloat8E8M0; - break; - case NVTE_NVFP4_1D_SCALING: - data_dtype = transformer_engine::DType::kFloat4E2M1; - scale_dtype = transformer_engine::DType::kFloat8E4M3; - break; - default: - NVTE_ERROR("Unsupported case."); - } - - // Scale shape - const NVTEShape scale_shape = convertTorchShape(tensors[0].sizes()); - NVTE_CHECK(scale_shape.ndim == 2, - "Expected 2D scale tensor, but got shape=", getTensorShape(tensors[0]), "."); - const size_t scale_numel = scale_shape.data[0] * scale_shape.data[1]; - const size_t scale_dtype_bits = transformer_engine::pytorch::typeToNumBits(scale_dtype); - const size_t scale_bytes = ceildiv(scale_numel * scale_dtype_bits, 8); - - // Expected data shape - // Note: May not match actual data shape since the scales are padded. - // This is fine since we're not actually touching the data. - NVTEShape data_shape; - data_shape.ndim = 2; - if (uniform_mxfp8_rowwise_swizzle) { - data_shape.data[0] = scale_shape.data[0]; - data_shape.data[1] = scale_shape.data[1] * 32; - } else if (uniform_mxfp8_colwise_swizzle) { - data_shape.data[0] = scale_shape.data[0] * 32; - data_shape.data[1] = scale_shape.data[1]; - } else if (uniform_nvfp4_swizzle) { - data_shape.data[0] = scale_shape.data[0]; - data_shape.data[1] = scale_shape.data[1] * 16; - } else { - NVTE_ERROR("Unsupported case."); - } - - // Allocate single buffer for swizzled scales. - // Uses a uniform stride since all tensors share the same scale shape. - const size_t swizzled_scales_stride = roundup(scale_bytes, 16); // Align to 16 bytes - auto swizzled_scales = at::empty({static_cast(swizzled_scales_stride * num_tensors)}, - at::TensorOptions().dtype(at::kByte).device(device)); - uint8_t *swizzled_scales_dptr = reinterpret_cast(swizzled_scales.data_ptr()); - - // Allocate input/output NVTETensors as a single batch. The first - // num_tensors entries are inputs; the next num_tensors are outputs. - MultiTensorWrapper nvte_tensors(2 * num_tensors, scaling_mode); - NVTETensor *inputs_nvte = nvte_tensors.data(); - NVTETensor *outputs_nvte = nvte_tensors.data() + num_tensors; - - auto set_param = [](NVTETensor t, NVTETensorParam param, void *dptr, - transformer_engine::DType dtype, const NVTEShape &shape) { - NVTEBasicTensor data{dptr, static_cast(dtype), shape}; - nvte_set_tensor_param_v2(t, param, &data, sizeof(data)); - }; - - for (size_t i = 0; i < num_tensors; ++i) { - const uint8_t swizzled_flag = 1; - nvte_set_tensor_param_v2(outputs_nvte[i], kNVTEWithGEMMSwizzledScales, &swizzled_flag, - sizeof(swizzled_flag)); - void *in_scale_ptr = tensors[i].data_ptr(); - void *out_scale_ptr = swizzled_scales_dptr + i * swizzled_scales_stride; - if (uniform_mxfp8_rowwise_swizzle || uniform_nvfp4_swizzle) { - set_param(inputs_nvte[i], kNVTERowwiseData, nullptr, data_dtype, data_shape); - set_param(inputs_nvte[i], kNVTERowwiseScaleInv, in_scale_ptr, scale_dtype, scale_shape); - set_param(outputs_nvte[i], kNVTERowwiseData, nullptr, data_dtype, data_shape); - set_param(outputs_nvte[i], kNVTERowwiseScaleInv, out_scale_ptr, scale_dtype, scale_shape); - } else if (uniform_mxfp8_colwise_swizzle) { - set_param(inputs_nvte[i], kNVTEColumnwiseData, nullptr, data_dtype, data_shape); - set_param(inputs_nvte[i], kNVTEColumnwiseScaleInv, in_scale_ptr, scale_dtype, scale_shape); - set_param(outputs_nvte[i], kNVTEColumnwiseData, nullptr, data_dtype, data_shape); - set_param(outputs_nvte[i], kNVTEColumnwiseScaleInv, out_scale_ptr, scale_dtype, - scale_shape); - } - } - - // Launch kernel - nvte_multi_tensor_swizzle_scaling_factors(inputs_nvte, outputs_nvte, num_tensors, stream); - - // Collect data pointers - std::vector ptrs_host; - ptrs_host.reserve(num_tensors); - for (size_t i = 0; i < num_tensors; ++i) { - ptrs_host.push_back( - reinterpret_cast(swizzled_scales_dptr + i * swizzled_scales_stride)); - } - - // Load pointers on device - auto ptrs_device = at::empty({static_cast(num_tensors)}, - at::TensorOptions().dtype(at::kLong).device(device)); - nvte_copy_host_to_device_via_kernel(ptrs_host.data(), ptrs_device.data_ptr(), - num_tensors * sizeof(uint64_t), stream); - - return {std::move(ptrs_device), std::move(swizzled_scales)}; - } - - // Unsupported transform - NVTE_ERROR("Unsupported transform type (", transform_type, ")"); -} - } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py index 0f06c96cfa..3e4b898ee4 100644 --- a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py @@ -649,17 +649,11 @@ def fuser_backward( fc2_dactivation_kwargs["b_tensor"] = fc2_w_data fc2_dactivation_kwargs["sfb_tensor"] = fc2_w_scales else: - fc2_b_ptrs = tex.copy_data_ptrs_to_device( - [w._columnwise_data for w in grouped_fc2_weight], - device, - ) - swizzle_type = ( - "uniform_nvfp4_swizzle" if use_nvfp4 else "uniform_mxfp8_columnwise_swizzle" - ) - fc2_sfb_ptrs, _fc2_sfb_buffer = ( - tex.grouped_mlp_experimental.transform_and_copy_data_ptrs_to_device( - swizzle_type, + fc2_b_ptrs, fc2_sfb_ptrs, _fc2_sfb_buffer = ( + tex.grouped_mlp_experimental.swizzle_scales_and_pack_ptrs_for_discrete_weights( + [w._columnwise_data for w in grouped_fc2_weight], [w._columnwise_scale_inv for w in grouped_fc2_weight], + "nvfp4" if use_nvfp4 else "mxfp8_columnwise", device, ) ) @@ -913,17 +907,11 @@ def fuser_backward( fc1_dgrad_kwargs["b_tensor"] = fc1_w_data fc1_dgrad_kwargs["sfb_tensor"] = fc1_w_scales else: - fc1_b_ptrs = tex.copy_data_ptrs_to_device( - [w._columnwise_data for w in grouped_fc1_weight], - device, - ) - swizzle_type = ( - "uniform_nvfp4_swizzle" if use_nvfp4 else "uniform_mxfp8_columnwise_swizzle" - ) - fc1_sfb_ptrs, _fc1_sfb_buffer = ( - tex.grouped_mlp_experimental.transform_and_copy_data_ptrs_to_device( - swizzle_type, + fc1_b_ptrs, fc1_sfb_ptrs, _fc1_sfb_buffer = ( + tex.grouped_mlp_experimental.swizzle_scales_and_pack_ptrs_for_discrete_weights( + [w._columnwise_data for w in grouped_fc1_weight], [w._columnwise_scale_inv for w in grouped_fc1_weight], + "nvfp4" if use_nvfp4 else "mxfp8_columnwise", device, ) ) diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index 9a2113a8ba..ef9319918f 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -449,15 +449,11 @@ def fuser_forward( fc1_activation_kwargs["sfb_tensor"] = fc1_w_scales else: # Discrete-weight kernel: per-expert data/scale pointers - fc1_b_ptrs = tex.copy_data_ptrs_to_device( - [w._rowwise_data for w in grouped_fc1_weight], - device, - ) - swizzle_type = "uniform_nvfp4_swizzle" if use_nvfp4 else "uniform_mxfp8_rowwise_swizzle" - fc1_sfb_ptrs, _fc1_sfb_buffer = ( - tex.grouped_mlp_experimental.transform_and_copy_data_ptrs_to_device( - swizzle_type, + fc1_b_ptrs, fc1_sfb_ptrs, _fc1_sfb_buffer = ( + tex.grouped_mlp_experimental.swizzle_scales_and_pack_ptrs_for_discrete_weights( + [w._rowwise_data for w in grouped_fc1_weight], [w._rowwise_scale_inv for w in grouped_fc1_weight], + "nvfp4" if use_nvfp4 else "mxfp8_rowwise", device, ) ) @@ -627,17 +623,11 @@ def fuser_forward( fc2_quant_kwargs["b_tensor"] = fc2_w_data fc2_quant_kwargs["sfb_tensor"] = fc2_w_scales else: - fc2_b_ptrs = tex.copy_data_ptrs_to_device( - [w._rowwise_data for w in grouped_fc2_weight], - device, - ) - swizzle_type = ( - "uniform_nvfp4_swizzle" if use_nvfp4 else "uniform_mxfp8_rowwise_swizzle" - ) - fc2_sfb_ptrs, _fc2_sfb_buffer = ( - tex.grouped_mlp_experimental.transform_and_copy_data_ptrs_to_device( - swizzle_type, + fc2_b_ptrs, fc2_sfb_ptrs, _fc2_sfb_buffer = ( + tex.grouped_mlp_experimental.swizzle_scales_and_pack_ptrs_for_discrete_weights( + [w._rowwise_data for w in grouped_fc2_weight], [w._rowwise_scale_inv for w in grouped_fc2_weight], + "nvfp4" if use_nvfp4 else "mxfp8_rowwise", device, ) ) From 80fcdcfac13b7bf31be70dd60fb77e665184ce0a Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Tue, 2 Jun 2026 00:38:46 +0000 Subject: [PATCH 03/12] [PyTorch] Use real data shapes and dispatch on a format enum in swizzle helper In swizzle_scales_and_pack_ptrs_for_discrete_weights, take the data shape directly from the data tensors instead of inferring it from the padded scale shape. NVFP4 packs two 4-bit values per byte, so the byte-shape's inner dim is doubled to recover the logical element count. Also replace the trio of is_mxfp8_rowwise / is_mxfp8_columnwise / is_nvfp4 booleans with a function-local TensorFormat enum. Tensor properties (scaling mode, dtypes, swizzle param names) are assigned together per case in a single switch so adding a future format is a single-point change rather than a fresh boolean threaded through the function. Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: Tim Moon --- .../extensions/grouped_mlp_experimental.cpp | 94 +++++++++++-------- 1 file changed, 54 insertions(+), 40 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/grouped_mlp_experimental.cpp b/transformer_engine/pytorch/csrc/extensions/grouped_mlp_experimental.cpp index 5aac02f6ef..b89ebefea5 100644 --- a/transformer_engine/pytorch/csrc/extensions/grouped_mlp_experimental.cpp +++ b/transformer_engine/pytorch/csrc/extensions/grouped_mlp_experimental.cpp @@ -20,21 +20,25 @@ namespace transformer_engine::pytorch { std::tuple swizzle_scales_and_pack_ptrs_for_discrete_weights( const std::vector &data_tensors, - const std::vector &scale_tensors, const std::string &format, + const std::vector &scale_tensors, const std::string &format_str, const c10::Device &device) { const size_t num_tensors = data_tensors.size(); NVTE_CHECK(scale_tensors.size() == num_tensors, "Expected data_tensors and scale_tensors to have matching sizes, but got ", num_tensors, " and ", scale_tensors.size(), "."); - // Decode format - const bool is_mxfp8_rowwise = format == "mxfp8_rowwise"; - const bool is_mxfp8_columnwise = format == "mxfp8_columnwise"; - const bool is_nvfp4 = format == "nvfp4"; - NVTE_CHECK(is_mxfp8_rowwise || is_mxfp8_columnwise || is_nvfp4, "Unsupported format (", format, - "). Expected one of: mxfp8_rowwise, mxfp8_columnwise, nvfp4."); + // Parse tensor format + enum class TensorFormat { Invalid, MXFP8Rowwise, MXFP8Columnwise, NVFP4 }; + TensorFormat format = TensorFormat::Invalid; + if (format_str == "mxfp8_rowwise") { format = TensorFormat::MXFP8Rowwise; } + else if (format_str == "mxfp8_columnwise") { format = TensorFormat::MXFP8Columnwise; } + else if (format_str == "nvfp4") { format = TensorFormat::NVFP4; } + else { + NVTE_ERROR("Unsupported format (", format_str, + "). Expected one of: mxfp8_rowwise, mxfp8_columnwise, nvfp4."); + } - // Trivial case: no tensors. Return three empty tensors. + // Trivial case: no tensors. Return empty tensors. if (num_tensors == 0) { auto empty_ptrs = at::empty({0}, at::TensorOptions().dtype(at::kLong).device(device)); auto empty_scales = at::empty({0}, at::TensorOptions().dtype(at::kByte).device(device)); @@ -44,12 +48,43 @@ std::tuple swizzle_scales_and_pack_ptrs_for_ // CUDA stream auto stream = at::cuda::getCurrentCUDAStream(); - // Tensor format - const NVTEScalingMode scaling_mode = is_nvfp4 ? NVTE_NVFP4_1D_SCALING : NVTE_MXFP8_1D_SCALING; - const transformer_engine::DType data_dtype = - is_nvfp4 ? transformer_engine::DType::kFloat4E2M1 : transformer_engine::DType::kFloat8E4M3; - const transformer_engine::DType scale_dtype = - is_nvfp4 ? transformer_engine::DType::kFloat8E4M3 : transformer_engine::DType::kFloat8E8M0; + // Tensor properties + NVTEScalingMode scaling_mode; + transformer_engine::DType data_dtype, scale_dtype; + NVTETensorParam data_param_name, scale_param_name; + switch (format) { + case TensorFormat::MXFP8Rowwise: + case TensorFormat::MXFP8Columnwise: + scaling_mode = NVTE_MXFP8_1D_SCALING; + data_dtype = transformer_engine::DType::kFloat8E4M3; + scale_dtype = transformer_engine::DType::kFloat8E8M0; + if (format == TensorFormat::MXFP8Rowwise) { + data_param_name = kNVTERowwiseData; + scale_param_name = kNVTERowwiseScaleInv; + } else { + data_param_name = kNVTEColumnwiseData; + scale_param_name = kNVTEColumnwiseScaleInv; + } + break; + case TensorFormat::NVFP4: + scaling_mode = NVTE_NVFP4_1D_SCALING; + data_dtype = transformer_engine::DType::kFloat4E2M1; + scale_dtype = transformer_engine::DType::kFloat8E4M3; + data_param_name = kNVTERowwiseData; + scale_param_name = kNVTERowwiseScaleInv; + break; + default: + NVTE_ERROR("Unsupported format (", static_cast(format), ")."); + } + + // Data shape + NVTEShape data_shape = convertTorchShape(data_tensors[0].sizes()); + if (format == TensorFormat::NVFP4) { + // NVFP4 packs two 4-bit values per byte + NVTE_CHECK(data_shape.ndim > 0, + "Invalid shape for NVFP4 data tensor (", getTensorShape(data_tensors[0]), ")."); + data_shape.data[data_shape.ndim - 1] *= 2; + } // Scale shape const NVTEShape scale_shape = convertTorchShape(scale_tensors[0].sizes()); @@ -59,22 +94,6 @@ std::tuple swizzle_scales_and_pack_ptrs_for_ const size_t scale_dtype_bits = transformer_engine::pytorch::typeToNumBits(scale_dtype); const size_t scale_bytes = ceildiv(scale_numel * scale_dtype_bits, 8); - // Expected data shape. - // Note: May not match actual data shape since the scales are padded. - // This is fine since the swizzle kernel does not touch the data. - NVTEShape data_shape; - data_shape.ndim = 2; - if (is_mxfp8_rowwise) { - data_shape.data[0] = scale_shape.data[0]; - data_shape.data[1] = scale_shape.data[1] * 32; - } else if (is_mxfp8_columnwise) { - data_shape.data[0] = scale_shape.data[0] * 32; - data_shape.data[1] = scale_shape.data[1]; - } else { // nvfp4 - data_shape.data[0] = scale_shape.data[0]; - data_shape.data[1] = scale_shape.data[1] * 16; - } - // Allocate single buffer for swizzled scales. Uses a uniform stride since // all tensors share the same scale shape. const size_t swizzled_scales_stride = roundup(scale_bytes, 16); // Align to 16 bytes @@ -94,22 +113,17 @@ std::tuple swizzle_scales_and_pack_ptrs_for_ nvte_set_tensor_param_v2(t, param, &data, sizeof(data)); }; - // MXFP8 columnwise tags its data/scale params on the columnwise side; the - // other two formats (mxfp8_rowwise, nvfp4) tag on the rowwise side. - const NVTETensorParam data_param = is_mxfp8_columnwise ? kNVTEColumnwiseData : kNVTERowwiseData; - const NVTETensorParam scale_param = - is_mxfp8_columnwise ? kNVTEColumnwiseScaleInv : kNVTERowwiseScaleInv; - + // Configure NVTETensors for (size_t i = 0; i < num_tensors; ++i) { const uint8_t swizzled_flag = 1; nvte_set_tensor_param_v2(outputs_nvte[i], kNVTEWithGEMMSwizzledScales, &swizzled_flag, sizeof(swizzled_flag)); void *in_scale_ptr = scale_tensors[i].data_ptr(); void *out_scale_ptr = swizzled_scales_dptr + i * swizzled_scales_stride; - set_param(inputs_nvte[i], data_param, nullptr, data_dtype, data_shape); - set_param(inputs_nvte[i], scale_param, in_scale_ptr, scale_dtype, scale_shape); - set_param(outputs_nvte[i], data_param, nullptr, data_dtype, data_shape); - set_param(outputs_nvte[i], scale_param, out_scale_ptr, scale_dtype, scale_shape); + set_param(inputs_nvte[i], data_param_name, nullptr, data_dtype, data_shape); + set_param(inputs_nvte[i], scale_param_name, in_scale_ptr, scale_dtype, scale_shape); + set_param(outputs_nvte[i], data_param_name, nullptr, data_dtype, data_shape); + set_param(outputs_nvte[i], scale_param_name, out_scale_ptr, scale_dtype, scale_shape); } // Launch swizzle kernel From c03cb80c3a381040c601b5c71b85d275416c4313 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Tue, 2 Jun 2026 00:40:02 +0000 Subject: [PATCH 04/12] [PyTorch] Consolidate utils.cpp into misc.cpp After moving the experimental grouped-MLP helper out, the only thing left in extensions/utils.cpp was copy_data_ptrs_to_device, which fits naturally alongside the cublasLt/cuDNN version getters and splits_to_offsets already in extensions/misc.cpp. Move it there and delete the now-empty utils.cpp. Build picks up sources via glob, so no manifest update is needed. Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: Tim Moon --- .../pytorch/csrc/extensions/misc.cpp | 26 +++++++++++++ .../pytorch/csrc/extensions/utils.cpp | 37 ------------------- 2 files changed, 26 insertions(+), 37 deletions(-) delete mode 100644 transformer_engine/pytorch/csrc/extensions/utils.cpp diff --git a/transformer_engine/pytorch/csrc/extensions/misc.cpp b/transformer_engine/pytorch/csrc/extensions/misc.cpp index c5707fa53c..c3bb6d64d6 100644 --- a/transformer_engine/pytorch/csrc/extensions/misc.cpp +++ b/transformer_engine/pytorch/csrc/extensions/misc.cpp @@ -4,6 +4,11 @@ * See LICENSE for license information. ************************************************************************/ +#include + +#include + +#include "common/common.h" #include "../extensions.h" namespace transformer_engine::pytorch { @@ -30,4 +35,25 @@ at::Tensor splits_to_offsets(const at::Tensor &first_dims, int64_t logical_last_ return output; } +at::Tensor copy_data_ptrs_to_device(const std::vector &tensors, + const c10::Device &device) { + // Collect data pointers + std::vector ptrs_host; + ptrs_host.reserve(tensors.size()); + for (const auto &tensor : tensors) { + ptrs_host.push_back(reinterpret_cast(tensor.data_ptr())); + } + + // Allocate device buffer + auto ptrs_device = at::empty({static_cast(tensors.size())}, + at::TensorOptions().dtype(at::kLong).device(device)); + + // Load pointers on device + nvte_copy_host_to_device_via_kernel(ptrs_host.data(), ptrs_device.data_ptr(), + tensors.size() * sizeof(uint64_t), + at::cuda::getCurrentCUDAStream()); + + return ptrs_device; +} + } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/utils.cpp b/transformer_engine/pytorch/csrc/extensions/utils.cpp deleted file mode 100644 index 228c831ec0..0000000000 --- a/transformer_engine/pytorch/csrc/extensions/utils.cpp +++ /dev/null @@ -1,37 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include - -#include - -#include "common/common.h" -#include "extensions.h" - -namespace transformer_engine::pytorch { - -at::Tensor copy_data_ptrs_to_device(const std::vector &tensors, - const c10::Device &device) { - // Collect data pointers - std::vector ptrs_host; - ptrs_host.reserve(tensors.size()); - for (const auto &tensor : tensors) { - ptrs_host.push_back(reinterpret_cast(tensor.data_ptr())); - } - - // Allocate device buffer - auto ptrs_device = at::empty({static_cast(tensors.size())}, - at::TensorOptions().dtype(at::kLong).device(device)); - - // Load pointers on device - nvte_copy_host_to_device_via_kernel(ptrs_host.data(), ptrs_device.data_ptr(), - tensors.size() * sizeof(uint64_t), - at::cuda::getCurrentCUDAStream()); - - return ptrs_device; -} - -} // namespace transformer_engine::pytorch From 6ee43768e5460f0183a84da0e341729a69ebaf2b Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Tue, 2 Jun 2026 02:59:24 +0000 Subject: [PATCH 05/12] [PyTorch] Tidy the grouped MLP experimental helper Wraps the C++ implementation of swizzle_scales_and_pack_ptrs_for_discrete_weights in a `grouped_mlp_experimental` namespace and renames the format-selector argument from `format` to `swizzle_type` across the declaration, implementation, and pybind binding. The pybind submodule name was already `grouped_mlp_experimental`, so the C++ namespace now mirrors it. Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: Tim Moon --- transformer_engine/pytorch/csrc/extensions.h | 12 ++++-- .../extensions/grouped_mlp_experimental.cpp | 39 +++++++++++-------- .../pytorch/csrc/extensions/pybind.cpp | 4 +- 3 files changed, 32 insertions(+), 23 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index b0bea69c88..6a2f039929 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -498,16 +498,20 @@ at::Tensor copy_data_ptrs_to_device(const std::vector &tensors, * these helpers should also be considered unstable. **************************************************************************************************/ +namespace grouped_mlp_experimental { + // Prepare discrete weight tensors for the cuDNN CuTe DSL grouped GEMM // kernel by swizzling scales and copying data and scale pointers to -// device. All tensors must share a uniform shape and `format` must be -// one of "mxfp8_rowwise", "mxfp8_columnwise", or "nvfp4". Returns -// {data_ptrs_device, scale_ptrs_device, swizzled_scales_buffer}. +// device. All tensors must share a uniform shape and `swizzle_type` +// must be one of "mxfp8_rowwise", "mxfp8_columnwise", or "nvfp4". +// Returns {data_ptrs_device, scale_ptrs_device, swizzled_scales_buffer}. std::tuple swizzle_scales_and_pack_ptrs_for_discrete_weights( const std::vector &data_tensors, - const std::vector &scale_tensors, const std::string &format, + const std::vector &scale_tensors, const std::string &swizzle_type, const c10::Device &device); +} // namespace grouped_mlp_experimental + /*************************************************************************************************** * Support THD format for Context Parallel **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/grouped_mlp_experimental.cpp b/transformer_engine/pytorch/csrc/extensions/grouped_mlp_experimental.cpp index b89ebefea5..948098d5cf 100644 --- a/transformer_engine/pytorch/csrc/extensions/grouped_mlp_experimental.cpp +++ b/transformer_engine/pytorch/csrc/extensions/grouped_mlp_experimental.cpp @@ -16,25 +16,28 @@ #include "common/common.h" #include "extensions.h" -namespace transformer_engine::pytorch { +namespace transformer_engine { +namespace pytorch { +namespace grouped_mlp_experimental { std::tuple swizzle_scales_and_pack_ptrs_for_discrete_weights( const std::vector &data_tensors, - const std::vector &scale_tensors, const std::string &format_str, + const std::vector &scale_tensors, + const std::string &swizzle_type_str, const c10::Device &device) { const size_t num_tensors = data_tensors.size(); NVTE_CHECK(scale_tensors.size() == num_tensors, "Expected data_tensors and scale_tensors to have matching sizes, but got ", num_tensors, " and ", scale_tensors.size(), "."); - // Parse tensor format - enum class TensorFormat { Invalid, MXFP8Rowwise, MXFP8Columnwise, NVFP4 }; - TensorFormat format = TensorFormat::Invalid; - if (format_str == "mxfp8_rowwise") { format = TensorFormat::MXFP8Rowwise; } - else if (format_str == "mxfp8_columnwise") { format = TensorFormat::MXFP8Columnwise; } - else if (format_str == "nvfp4") { format = TensorFormat::NVFP4; } + // Parse swizzle type + enum class SwizzleType { Invalid, MXFP8Rowwise, MXFP8Columnwise, NVFP4 }; + SwizzleType swizzle_type = SwizzleType::Invalid; + if (swizzle_type_str == "mxfp8_rowwise") { swizzle_type = SwizzleType::MXFP8Rowwise; } + else if (swizzle_type_str == "mxfp8_columnwise") { swizzle_type = SwizzleType::MXFP8Columnwise; } + else if (swizzle_type_str == "nvfp4") { swizzle_type = SwizzleType::NVFP4; } else { - NVTE_ERROR("Unsupported format (", format_str, + NVTE_ERROR("Unsupported swizzle type (", swizzle_type_str, "). Expected one of: mxfp8_rowwise, mxfp8_columnwise, nvfp4."); } @@ -52,13 +55,13 @@ std::tuple swizzle_scales_and_pack_ptrs_for_ NVTEScalingMode scaling_mode; transformer_engine::DType data_dtype, scale_dtype; NVTETensorParam data_param_name, scale_param_name; - switch (format) { - case TensorFormat::MXFP8Rowwise: - case TensorFormat::MXFP8Columnwise: + switch (swizzle_type) { + case SwizzleType::MXFP8Rowwise: + case SwizzleType::MXFP8Columnwise: scaling_mode = NVTE_MXFP8_1D_SCALING; data_dtype = transformer_engine::DType::kFloat8E4M3; scale_dtype = transformer_engine::DType::kFloat8E8M0; - if (format == TensorFormat::MXFP8Rowwise) { + if (swizzle_type == SwizzleType::MXFP8Rowwise) { data_param_name = kNVTERowwiseData; scale_param_name = kNVTERowwiseScaleInv; } else { @@ -66,7 +69,7 @@ std::tuple swizzle_scales_and_pack_ptrs_for_ scale_param_name = kNVTEColumnwiseScaleInv; } break; - case TensorFormat::NVFP4: + case SwizzleType::NVFP4: scaling_mode = NVTE_NVFP4_1D_SCALING; data_dtype = transformer_engine::DType::kFloat4E2M1; scale_dtype = transformer_engine::DType::kFloat8E4M3; @@ -74,12 +77,12 @@ std::tuple swizzle_scales_and_pack_ptrs_for_ scale_param_name = kNVTERowwiseScaleInv; break; default: - NVTE_ERROR("Unsupported format (", static_cast(format), ")."); + NVTE_ERROR("Unsupported swizzle type (", static_cast(swizzle_type), ")."); } // Data shape NVTEShape data_shape = convertTorchShape(data_tensors[0].sizes()); - if (format == TensorFormat::NVFP4) { + if (swizzle_type == SwizzleType::NVFP4) { // NVFP4 packs two 4-bit values per byte NVTE_CHECK(data_shape.ndim > 0, "Invalid shape for NVFP4 data tensor (", getTensorShape(data_tensors[0]), ")."); @@ -149,4 +152,6 @@ std::tuple swizzle_scales_and_pack_ptrs_for_ return {std::move(data_ptrs), std::move(scale_ptrs), std::move(swizzled_scales)}; } -} // namespace transformer_engine::pytorch +} // namespace grouped_mlp_experimental +} // namespace pytorch +} // namespace transformer_engine diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 3029b22dda..47e2034ae9 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -608,9 +608,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "grouped_mlp_experimental", "Experimental helpers for the fused grouped MLP (unstable, may change or disappear)."); grouped_mlp_experimental.def("swizzle_scales_and_pack_ptrs_for_discrete_weights", - &transformer_engine::pytorch::swizzle_scales_and_pack_ptrs_for_discrete_weights, + &transformer_engine::pytorch::grouped_mlp_experimental::swizzle_scales_and_pack_ptrs_for_discrete_weights, py::arg("data_tensors"), py::arg("scale_tensors"), - py::arg("format"), py::arg("device"), + py::arg("swizzle_type"), py::arg("device"), py::call_guard()); // Data structures From dbbb9b1d8a6212739877cad78bd7941bf1cd57cc Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Tue, 2 Jun 2026 02:59:38 +0000 Subject: [PATCH 06/12] [PyTorch] Hook NVFP4 optimize_for_gemm into the quantize path After NVFP4Quantizer::quantize runs, run inplace_swizzle_scale_for_gemm on the output when optimize_for_gemm is set and the quantize kernel hasn't already produced swizzled scales. The NVFP4 quantize kernel rejects with_gemm_swizzled_scales=true and emits compact scales, so without this hook callers had to follow up with a manual swizzle in Python (see ops/_common.py:85). The hook is a no-op for MXFP8 (its quantize kernel sets the flag itself) and for any quantizer with optimize_for_gemm=false. Also fixes a latent state-consistency bug in NVFP4Quantizer::convert_and_update_tensor: it was resetting the C++ wrapper's with_gemm_swizzled_scales to false but never touching the Python tensor's _with_gemm_swizzled_scales attribute. Re-quantizing into a tensor that previously held swizzled scales would leave the Python flag stuck at true while the buffer was compact, mismatched state that downstream code could mis-read. The Python attribute is now reset alongside the C++ wrapper, matching what MXFP8Quantizer::convert_and_update_tensor already does. Adds test_swizzle_scales_and_pack_ptrs_for_discrete_weights covering mxfp8_rowwise, mxfp8_columnwise, and nvfp4, comparing the helper's swizzled output against scales produced by the quantizer with optimize_for_gemm=true. NVFP4 was the case that surfaced the quantizer-side issues fixed above. Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: Tim Moon --- tests/pytorch/test_grouped_linear.py | 131 +++++++++++++++++- .../pytorch/csrc/extensions/cast.cpp | 6 + transformer_engine/pytorch/csrc/quantizer.cpp | 17 ++- 3 files changed, 149 insertions(+), 5 deletions(-) diff --git a/tests/pytorch/test_grouped_linear.py b/tests/pytorch/test_grouped_linear.py index 0dc253c18c..a592c44528 100644 --- a/tests/pytorch/test_grouped_linear.py +++ b/tests/pytorch/test_grouped_linear.py @@ -4,7 +4,7 @@ import os import random -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Sequence import pytest import torch @@ -20,6 +20,7 @@ GroupedLinear, Linear, MXFP8Quantizer, + NVFP4Quantizer, autocast, is_bf16_available, quantized_model_init, @@ -35,13 +36,19 @@ ) from transformer_engine.pytorch.tensor.grouped_tensor import GroupedTensor import transformer_engine_torch as tex -from utils import ModelConfig, recipe_id, reset_rng_states, skip_unsupported_backward_override +from utils import ( + ModelConfig, + assert_close, + recipe_id, + reset_rng_states, + skip_unsupported_backward_override, +) # Only run FP8 tests on supported devices. fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) fp8_block_scaling_available, _ = te.is_fp8_block_scaling_available(return_reason=True) mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) -nvfp4_available, _ = te.is_nvfp4_available(return_reason=True) +nvfp4_available, reason_for_no_nvfp4 = te.is_nvfp4_available(return_reason=True) seed = 1234 reset_rng_states() @@ -1740,3 +1747,121 @@ def _train_step(x, dy, out_buf, *, use_graphed): for graph_grad, param in zip(graph_param_grads, grouped_linear.parameters()): assert param.grad is not None torch.testing.assert_close(graph_grad.float(), param.grad.float(), **tols) + + +@pytest.mark.parametrize("swizzle_type", ["mxfp8_rowwise", "mxfp8_columnwise", "nvfp4"]) +def test_swizzle_scales_and_pack_ptrs_for_discrete_weights( + swizzle_type: str, + num_tensors: int = 4, + shape: Sequence[int] = (160, 96), +): + """Helper function for preparing discrete weights for cuDNN group GEMM kernel""" + + # Skip unsupported configurations + if not mxfp8_available and swizzle_type in ("mxfp8_rowwise", "mxfp8_columnwise"): + pytest.skip(reason_for_no_mxfp8) + if not nvfp4_available and swizzle_type == "nvfp4": + pytest.skip(reason_for_no_nvfp4) + + # Construct quantizer + quantizer = None + if swizzle_type in ("mxfp8_rowwise", "mxfp8_columnwise"): + quantizer = MXFP8Quantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + rowwise=True, + columnwise=True, + ) + elif swizzle_type == "nvfp4": + quantizer = NVFP4Quantizer( + columnwise=False, + with_rht=False, + with_post_rht_amax=False, + with_2d_quantization=False, + stochastic_rounding=False, + with_random_sign_mask=False, + ) + + # Per-expert tensors: unquantized, quantized with compact scales, + # quantized with swizzled scales + device = torch.device("cuda") + unquantized_tensors = [ + torch.randn(shape, dtype=torch.bfloat16, device=device) for _ in range(num_tensors) + ] + quantizer.optimize_for_gemm = False + tensors_with_compact_scales = [quantizer(t) for t in unquantized_tensors] + quantizer.optimize_for_gemm = True + tensors_with_swizzled_scales = [quantizer(t) for t in unquantized_tensors] + + # Extract data and scale buffers + if swizzle_type in ("mxfp8_rowwise", "nvfp4"): + data_tensors = [qx._rowwise_data for qx in tensors_with_compact_scales] + scale_tensors = [qx._rowwise_scale_inv for qx in tensors_with_compact_scales] + ref_scale_tensors = [qx._rowwise_scale_inv for qx in tensors_with_swizzled_scales] + elif swizzle_type == "mxfp8_columnwise": + data_tensors = [qx._columnwise_data for qx in tensors_with_compact_scales] + scale_tensors = [qx._columnwise_scale_inv for qx in tensors_with_compact_scales] + ref_scale_tensors = [qx._columnwise_scale_inv for qx in tensors_with_swizzled_scales] + else: + raise ValueError("Unrecogized swizzle type") + + # Call the helper function + data_ptrs, scale_ptrs, swizzled_scales_buffer = ( + tex.grouped_mlp_experimental.swizzle_scales_and_pack_ptrs_for_discrete_weights( + data_tensors, + scale_tensors, + swizzle_type, + device, + ) + ) + + # Check data pointer values + expected_data_ptrs = torch.tensor( + [t.data_ptr() for t in data_tensors], dtype=torch.int64, device="cpu", + ) + assert_close(data_ptrs, expected_data_ptrs) + + # Check scale pointer values + scale_bytes = scale_tensors[0].numel() * scale_tensors[0].element_size() + expected_scale_ptrs = torch.tensor( + [swizzled_scales_buffer.data_ptr() + i * scale_bytes for i in range(num_tensors)], + dtype=torch.int64, + device="cpu", + ) + assert_close(scale_ptrs, expected_scale_ptrs) + + # Check swizzled scale values + swizzled_scales_buffer = swizzled_scales_buffer.view(torch.uint8) + expected_swizzled_scales_buffer = ( + torch.cat(ref_scale_tensors) + .view(torch.uint8) + .view_as(swizzled_scales_buffer) + ) + assert_close( + swizzled_scales_buffer, + expected_swizzled_scales_buffer, + ) + + # Poison the padded compact scales + if swizzle_type == "mxfp8_rowwise": + unpadded_scale_shape = (shape[0], shape[1] // 32) + elif swizzle_type == "mxfp8_columnwise": + unpadded_scale_shape = (shape[0] // 32, shape[1]) + elif swizzle_type == "nvfp4": + unpadded_scale_shape = (shape[0], shape[1] // 16) + for scale in scale_tensors: + scale[unpadded_scale_shape[0]:, :].view(torch.uint8).fill_(-1) + scale[:, unpadded_scale_shape[1]:].view(torch.uint8).fill_(-1) + + # Check that swizzling removes poisoned pad scales + _, _, swizzled_scales_buffer = ( + tex.grouped_mlp_experimental.swizzle_scales_and_pack_ptrs_for_discrete_weights( + data_tensors, + scale_tensors, + swizzle_type, + device, + ) + ) + assert_close( + swizzled_scales_buffer, + expected_swizzled_scales_buffer, + ) diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index d1a9cd8587..e4110131ea 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -62,6 +62,12 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob // Perform quantization quantizer_cpp->quantize(input_cpp, output_cpp, noop_flag_cpp); + // Post-quantize swizzle for quantizers whose kernel does not bake + // the GEMM-swizzled scale layout in directly + if (quantizer_cpp->optimize_for_gemm && !output_cpp.get_with_gemm_swizzled_scales()) { + inplace_swizzle_scale_for_gemm(output_py); + } + return output_py; } diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index bc87b54ba8..a54a301664 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1781,7 +1781,12 @@ std::pair NVFP4Quantizer::create_tensor( using namespace pybind11::literals; // Scaling factor format - const bool with_gemm_swizzled_scales = false; /// TODO (tmoon) self->optimize_for_gemm + // + // The NVFP4 quantize kernel requires `with_gemm_swizzled_scales=false` and + // emits compact scales. When `optimize_for_gemm` is set, the caller of + // `quantize` (cast.cpp) applies a post-quantize swizzle that flips this + // flag on the returned tensor. + const bool with_gemm_swizzled_scales = false; // Tensor dimensions const std::vector shape_int64(shape.begin(), shape.end()); @@ -2077,7 +2082,12 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( NVTE_CHECK(detail::IsNVFP4Tensor(tensor.ptr()), "NVFP4Quantizer must output to IsNVFP4Tensor."); // Scaling factor format - const bool with_gemm_swizzled_scales = false; // TODO (tmoon) Enable with optimize_for_gemm + // + // The NVFP4 quantize kernel requires `with_gemm_swizzled_scales=false` and + // emits compact scales. When `optimize_for_gemm` is set, the caller of + // `quantize` (cast.cpp) applies a post-quantize swizzle that flips this + // flag on the returned tensor. + const bool with_gemm_swizzled_scales = false; // Extract buffers from Python tensor auto get_tensor = [&tensor](const char* name) -> std::optional { @@ -2109,6 +2119,9 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( } const auto [flat_first_dim, flat_last_dim] = get_2d_dims(shape); + tensor.attr("_with_gemm_swizzled_scales") = py::cast(with_gemm_swizzled_scales); + + // Advanced NVFP4 modes const bool row_scaled_nvfp4 = this->row_scaled_nvfp4; const bool nvfp4_use_4over6 = this->nvfp4_4over6_mode != kNVTENVFP44Over6Disabled; const int nvfp4_e4m3_max = this->nvfp4_e4m3_max; From d0c22e8613b485fbcdc45b0c54af3d3301b354bb Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Tue, 2 Jun 2026 03:07:31 +0000 Subject: [PATCH 07/12] Fix linter warning Signed-off-by: Tim Moon --- .../csrc/extensions/grouped_mlp_experimental.cpp | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/grouped_mlp_experimental.cpp b/transformer_engine/pytorch/csrc/extensions/grouped_mlp_experimental.cpp index 948098d5cf..c49efd2a11 100644 --- a/transformer_engine/pytorch/csrc/extensions/grouped_mlp_experimental.cpp +++ b/transformer_engine/pytorch/csrc/extensions/grouped_mlp_experimental.cpp @@ -33,10 +33,13 @@ std::tuple swizzle_scales_and_pack_ptrs_for_ // Parse swizzle type enum class SwizzleType { Invalid, MXFP8Rowwise, MXFP8Columnwise, NVFP4 }; SwizzleType swizzle_type = SwizzleType::Invalid; - if (swizzle_type_str == "mxfp8_rowwise") { swizzle_type = SwizzleType::MXFP8Rowwise; } - else if (swizzle_type_str == "mxfp8_columnwise") { swizzle_type = SwizzleType::MXFP8Columnwise; } - else if (swizzle_type_str == "nvfp4") { swizzle_type = SwizzleType::NVFP4; } - else { + if (swizzle_type_str == "mxfp8_rowwise") { + swizzle_type = SwizzleType::MXFP8Rowwise; + } else if (swizzle_type_str == "mxfp8_columnwise") { + swizzle_type = SwizzleType::MXFP8Columnwise; + } else if (swizzle_type_str == "nvfp4") { + swizzle_type = SwizzleType::NVFP4; + } else { NVTE_ERROR("Unsupported swizzle type (", swizzle_type_str, "). Expected one of: mxfp8_rowwise, mxfp8_columnwise, nvfp4."); } From a268126427a24e011e16219e2b7cb1448dfe32cc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 2 Jun 2026 03:21:19 +0000 Subject: [PATCH 08/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/test_grouped_linear.py | 12 ++--- transformer_engine/pytorch/csrc/extensions.h | 5 +- .../extensions/grouped_mlp_experimental.cpp | 50 +++++++++---------- .../pytorch/csrc/extensions/misc.cpp | 2 +- .../pytorch/csrc/extensions/pybind.cpp | 3 +- 5 files changed, 35 insertions(+), 37 deletions(-) diff --git a/tests/pytorch/test_grouped_linear.py b/tests/pytorch/test_grouped_linear.py index a592c44528..092f52d553 100644 --- a/tests/pytorch/test_grouped_linear.py +++ b/tests/pytorch/test_grouped_linear.py @@ -1816,7 +1816,9 @@ def test_swizzle_scales_and_pack_ptrs_for_discrete_weights( # Check data pointer values expected_data_ptrs = torch.tensor( - [t.data_ptr() for t in data_tensors], dtype=torch.int64, device="cpu", + [t.data_ptr() for t in data_tensors], + dtype=torch.int64, + device="cpu", ) assert_close(data_ptrs, expected_data_ptrs) @@ -1832,9 +1834,7 @@ def test_swizzle_scales_and_pack_ptrs_for_discrete_weights( # Check swizzled scale values swizzled_scales_buffer = swizzled_scales_buffer.view(torch.uint8) expected_swizzled_scales_buffer = ( - torch.cat(ref_scale_tensors) - .view(torch.uint8) - .view_as(swizzled_scales_buffer) + torch.cat(ref_scale_tensors).view(torch.uint8).view_as(swizzled_scales_buffer) ) assert_close( swizzled_scales_buffer, @@ -1849,8 +1849,8 @@ def test_swizzle_scales_and_pack_ptrs_for_discrete_weights( elif swizzle_type == "nvfp4": unpadded_scale_shape = (shape[0], shape[1] // 16) for scale in scale_tensors: - scale[unpadded_scale_shape[0]:, :].view(torch.uint8).fill_(-1) - scale[:, unpadded_scale_shape[1]:].view(torch.uint8).fill_(-1) + scale[unpadded_scale_shape[0] :, :].view(torch.uint8).fill_(-1) + scale[:, unpadded_scale_shape[1] :].view(torch.uint8).fill_(-1) # Check that swizzling removes poisoned pad scales _, _, swizzled_scales_buffer = ( diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 6a2f039929..b0ddf679f0 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -506,9 +506,8 @@ namespace grouped_mlp_experimental { // must be one of "mxfp8_rowwise", "mxfp8_columnwise", or "nvfp4". // Returns {data_ptrs_device, scale_ptrs_device, swizzled_scales_buffer}. std::tuple swizzle_scales_and_pack_ptrs_for_discrete_weights( - const std::vector &data_tensors, - const std::vector &scale_tensors, const std::string &swizzle_type, - const c10::Device &device); + const std::vector &data_tensors, const std::vector &scale_tensors, + const std::string &swizzle_type, const c10::Device &device); } // namespace grouped_mlp_experimental diff --git a/transformer_engine/pytorch/csrc/extensions/grouped_mlp_experimental.cpp b/transformer_engine/pytorch/csrc/extensions/grouped_mlp_experimental.cpp index c49efd2a11..0ab8bc8d61 100644 --- a/transformer_engine/pytorch/csrc/extensions/grouped_mlp_experimental.cpp +++ b/transformer_engine/pytorch/csrc/extensions/grouped_mlp_experimental.cpp @@ -21,10 +21,8 @@ namespace pytorch { namespace grouped_mlp_experimental { std::tuple swizzle_scales_and_pack_ptrs_for_discrete_weights( - const std::vector &data_tensors, - const std::vector &scale_tensors, - const std::string &swizzle_type_str, - const c10::Device &device) { + const std::vector &data_tensors, const std::vector &scale_tensors, + const std::string &swizzle_type_str, const c10::Device &device) { const size_t num_tensors = data_tensors.size(); NVTE_CHECK(scale_tensors.size() == num_tensors, "Expected data_tensors and scale_tensors to have matching sizes, but got ", @@ -59,36 +57,36 @@ std::tuple swizzle_scales_and_pack_ptrs_for_ transformer_engine::DType data_dtype, scale_dtype; NVTETensorParam data_param_name, scale_param_name; switch (swizzle_type) { - case SwizzleType::MXFP8Rowwise: - case SwizzleType::MXFP8Columnwise: - scaling_mode = NVTE_MXFP8_1D_SCALING; - data_dtype = transformer_engine::DType::kFloat8E4M3; - scale_dtype = transformer_engine::DType::kFloat8E8M0; - if (swizzle_type == SwizzleType::MXFP8Rowwise) { + case SwizzleType::MXFP8Rowwise: + case SwizzleType::MXFP8Columnwise: + scaling_mode = NVTE_MXFP8_1D_SCALING; + data_dtype = transformer_engine::DType::kFloat8E4M3; + scale_dtype = transformer_engine::DType::kFloat8E8M0; + if (swizzle_type == SwizzleType::MXFP8Rowwise) { + data_param_name = kNVTERowwiseData; + scale_param_name = kNVTERowwiseScaleInv; + } else { + data_param_name = kNVTEColumnwiseData; + scale_param_name = kNVTEColumnwiseScaleInv; + } + break; + case SwizzleType::NVFP4: + scaling_mode = NVTE_NVFP4_1D_SCALING; + data_dtype = transformer_engine::DType::kFloat4E2M1; + scale_dtype = transformer_engine::DType::kFloat8E4M3; data_param_name = kNVTERowwiseData; scale_param_name = kNVTERowwiseScaleInv; - } else { - data_param_name = kNVTEColumnwiseData; - scale_param_name = kNVTEColumnwiseScaleInv; - } - break; - case SwizzleType::NVFP4: - scaling_mode = NVTE_NVFP4_1D_SCALING; - data_dtype = transformer_engine::DType::kFloat4E2M1; - scale_dtype = transformer_engine::DType::kFloat8E4M3; - data_param_name = kNVTERowwiseData; - scale_param_name = kNVTERowwiseScaleInv; - break; - default: - NVTE_ERROR("Unsupported swizzle type (", static_cast(swizzle_type), ")."); + break; + default: + NVTE_ERROR("Unsupported swizzle type (", static_cast(swizzle_type), ")."); } // Data shape NVTEShape data_shape = convertTorchShape(data_tensors[0].sizes()); if (swizzle_type == SwizzleType::NVFP4) { // NVFP4 packs two 4-bit values per byte - NVTE_CHECK(data_shape.ndim > 0, - "Invalid shape for NVFP4 data tensor (", getTensorShape(data_tensors[0]), ")."); + NVTE_CHECK(data_shape.ndim > 0, "Invalid shape for NVFP4 data tensor (", + getTensorShape(data_tensors[0]), ")."); data_shape.data[data_shape.ndim - 1] *= 2; } diff --git a/transformer_engine/pytorch/csrc/extensions/misc.cpp b/transformer_engine/pytorch/csrc/extensions/misc.cpp index c3bb6d64d6..727c79aea3 100644 --- a/transformer_engine/pytorch/csrc/extensions/misc.cpp +++ b/transformer_engine/pytorch/csrc/extensions/misc.cpp @@ -8,8 +8,8 @@ #include -#include "common/common.h" #include "../extensions.h" +#include "common/common.h" namespace transformer_engine::pytorch { diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 47e2034ae9..d78d09817e 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -608,7 +608,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "grouped_mlp_experimental", "Experimental helpers for the fused grouped MLP (unstable, may change or disappear)."); grouped_mlp_experimental.def("swizzle_scales_and_pack_ptrs_for_discrete_weights", - &transformer_engine::pytorch::grouped_mlp_experimental::swizzle_scales_and_pack_ptrs_for_discrete_weights, + &transformer_engine::pytorch::grouped_mlp_experimental:: + swizzle_scales_and_pack_ptrs_for_discrete_weights, py::arg("data_tensors"), py::arg("scale_tensors"), py::arg("swizzle_type"), py::arg("device"), py::call_guard()); From e09e1cabafe5735563b1a38cbe673a3f1c202141 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Tue, 2 Jun 2026 03:27:35 +0000 Subject: [PATCH 09/12] Fix linter warning Signed-off-by: Tim Moon --- transformer_engine/pytorch/csrc/extensions/pybind.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index d78d09817e..a931bdd223 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -683,4 +683,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .def("get_buffer", &CommOverlapP2P::get_buffer, py::arg("local_chunk") = false, py::arg("shape") = std::nullopt) .def("get_communication_stream", &CommOverlapP2P::get_communication_stream); -} +} // NOLINT(readability/fn_size) From 2318c465d1b57835c9ebdff9a9821caa64abb15f Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Tue, 2 Jun 2026 22:14:09 +0000 Subject: [PATCH 10/12] Avoid swizzle kernel when tensor size is zero Signed-off-by: Tim Moon --- transformer_engine/common/swizzle/swizzle.cu | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/transformer_engine/common/swizzle/swizzle.cu b/transformer_engine/common/swizzle/swizzle.cu index 51969e10e3..38b526360e 100644 --- a/transformer_engine/common/swizzle/swizzle.cu +++ b/transformer_engine/common/swizzle/swizzle.cu @@ -1049,6 +1049,11 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s " column-wise scaling factors, but got shape=", output->columnwise_scale_inv.shape, "."); } + // Return early if tensor has no entries + if (m == 0 || k == 0) { + return; + } + // Choose swizzle implementation bool rowwise_swizzle{false}, columnwise_swizzle{false}; switch (scaling_mode) { From 6cd058475af4b764a24a2d07ba6669ca03ca41aa Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Tue, 2 Jun 2026 15:16:28 -0700 Subject: [PATCH 11/12] Review suggestion from @vthumbe1503 Co-authored-by: vthumbe1503 Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- tests/pytorch/test_grouped_linear.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/pytorch/test_grouped_linear.py b/tests/pytorch/test_grouped_linear.py index 092f52d553..30a15da711 100644 --- a/tests/pytorch/test_grouped_linear.py +++ b/tests/pytorch/test_grouped_linear.py @@ -1768,8 +1768,8 @@ def test_swizzle_scales_and_pack_ptrs_for_discrete_weights( if swizzle_type in ("mxfp8_rowwise", "mxfp8_columnwise"): quantizer = MXFP8Quantizer( fp8_dtype=tex.DType.kFloat8E4M3, - rowwise=True, - columnwise=True, + rowwise=swizzle_type=="mxfp8_rowwise", + columnwise=swizzle_type=="mxfp8_columnwise", ) elif swizzle_type == "nvfp4": quantizer = NVFP4Quantizer( From c14787543e5937fbf3e08c18a7fa1adf376868b5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 2 Jun 2026 22:17:21 +0000 Subject: [PATCH 12/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/test_grouped_linear.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/pytorch/test_grouped_linear.py b/tests/pytorch/test_grouped_linear.py index 30a15da711..bdb61c6e91 100644 --- a/tests/pytorch/test_grouped_linear.py +++ b/tests/pytorch/test_grouped_linear.py @@ -1768,8 +1768,8 @@ def test_swizzle_scales_and_pack_ptrs_for_discrete_weights( if swizzle_type in ("mxfp8_rowwise", "mxfp8_columnwise"): quantizer = MXFP8Quantizer( fp8_dtype=tex.DType.kFloat8E4M3, - rowwise=swizzle_type=="mxfp8_rowwise", - columnwise=swizzle_type=="mxfp8_columnwise", + rowwise=swizzle_type == "mxfp8_rowwise", + columnwise=swizzle_type == "mxfp8_columnwise", ) elif swizzle_type == "nvfp4": quantizer = NVFP4Quantizer(