diff --git a/tests/pytorch/test_grouped_linear.py b/tests/pytorch/test_grouped_linear.py index 0dc253c18c..bdb61c6e91 100644 --- a/tests/pytorch/test_grouped_linear.py +++ b/tests/pytorch/test_grouped_linear.py @@ -4,7 +4,7 @@ import os import random -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Sequence import pytest import torch @@ -20,6 +20,7 @@ GroupedLinear, Linear, MXFP8Quantizer, + NVFP4Quantizer, autocast, is_bf16_available, quantized_model_init, @@ -35,13 +36,19 @@ ) from transformer_engine.pytorch.tensor.grouped_tensor import GroupedTensor import transformer_engine_torch as tex -from utils import ModelConfig, recipe_id, reset_rng_states, skip_unsupported_backward_override +from utils import ( + ModelConfig, + assert_close, + recipe_id, + reset_rng_states, + skip_unsupported_backward_override, +) # Only run FP8 tests on supported devices. fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) fp8_block_scaling_available, _ = te.is_fp8_block_scaling_available(return_reason=True) mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) -nvfp4_available, _ = te.is_nvfp4_available(return_reason=True) +nvfp4_available, reason_for_no_nvfp4 = te.is_nvfp4_available(return_reason=True) seed = 1234 reset_rng_states() @@ -1740,3 +1747,121 @@ def _train_step(x, dy, out_buf, *, use_graphed): for graph_grad, param in zip(graph_param_grads, grouped_linear.parameters()): assert param.grad is not None torch.testing.assert_close(graph_grad.float(), param.grad.float(), **tols) + + +@pytest.mark.parametrize("swizzle_type", ["mxfp8_rowwise", "mxfp8_columnwise", "nvfp4"]) +def test_swizzle_scales_and_pack_ptrs_for_discrete_weights( + swizzle_type: str, + num_tensors: int = 4, + shape: Sequence[int] = (160, 96), +): + """Helper function for preparing discrete weights for cuDNN group GEMM kernel""" + + # Skip unsupported configurations + if not mxfp8_available and swizzle_type in ("mxfp8_rowwise", "mxfp8_columnwise"): + pytest.skip(reason_for_no_mxfp8) + if not nvfp4_available and swizzle_type == "nvfp4": + pytest.skip(reason_for_no_nvfp4) + + # Construct quantizer + quantizer = None + if swizzle_type in ("mxfp8_rowwise", "mxfp8_columnwise"): + quantizer = MXFP8Quantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + rowwise=swizzle_type == "mxfp8_rowwise", + columnwise=swizzle_type == "mxfp8_columnwise", + ) + elif swizzle_type == "nvfp4": + quantizer = NVFP4Quantizer( + columnwise=False, + with_rht=False, + with_post_rht_amax=False, + with_2d_quantization=False, + stochastic_rounding=False, + with_random_sign_mask=False, + ) + + # Per-expert tensors: unquantized, quantized with compact scales, + # quantized with swizzled scales + device = torch.device("cuda") + unquantized_tensors = [ + torch.randn(shape, dtype=torch.bfloat16, device=device) for _ in range(num_tensors) + ] + quantizer.optimize_for_gemm = False + tensors_with_compact_scales = [quantizer(t) for t in unquantized_tensors] + quantizer.optimize_for_gemm = True + tensors_with_swizzled_scales = [quantizer(t) for t in unquantized_tensors] + + # Extract data and scale buffers + if swizzle_type in ("mxfp8_rowwise", "nvfp4"): + data_tensors = [qx._rowwise_data for qx in tensors_with_compact_scales] + scale_tensors = [qx._rowwise_scale_inv for qx in tensors_with_compact_scales] + ref_scale_tensors = [qx._rowwise_scale_inv for qx in tensors_with_swizzled_scales] + elif swizzle_type == "mxfp8_columnwise": + data_tensors = [qx._columnwise_data for qx in tensors_with_compact_scales] + scale_tensors = [qx._columnwise_scale_inv for qx in tensors_with_compact_scales] + ref_scale_tensors = [qx._columnwise_scale_inv for qx in tensors_with_swizzled_scales] + else: + raise ValueError("Unrecogized swizzle type") + + # Call the helper function + data_ptrs, scale_ptrs, swizzled_scales_buffer = ( + tex.grouped_mlp_experimental.swizzle_scales_and_pack_ptrs_for_discrete_weights( + data_tensors, + scale_tensors, + swizzle_type, + device, + ) + ) + + # Check data pointer values + expected_data_ptrs = torch.tensor( + [t.data_ptr() for t in data_tensors], + dtype=torch.int64, + device="cpu", + ) + assert_close(data_ptrs, expected_data_ptrs) + + # Check scale pointer values + scale_bytes = scale_tensors[0].numel() * scale_tensors[0].element_size() + expected_scale_ptrs = torch.tensor( + [swizzled_scales_buffer.data_ptr() + i * scale_bytes for i in range(num_tensors)], + dtype=torch.int64, + device="cpu", + ) + assert_close(scale_ptrs, expected_scale_ptrs) + + # Check swizzled scale values + swizzled_scales_buffer = swizzled_scales_buffer.view(torch.uint8) + expected_swizzled_scales_buffer = ( + torch.cat(ref_scale_tensors).view(torch.uint8).view_as(swizzled_scales_buffer) + ) + assert_close( + swizzled_scales_buffer, + expected_swizzled_scales_buffer, + ) + + # Poison the padded compact scales + if swizzle_type == "mxfp8_rowwise": + unpadded_scale_shape = (shape[0], shape[1] // 32) + elif swizzle_type == "mxfp8_columnwise": + unpadded_scale_shape = (shape[0] // 32, shape[1]) + elif swizzle_type == "nvfp4": + unpadded_scale_shape = (shape[0], shape[1] // 16) + for scale in scale_tensors: + scale[unpadded_scale_shape[0] :, :].view(torch.uint8).fill_(-1) + scale[:, unpadded_scale_shape[1] :].view(torch.uint8).fill_(-1) + + # Check that swizzling removes poisoned pad scales + _, _, swizzled_scales_buffer = ( + tex.grouped_mlp_experimental.swizzle_scales_and_pack_ptrs_for_discrete_weights( + data_tensors, + scale_tensors, + swizzle_type, + device, + ) + ) + assert_close( + swizzled_scales_buffer, + expected_swizzled_scales_buffer, + ) diff --git a/transformer_engine/common/swizzle/swizzle.cu b/transformer_engine/common/swizzle/swizzle.cu index 51969e10e3..38b526360e 100644 --- a/transformer_engine/common/swizzle/swizzle.cu +++ b/transformer_engine/common/swizzle/swizzle.cu @@ -1049,6 +1049,11 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s " column-wise scaling factors, but got shape=", output->columnwise_scale_inv.shape, "."); } + // Return early if tensor has no entries + if (m == 0 || k == 0) { + return; + } + // Choose swizzle implementation bool rowwise_swizzle{false}, columnwise_swizzle{false}; switch (scaling_mode) { diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 1f9974448b..b0ddf679f0 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -490,9 +490,26 @@ at::Tensor splits_to_offsets(const at::Tensor &first_dims, int64_t logical_last_ at::Tensor copy_data_ptrs_to_device(const std::vector &tensors, const c10::Device &device); -std::tuple> transform_and_copy_data_ptrs_to_device( - const std::string &transform_type, const std::vector &tensors, - const c10::Device &device); +/*************************************************************************************************** + * Experimental helpers for the fused grouped MLP + * + * These primarily exist to support cuDNN CuTe DSL grouped GEMM + * kernels. Since those are unstable and under active development, + * these helpers should also be considered unstable. + **************************************************************************************************/ + +namespace grouped_mlp_experimental { + +// Prepare discrete weight tensors for the cuDNN CuTe DSL grouped GEMM +// kernel by swizzling scales and copying data and scale pointers to +// device. All tensors must share a uniform shape and `swizzle_type` +// must be one of "mxfp8_rowwise", "mxfp8_columnwise", or "nvfp4". +// Returns {data_ptrs_device, scale_ptrs_device, swizzled_scales_buffer}. +std::tuple swizzle_scales_and_pack_ptrs_for_discrete_weights( + const std::vector &data_tensors, const std::vector &scale_tensors, + const std::string &swizzle_type, const c10::Device &device); + +} // namespace grouped_mlp_experimental /*************************************************************************************************** * Support THD format for Context Parallel diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index d1a9cd8587..e4110131ea 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -62,6 +62,12 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob // Perform quantization quantizer_cpp->quantize(input_cpp, output_cpp, noop_flag_cpp); + // Post-quantize swizzle for quantizers whose kernel does not bake + // the GEMM-swizzled scale layout in directly + if (quantizer_cpp->optimize_for_gemm && !output_cpp.get_with_gemm_swizzled_scales()) { + inplace_swizzle_scale_for_gemm(output_py); + } + return output_py; } diff --git a/transformer_engine/pytorch/csrc/extensions/grouped_mlp_experimental.cpp b/transformer_engine/pytorch/csrc/extensions/grouped_mlp_experimental.cpp new file mode 100644 index 0000000000..0ab8bc8d61 --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/grouped_mlp_experimental.cpp @@ -0,0 +1,158 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +// Experimental helpers for the fused grouped MLP. + +#include + +#include +#include +#include +#include + +#include "common/common.h" +#include "extensions.h" + +namespace transformer_engine { +namespace pytorch { +namespace grouped_mlp_experimental { + +std::tuple swizzle_scales_and_pack_ptrs_for_discrete_weights( + const std::vector &data_tensors, const std::vector &scale_tensors, + const std::string &swizzle_type_str, const c10::Device &device) { + const size_t num_tensors = data_tensors.size(); + NVTE_CHECK(scale_tensors.size() == num_tensors, + "Expected data_tensors and scale_tensors to have matching sizes, but got ", + num_tensors, " and ", scale_tensors.size(), "."); + + // Parse swizzle type + enum class SwizzleType { Invalid, MXFP8Rowwise, MXFP8Columnwise, NVFP4 }; + SwizzleType swizzle_type = SwizzleType::Invalid; + if (swizzle_type_str == "mxfp8_rowwise") { + swizzle_type = SwizzleType::MXFP8Rowwise; + } else if (swizzle_type_str == "mxfp8_columnwise") { + swizzle_type = SwizzleType::MXFP8Columnwise; + } else if (swizzle_type_str == "nvfp4") { + swizzle_type = SwizzleType::NVFP4; + } else { + NVTE_ERROR("Unsupported swizzle type (", swizzle_type_str, + "). Expected one of: mxfp8_rowwise, mxfp8_columnwise, nvfp4."); + } + + // Trivial case: no tensors. Return empty tensors. + if (num_tensors == 0) { + auto empty_ptrs = at::empty({0}, at::TensorOptions().dtype(at::kLong).device(device)); + auto empty_scales = at::empty({0}, at::TensorOptions().dtype(at::kByte).device(device)); + return {empty_ptrs, empty_ptrs.clone(), std::move(empty_scales)}; + } + + // CUDA stream + auto stream = at::cuda::getCurrentCUDAStream(); + + // Tensor properties + NVTEScalingMode scaling_mode; + transformer_engine::DType data_dtype, scale_dtype; + NVTETensorParam data_param_name, scale_param_name; + switch (swizzle_type) { + case SwizzleType::MXFP8Rowwise: + case SwizzleType::MXFP8Columnwise: + scaling_mode = NVTE_MXFP8_1D_SCALING; + data_dtype = transformer_engine::DType::kFloat8E4M3; + scale_dtype = transformer_engine::DType::kFloat8E8M0; + if (swizzle_type == SwizzleType::MXFP8Rowwise) { + data_param_name = kNVTERowwiseData; + scale_param_name = kNVTERowwiseScaleInv; + } else { + data_param_name = kNVTEColumnwiseData; + scale_param_name = kNVTEColumnwiseScaleInv; + } + break; + case SwizzleType::NVFP4: + scaling_mode = NVTE_NVFP4_1D_SCALING; + data_dtype = transformer_engine::DType::kFloat4E2M1; + scale_dtype = transformer_engine::DType::kFloat8E4M3; + data_param_name = kNVTERowwiseData; + scale_param_name = kNVTERowwiseScaleInv; + break; + default: + NVTE_ERROR("Unsupported swizzle type (", static_cast(swizzle_type), ")."); + } + + // Data shape + NVTEShape data_shape = convertTorchShape(data_tensors[0].sizes()); + if (swizzle_type == SwizzleType::NVFP4) { + // NVFP4 packs two 4-bit values per byte + NVTE_CHECK(data_shape.ndim > 0, "Invalid shape for NVFP4 data tensor (", + getTensorShape(data_tensors[0]), ")."); + data_shape.data[data_shape.ndim - 1] *= 2; + } + + // Scale shape + const NVTEShape scale_shape = convertTorchShape(scale_tensors[0].sizes()); + NVTE_CHECK(scale_shape.ndim == 2, + "Expected 2D scale tensor, but got shape=", getTensorShape(scale_tensors[0]), "."); + const size_t scale_numel = scale_shape.data[0] * scale_shape.data[1]; + const size_t scale_dtype_bits = transformer_engine::pytorch::typeToNumBits(scale_dtype); + const size_t scale_bytes = ceildiv(scale_numel * scale_dtype_bits, 8); + + // Allocate single buffer for swizzled scales. Uses a uniform stride since + // all tensors share the same scale shape. + const size_t swizzled_scales_stride = roundup(scale_bytes, 16); // Align to 16 bytes + auto swizzled_scales = at::empty({static_cast(swizzled_scales_stride * num_tensors)}, + at::TensorOptions().dtype(at::kByte).device(device)); + uint8_t *swizzled_scales_dptr = reinterpret_cast(swizzled_scales.data_ptr()); + + // Allocate input/output NVTETensors as a single batch. The first + // num_tensors entries are inputs; the next num_tensors are outputs. + MultiTensorWrapper nvte_tensors(2 * num_tensors, scaling_mode); + NVTETensor *inputs_nvte = nvte_tensors.data(); + NVTETensor *outputs_nvte = nvte_tensors.data() + num_tensors; + + auto set_param = [](NVTETensor t, NVTETensorParam param, void *dptr, + transformer_engine::DType dtype, const NVTEShape &shape) { + NVTEBasicTensor data{dptr, static_cast(dtype), shape}; + nvte_set_tensor_param_v2(t, param, &data, sizeof(data)); + }; + + // Configure NVTETensors + for (size_t i = 0; i < num_tensors; ++i) { + const uint8_t swizzled_flag = 1; + nvte_set_tensor_param_v2(outputs_nvte[i], kNVTEWithGEMMSwizzledScales, &swizzled_flag, + sizeof(swizzled_flag)); + void *in_scale_ptr = scale_tensors[i].data_ptr(); + void *out_scale_ptr = swizzled_scales_dptr + i * swizzled_scales_stride; + set_param(inputs_nvte[i], data_param_name, nullptr, data_dtype, data_shape); + set_param(inputs_nvte[i], scale_param_name, in_scale_ptr, scale_dtype, scale_shape); + set_param(outputs_nvte[i], data_param_name, nullptr, data_dtype, data_shape); + set_param(outputs_nvte[i], scale_param_name, out_scale_ptr, scale_dtype, scale_shape); + } + + // Launch swizzle kernel + nvte_multi_tensor_swizzle_scaling_factors(inputs_nvte, outputs_nvte, num_tensors, stream); + + // Pack data pointers (first half) and swizzled scale pointers (second half) + // into a single host buffer and copy to device with one kernel launch. + std::vector packed_ptrs_host(2 * num_tensors); + for (size_t i = 0; i < num_tensors; ++i) { + packed_ptrs_host[i] = reinterpret_cast(data_tensors[i].data_ptr()); + packed_ptrs_host[num_tensors + i] = + reinterpret_cast(swizzled_scales_dptr + i * swizzled_scales_stride); + } + auto packed_ptrs_device = at::empty({static_cast(2 * num_tensors)}, + at::TensorOptions().dtype(at::kLong).device(device)); + nvte_copy_host_to_device_via_kernel(packed_ptrs_host.data(), packed_ptrs_device.data_ptr(), + 2 * num_tensors * sizeof(uint64_t), stream); + + // Return the two pointer arrays as views into the packed device buffer. + auto data_ptrs = packed_ptrs_device.narrow(0, 0, static_cast(num_tensors)); + auto scale_ptrs = packed_ptrs_device.narrow(0, static_cast(num_tensors), + static_cast(num_tensors)); + return {std::move(data_ptrs), std::move(scale_ptrs), std::move(swizzled_scales)}; +} + +} // namespace grouped_mlp_experimental +} // namespace pytorch +} // namespace transformer_engine diff --git a/transformer_engine/pytorch/csrc/extensions/misc.cpp b/transformer_engine/pytorch/csrc/extensions/misc.cpp index c5707fa53c..727c79aea3 100644 --- a/transformer_engine/pytorch/csrc/extensions/misc.cpp +++ b/transformer_engine/pytorch/csrc/extensions/misc.cpp @@ -4,7 +4,12 @@ * See LICENSE for license information. ************************************************************************/ +#include + +#include + #include "../extensions.h" +#include "common/common.h" namespace transformer_engine::pytorch { @@ -30,4 +35,25 @@ at::Tensor splits_to_offsets(const at::Tensor &first_dims, int64_t logical_last_ return output; } +at::Tensor copy_data_ptrs_to_device(const std::vector &tensors, + const c10::Device &device) { + // Collect data pointers + std::vector ptrs_host; + ptrs_host.reserve(tensors.size()); + for (const auto &tensor : tensors) { + ptrs_host.push_back(reinterpret_cast(tensor.data_ptr())); + } + + // Allocate device buffer + auto ptrs_device = at::empty({static_cast(tensors.size())}, + at::TensorOptions().dtype(at::kLong).device(device)); + + // Load pointers on device + nvte_copy_host_to_device_via_kernel(ptrs_host.data(), ptrs_device.data_ptr(), + tensors.size() * sizeof(uint64_t), + at::cuda::getCurrentCUDAStream()); + + return ptrs_device; +} + } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index c1b38a2275..a931bdd223 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -494,10 +494,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::call_guard()); m.def("copy_data_ptrs_to_device", &transformer_engine::pytorch::copy_data_ptrs_to_device, py::arg("tensors"), py::arg("device"), py::call_guard()); - m.def("transform_and_copy_data_ptrs_to_device", - &transformer_engine::pytorch::transform_and_copy_data_ptrs_to_device, - py::arg("transform_type"), py::arg("tensors"), py::arg("device"), - py::call_guard()); m.def("splits_to_offsets", &transformer_engine::pytorch::splits_to_offsets, "Compute grouped tensor offsets from split sizes", py::arg("first_dims"), py::arg("logical_last_dim"), py::call_guard()); @@ -607,6 +603,17 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::call_guard(), py::arg("allgather_communicator"), py::arg("send_stream"), py::arg("recv_stream")); + // Experimental fused grouped MLP + auto grouped_mlp_experimental = m.def_submodule( + "grouped_mlp_experimental", + "Experimental helpers for the fused grouped MLP (unstable, may change or disappear)."); + grouped_mlp_experimental.def("swizzle_scales_and_pack_ptrs_for_discrete_weights", + &transformer_engine::pytorch::grouped_mlp_experimental:: + swizzle_scales_and_pack_ptrs_for_discrete_weights, + py::arg("data_tensors"), py::arg("scale_tensors"), + py::arg("swizzle_type"), py::arg("device"), + py::call_guard()); + // Data structures py::class_(m, "FP8TensorMeta") .def(py::init<>()) @@ -676,4 +683,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .def("get_buffer", &CommOverlapP2P::get_buffer, py::arg("local_chunk") = false, py::arg("shape") = std::nullopt) .def("get_communication_stream", &CommOverlapP2P::get_communication_stream); -} +} // NOLINT(readability/fn_size) diff --git a/transformer_engine/pytorch/csrc/extensions/utils.cpp b/transformer_engine/pytorch/csrc/extensions/utils.cpp deleted file mode 100644 index 453f238c0d..0000000000 --- a/transformer_engine/pytorch/csrc/extensions/utils.cpp +++ /dev/null @@ -1,177 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include - -#include -#include -#include -#include -#include - -#include "common/common.h" -#include "extensions.h" - -namespace transformer_engine::pytorch { - -at::Tensor copy_data_ptrs_to_device(const std::vector &tensors, - const c10::Device &device) { - // Collect data pointers - std::vector ptrs_host; - ptrs_host.reserve(tensors.size()); - for (const auto &tensor : tensors) { - ptrs_host.push_back(reinterpret_cast(tensor.data_ptr())); - } - - // Allocate device buffer - auto ptrs_device = at::empty({static_cast(tensors.size())}, - at::TensorOptions().dtype(at::kLong).device(device)); - - // Load pointers on device - nvte_copy_host_to_device_via_kernel(ptrs_host.data(), ptrs_device.data_ptr(), - tensors.size() * sizeof(uint64_t), - at::cuda::getCurrentCUDAStream()); - - return ptrs_device; -} - -std::tuple> transform_and_copy_data_ptrs_to_device( - const std::string &transform_type, const std::vector &tensors, - const c10::Device &device) { - const size_t num_tensors = tensors.size(); - - // Trivial cases - if (transform_type.empty()) { - // No transform, just load pointers on device - return {copy_data_ptrs_to_device(tensors, device), std::nullopt}; - } - if (num_tensors == 0) { - // No input tensors, return tensor with no elements - return {at::empty({int64_t{0}}, at::TensorOptions().dtype(at::kLong).device(device)), - std::nullopt}; - } - - // CUDA stream - auto stream = at::cuda::getCurrentCUDAStream(); - - // Swizzle scales for GEMM, with uniform tensor sizes - const bool uniform_mxfp8_rowwise_swizzle = transform_type == "uniform_mxfp8_rowwise_swizzle"; - const bool uniform_mxfp8_colwise_swizzle = transform_type == "uniform_mxfp8_columnwise_swizzle"; - const bool uniform_nvfp4_swizzle = transform_type == "uniform_nvfp4_swizzle"; - if (uniform_mxfp8_rowwise_swizzle || uniform_mxfp8_colwise_swizzle || uniform_nvfp4_swizzle) { - // Tensor format - NVTEScalingMode scaling_mode = NVTE_INVALID_SCALING; - if (uniform_mxfp8_rowwise_swizzle || uniform_mxfp8_colwise_swizzle) { - scaling_mode = NVTE_MXFP8_1D_SCALING; - } else if (uniform_nvfp4_swizzle) { - scaling_mode = NVTE_NVFP4_1D_SCALING; - } - - // Data types - transformer_engine::DType data_dtype, scale_dtype; - switch (scaling_mode) { - case NVTE_MXFP8_1D_SCALING: - data_dtype = transformer_engine::DType::kFloat8E4M3; - scale_dtype = transformer_engine::DType::kFloat8E8M0; - break; - case NVTE_NVFP4_1D_SCALING: - data_dtype = transformer_engine::DType::kFloat4E2M1; - scale_dtype = transformer_engine::DType::kFloat8E4M3; - break; - default: - NVTE_ERROR("Unsupported case."); - } - - // Scale shape - const NVTEShape scale_shape = convertTorchShape(tensors[0].sizes()); - NVTE_CHECK(scale_shape.ndim == 2, - "Expected 2D scale tensor, but got shape=", getTensorShape(tensors[0]), "."); - const size_t scale_numel = scale_shape.data[0] * scale_shape.data[1]; - const size_t scale_dtype_bits = transformer_engine::pytorch::typeToNumBits(scale_dtype); - const size_t scale_bytes = ceildiv(scale_numel * scale_dtype_bits, 8); - - // Expected data shape - // Note: May not match actual data shape since the scales are padded. - // This is fine since we're not actually touching the data. - NVTEShape data_shape; - data_shape.ndim = 2; - if (uniform_mxfp8_rowwise_swizzle) { - data_shape.data[0] = scale_shape.data[0]; - data_shape.data[1] = scale_shape.data[1] * 32; - } else if (uniform_mxfp8_colwise_swizzle) { - data_shape.data[0] = scale_shape.data[0] * 32; - data_shape.data[1] = scale_shape.data[1]; - } else if (uniform_nvfp4_swizzle) { - data_shape.data[0] = scale_shape.data[0]; - data_shape.data[1] = scale_shape.data[1] * 16; - } else { - NVTE_ERROR("Unsupported case."); - } - - // Allocate single buffer for swizzled scales. - // Uses a uniform stride since all tensors share the same scale shape. - const size_t swizzled_scales_stride = roundup(scale_bytes, 16); // Align to 16 bytes - auto swizzled_scales = at::empty({static_cast(swizzled_scales_stride * num_tensors)}, - at::TensorOptions().dtype(at::kByte).device(device)); - uint8_t *swizzled_scales_dptr = reinterpret_cast(swizzled_scales.data_ptr()); - - // Allocate input/output NVTETensors as a single batch. The first - // num_tensors entries are inputs; the next num_tensors are outputs. - MultiTensorWrapper nvte_tensors(2 * num_tensors, scaling_mode); - NVTETensor *inputs_nvte = nvte_tensors.data(); - NVTETensor *outputs_nvte = nvte_tensors.data() + num_tensors; - - auto set_param = [](NVTETensor t, NVTETensorParam param, void *dptr, - transformer_engine::DType dtype, const NVTEShape &shape) { - NVTEBasicTensor data{dptr, static_cast(dtype), shape}; - nvte_set_tensor_param_v2(t, param, &data, sizeof(data)); - }; - - for (size_t i = 0; i < num_tensors; ++i) { - const uint8_t swizzled_flag = 1; - nvte_set_tensor_param_v2(outputs_nvte[i], kNVTEWithGEMMSwizzledScales, &swizzled_flag, - sizeof(swizzled_flag)); - void *in_scale_ptr = tensors[i].data_ptr(); - void *out_scale_ptr = swizzled_scales_dptr + i * swizzled_scales_stride; - if (uniform_mxfp8_rowwise_swizzle || uniform_nvfp4_swizzle) { - set_param(inputs_nvte[i], kNVTERowwiseData, nullptr, data_dtype, data_shape); - set_param(inputs_nvte[i], kNVTERowwiseScaleInv, in_scale_ptr, scale_dtype, scale_shape); - set_param(outputs_nvte[i], kNVTERowwiseData, nullptr, data_dtype, data_shape); - set_param(outputs_nvte[i], kNVTERowwiseScaleInv, out_scale_ptr, scale_dtype, scale_shape); - } else if (uniform_mxfp8_colwise_swizzle) { - set_param(inputs_nvte[i], kNVTEColumnwiseData, nullptr, data_dtype, data_shape); - set_param(inputs_nvte[i], kNVTEColumnwiseScaleInv, in_scale_ptr, scale_dtype, scale_shape); - set_param(outputs_nvte[i], kNVTEColumnwiseData, nullptr, data_dtype, data_shape); - set_param(outputs_nvte[i], kNVTEColumnwiseScaleInv, out_scale_ptr, scale_dtype, - scale_shape); - } - } - - // Launch kernel - nvte_multi_tensor_swizzle_scaling_factors(inputs_nvte, outputs_nvte, num_tensors, stream); - - // Collect data pointers - std::vector ptrs_host; - ptrs_host.reserve(num_tensors); - for (size_t i = 0; i < num_tensors; ++i) { - ptrs_host.push_back( - reinterpret_cast(swizzled_scales_dptr + i * swizzled_scales_stride)); - } - - // Load pointers on device - auto ptrs_device = at::empty({static_cast(num_tensors)}, - at::TensorOptions().dtype(at::kLong).device(device)); - nvte_copy_host_to_device_via_kernel(ptrs_host.data(), ptrs_device.data_ptr(), - num_tensors * sizeof(uint64_t), stream); - - return {std::move(ptrs_device), std::move(swizzled_scales)}; - } - - // Unsupported transform - NVTE_ERROR("Unsupported transform type (", transform_type, ")"); -} - -} // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index bc87b54ba8..a54a301664 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1781,7 +1781,12 @@ std::pair NVFP4Quantizer::create_tensor( using namespace pybind11::literals; // Scaling factor format - const bool with_gemm_swizzled_scales = false; /// TODO (tmoon) self->optimize_for_gemm + // + // The NVFP4 quantize kernel requires `with_gemm_swizzled_scales=false` and + // emits compact scales. When `optimize_for_gemm` is set, the caller of + // `quantize` (cast.cpp) applies a post-quantize swizzle that flips this + // flag on the returned tensor. + const bool with_gemm_swizzled_scales = false; // Tensor dimensions const std::vector shape_int64(shape.begin(), shape.end()); @@ -2077,7 +2082,12 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( NVTE_CHECK(detail::IsNVFP4Tensor(tensor.ptr()), "NVFP4Quantizer must output to IsNVFP4Tensor."); // Scaling factor format - const bool with_gemm_swizzled_scales = false; // TODO (tmoon) Enable with optimize_for_gemm + // + // The NVFP4 quantize kernel requires `with_gemm_swizzled_scales=false` and + // emits compact scales. When `optimize_for_gemm` is set, the caller of + // `quantize` (cast.cpp) applies a post-quantize swizzle that flips this + // flag on the returned tensor. + const bool with_gemm_swizzled_scales = false; // Extract buffers from Python tensor auto get_tensor = [&tensor](const char* name) -> std::optional { @@ -2109,6 +2119,9 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( } const auto [flat_first_dim, flat_last_dim] = get_2d_dims(shape); + tensor.attr("_with_gemm_swizzled_scales") = py::cast(with_gemm_swizzled_scales); + + // Advanced NVFP4 modes const bool row_scaled_nvfp4 = this->row_scaled_nvfp4; const bool nvfp4_use_4over6 = this->nvfp4_4over6_mode != kNVTENVFP44Over6Disabled; const int nvfp4_e4m3_max = this->nvfp4_e4m3_max; diff --git a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py index 792b6d7811..3e4b898ee4 100644 --- a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py @@ -649,17 +649,13 @@ def fuser_backward( fc2_dactivation_kwargs["b_tensor"] = fc2_w_data fc2_dactivation_kwargs["sfb_tensor"] = fc2_w_scales else: - fc2_b_ptrs = tex.copy_data_ptrs_to_device( - [w._columnwise_data for w in grouped_fc2_weight], - device, - ) - swizzle_type = ( - "uniform_nvfp4_swizzle" if use_nvfp4 else "uniform_mxfp8_columnwise_swizzle" - ) - fc2_sfb_ptrs, _fc2_sfb_buffer = tex.transform_and_copy_data_ptrs_to_device( - swizzle_type, - [w._columnwise_scale_inv for w in grouped_fc2_weight], - device, + fc2_b_ptrs, fc2_sfb_ptrs, _fc2_sfb_buffer = ( + tex.grouped_mlp_experimental.swizzle_scales_and_pack_ptrs_for_discrete_weights( + [w._columnwise_data for w in grouped_fc2_weight], + [w._columnwise_scale_inv for w in grouped_fc2_weight], + "nvfp4" if use_nvfp4 else "mxfp8_columnwise", + device, + ) ) fc2_dactivation_kwargs["b_ptrs"] = fc2_b_ptrs fc2_dactivation_kwargs["sfb_ptrs"] = fc2_sfb_ptrs @@ -911,17 +907,13 @@ def fuser_backward( fc1_dgrad_kwargs["b_tensor"] = fc1_w_data fc1_dgrad_kwargs["sfb_tensor"] = fc1_w_scales else: - fc1_b_ptrs = tex.copy_data_ptrs_to_device( - [w._columnwise_data for w in grouped_fc1_weight], - device, - ) - swizzle_type = ( - "uniform_nvfp4_swizzle" if use_nvfp4 else "uniform_mxfp8_columnwise_swizzle" - ) - fc1_sfb_ptrs, _fc1_sfb_buffer = tex.transform_and_copy_data_ptrs_to_device( - swizzle_type, - [w._columnwise_scale_inv for w in grouped_fc1_weight], - device, + fc1_b_ptrs, fc1_sfb_ptrs, _fc1_sfb_buffer = ( + tex.grouped_mlp_experimental.swizzle_scales_and_pack_ptrs_for_discrete_weights( + [w._columnwise_data for w in grouped_fc1_weight], + [w._columnwise_scale_inv for w in grouped_fc1_weight], + "nvfp4" if use_nvfp4 else "mxfp8_columnwise", + device, + ) ) fc1_dgrad_kwargs["b_ptrs"] = fc1_b_ptrs diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index f4f2108578..ef9319918f 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -449,15 +449,13 @@ def fuser_forward( fc1_activation_kwargs["sfb_tensor"] = fc1_w_scales else: # Discrete-weight kernel: per-expert data/scale pointers - fc1_b_ptrs = tex.copy_data_ptrs_to_device( - [w._rowwise_data for w in grouped_fc1_weight], - device, - ) - swizzle_type = "uniform_nvfp4_swizzle" if use_nvfp4 else "uniform_mxfp8_rowwise_swizzle" - fc1_sfb_ptrs, _fc1_sfb_buffer = tex.transform_and_copy_data_ptrs_to_device( - swizzle_type, - [w._rowwise_scale_inv for w in grouped_fc1_weight], - device, + fc1_b_ptrs, fc1_sfb_ptrs, _fc1_sfb_buffer = ( + tex.grouped_mlp_experimental.swizzle_scales_and_pack_ptrs_for_discrete_weights( + [w._rowwise_data for w in grouped_fc1_weight], + [w._rowwise_scale_inv for w in grouped_fc1_weight], + "nvfp4" if use_nvfp4 else "mxfp8_rowwise", + device, + ) ) fc1_activation_kwargs["b_ptrs"] = fc1_b_ptrs fc1_activation_kwargs["sfb_ptrs"] = fc1_sfb_ptrs @@ -625,17 +623,13 @@ def fuser_forward( fc2_quant_kwargs["b_tensor"] = fc2_w_data fc2_quant_kwargs["sfb_tensor"] = fc2_w_scales else: - fc2_b_ptrs = tex.copy_data_ptrs_to_device( - [w._rowwise_data for w in grouped_fc2_weight], - device, - ) - swizzle_type = ( - "uniform_nvfp4_swizzle" if use_nvfp4 else "uniform_mxfp8_rowwise_swizzle" - ) - fc2_sfb_ptrs, _fc2_sfb_buffer = tex.transform_and_copy_data_ptrs_to_device( - swizzle_type, - [w._rowwise_scale_inv for w in grouped_fc2_weight], - device, + fc2_b_ptrs, fc2_sfb_ptrs, _fc2_sfb_buffer = ( + tex.grouped_mlp_experimental.swizzle_scales_and_pack_ptrs_for_discrete_weights( + [w._rowwise_data for w in grouped_fc2_weight], + [w._rowwise_scale_inv for w in grouped_fc2_weight], + "nvfp4" if use_nvfp4 else "mxfp8_rowwise", + device, + ) ) fc2_quant_kwargs["b_ptrs"] = fc2_b_ptrs fc2_quant_kwargs["sfb_ptrs"] = fc2_sfb_ptrs