Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 128 additions & 3 deletions tests/pytorch/test_grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import os
import random
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Sequence

import pytest
import torch
Expand All @@ -20,6 +20,7 @@
GroupedLinear,
Linear,
MXFP8Quantizer,
NVFP4Quantizer,
autocast,
is_bf16_available,
quantized_model_init,
Expand All @@ -35,13 +36,19 @@
)
from transformer_engine.pytorch.tensor.grouped_tensor import GroupedTensor
import transformer_engine_torch as tex
from utils import ModelConfig, recipe_id, reset_rng_states, skip_unsupported_backward_override
from utils import (
ModelConfig,
assert_close,
recipe_id,
reset_rng_states,
skip_unsupported_backward_override,
)

# Only run FP8 tests on supported devices.
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
fp8_block_scaling_available, _ = te.is_fp8_block_scaling_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
nvfp4_available, _ = te.is_nvfp4_available(return_reason=True)
nvfp4_available, reason_for_no_nvfp4 = te.is_nvfp4_available(return_reason=True)

seed = 1234
reset_rng_states()
Expand Down Expand Up @@ -1740,3 +1747,121 @@ def _train_step(x, dy, out_buf, *, use_graphed):
for graph_grad, param in zip(graph_param_grads, grouped_linear.parameters()):
assert param.grad is not None
torch.testing.assert_close(graph_grad.float(), param.grad.float(), **tols)


@pytest.mark.parametrize("swizzle_type", ["mxfp8_rowwise", "mxfp8_columnwise", "nvfp4"])
def test_swizzle_scales_and_pack_ptrs_for_discrete_weights(
swizzle_type: str,
num_tensors: int = 4,
shape: Sequence[int] = (160, 96),
):
"""Helper function for preparing discrete weights for cuDNN group GEMM kernel"""

# Skip unsupported configurations
if not mxfp8_available and swizzle_type in ("mxfp8_rowwise", "mxfp8_columnwise"):
pytest.skip(reason_for_no_mxfp8)
if not nvfp4_available and swizzle_type == "nvfp4":
pytest.skip(reason_for_no_nvfp4)

# Construct quantizer
quantizer = None
if swizzle_type in ("mxfp8_rowwise", "mxfp8_columnwise"):
quantizer = MXFP8Quantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
rowwise=swizzle_type == "mxfp8_rowwise",
columnwise=swizzle_type == "mxfp8_columnwise",
)
elif swizzle_type == "nvfp4":
quantizer = NVFP4Quantizer(
columnwise=False,
with_rht=False,
with_post_rht_amax=False,
with_2d_quantization=False,
stochastic_rounding=False,
with_random_sign_mask=False,
)

# Per-expert tensors: unquantized, quantized with compact scales,
# quantized with swizzled scales
device = torch.device("cuda")
unquantized_tensors = [
torch.randn(shape, dtype=torch.bfloat16, device=device) for _ in range(num_tensors)
]
quantizer.optimize_for_gemm = False
tensors_with_compact_scales = [quantizer(t) for t in unquantized_tensors]
quantizer.optimize_for_gemm = True
tensors_with_swizzled_scales = [quantizer(t) for t in unquantized_tensors]

# Extract data and scale buffers
if swizzle_type in ("mxfp8_rowwise", "nvfp4"):
data_tensors = [qx._rowwise_data for qx in tensors_with_compact_scales]
scale_tensors = [qx._rowwise_scale_inv for qx in tensors_with_compact_scales]
ref_scale_tensors = [qx._rowwise_scale_inv for qx in tensors_with_swizzled_scales]
elif swizzle_type == "mxfp8_columnwise":
data_tensors = [qx._columnwise_data for qx in tensors_with_compact_scales]
scale_tensors = [qx._columnwise_scale_inv for qx in tensors_with_compact_scales]
ref_scale_tensors = [qx._columnwise_scale_inv for qx in tensors_with_swizzled_scales]
else:
raise ValueError("Unrecogized swizzle type")

# Call the helper function
data_ptrs, scale_ptrs, swizzled_scales_buffer = (
tex.grouped_mlp_experimental.swizzle_scales_and_pack_ptrs_for_discrete_weights(
data_tensors,
scale_tensors,
swizzle_type,
device,
)
)

# Check data pointer values
expected_data_ptrs = torch.tensor(
[t.data_ptr() for t in data_tensors],
dtype=torch.int64,
device="cpu",
)
assert_close(data_ptrs, expected_data_ptrs)

# Check scale pointer values
scale_bytes = scale_tensors[0].numel() * scale_tensors[0].element_size()
expected_scale_ptrs = torch.tensor(
[swizzled_scales_buffer.data_ptr() + i * scale_bytes for i in range(num_tensors)],
dtype=torch.int64,
device="cpu",
)
Comment on lines +1825 to +1831
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 The expected scale-pointer stride uses scale_bytes (raw size), but the C++ code allocates the swizzled-scale buffer with swizzled_scales_stride = roundup(scale_bytes, 16). For the current shape (160, 96) every scale size happens to be a multiple of 16 (480 and 960), so both values agree and the assertion passes, but any shape that produces a non-16-aligned scale count would compute incorrect expected pointers. The same implicit assumption appears in the view_as call below — if scale_bytes != swizzled_scales_stride, that call throws a RuntimeError rather than silently validating padding-separated data.

Suggested change
# Check scale pointer values
scale_bytes = scale_tensors[0].numel() * scale_tensors[0].element_size()
expected_scale_ptrs = torch.tensor(
[swizzled_scales_buffer.data_ptr() + i * scale_bytes for i in range(num_tensors)],
dtype=torch.int64,
device="cpu",
)
# Check scale pointer values
# Use the same 16-byte-aligned stride as the C++ implementation
import math
scale_bytes = scale_tensors[0].numel() * scale_tensors[0].element_size()
scale_stride = math.ceil(scale_bytes / 16) * 16
expected_scale_ptrs = torch.tensor(
[swizzled_scales_buffer.data_ptr() + i * scale_stride for i in range(num_tensors)],
dtype=torch.int64,
device="cpu",
)

assert_close(scale_ptrs, expected_scale_ptrs)

# Check swizzled scale values
swizzled_scales_buffer = swizzled_scales_buffer.view(torch.uint8)
expected_swizzled_scales_buffer = (
torch.cat(ref_scale_tensors).view(torch.uint8).view_as(swizzled_scales_buffer)
)
assert_close(
swizzled_scales_buffer,
expected_swizzled_scales_buffer,
)

# Poison the padded compact scales
if swizzle_type == "mxfp8_rowwise":
unpadded_scale_shape = (shape[0], shape[1] // 32)
elif swizzle_type == "mxfp8_columnwise":
unpadded_scale_shape = (shape[0] // 32, shape[1])
elif swizzle_type == "nvfp4":
unpadded_scale_shape = (shape[0], shape[1] // 16)
for scale in scale_tensors:
scale[unpadded_scale_shape[0] :, :].view(torch.uint8).fill_(-1)
scale[:, unpadded_scale_shape[1] :].view(torch.uint8).fill_(-1)

# Check that swizzling removes poisoned pad scales
_, _, swizzled_scales_buffer = (
tex.grouped_mlp_experimental.swizzle_scales_and_pack_ptrs_for_discrete_weights(
data_tensors,
scale_tensors,
swizzle_type,
device,
)
)
assert_close(
swizzled_scales_buffer,
expected_swizzled_scales_buffer,
)
5 changes: 5 additions & 0 deletions transformer_engine/common/swizzle/swizzle.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1049,6 +1049,11 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
" column-wise scaling factors, but got shape=", output->columnwise_scale_inv.shape, ".");
}

// Return early if tensor has no entries
if (m == 0 || k == 0) {
return;
}

// Choose swizzle implementation
bool rowwise_swizzle{false}, columnwise_swizzle{false};
switch (scaling_mode) {
Expand Down
23 changes: 20 additions & 3 deletions transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -490,9 +490,26 @@ at::Tensor splits_to_offsets(const at::Tensor &first_dims, int64_t logical_last_
at::Tensor copy_data_ptrs_to_device(const std::vector<at::Tensor> &tensors,
const c10::Device &device);

std::tuple<at::Tensor, std::optional<at::Tensor>> transform_and_copy_data_ptrs_to_device(
const std::string &transform_type, const std::vector<at::Tensor> &tensors,
const c10::Device &device);
/***************************************************************************************************
* Experimental helpers for the fused grouped MLP
*
* These primarily exist to support cuDNN CuTe DSL grouped GEMM
* kernels. Since those are unstable and under active development,
* these helpers should also be considered unstable.
**************************************************************************************************/

namespace grouped_mlp_experimental {

// Prepare discrete weight tensors for the cuDNN CuTe DSL grouped GEMM
// kernel by swizzling scales and copying data and scale pointers to
// device. All tensors must share a uniform shape and `swizzle_type`
// must be one of "mxfp8_rowwise", "mxfp8_columnwise", or "nvfp4".
// Returns {data_ptrs_device, scale_ptrs_device, swizzled_scales_buffer}.
std::tuple<at::Tensor, at::Tensor, at::Tensor> swizzle_scales_and_pack_ptrs_for_discrete_weights(
const std::vector<at::Tensor> &data_tensors, const std::vector<at::Tensor> &scale_tensors,
const std::string &swizzle_type, const c10::Device &device);

} // namespace grouped_mlp_experimental

/***************************************************************************************************
* Support THD format for Context Parallel
Expand Down
6 changes: 6 additions & 0 deletions transformer_engine/pytorch/csrc/extensions/cast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,12 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob
// Perform quantization
quantizer_cpp->quantize(input_cpp, output_cpp, noop_flag_cpp);

// Post-quantize swizzle for quantizers whose kernel does not bake
// the GEMM-swizzled scale layout in directly
if (quantizer_cpp->optimize_for_gemm && !output_cpp.get_with_gemm_swizzled_scales()) {
inplace_swizzle_scale_for_gemm(output_py);
}
Comment on lines +65 to +69
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes here and in the C++ NVFP4 quantizer are to fix a bug uncovered by the test. When the NVFP4 quantizer was configured with optimize_for_gemm, it would not actually produce swizzled scales.


return output_py;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

// Experimental helpers for the fused grouped MLP.

#include <ATen/cuda/CUDAContext.h>

#include <string>
#include <tuple>
#include <utility>
#include <vector>

#include "common/common.h"
#include "extensions.h"

namespace transformer_engine {
namespace pytorch {
namespace grouped_mlp_experimental {

std::tuple<at::Tensor, at::Tensor, at::Tensor> swizzle_scales_and_pack_ptrs_for_discrete_weights(
const std::vector<at::Tensor> &data_tensors, const std::vector<at::Tensor> &scale_tensors,
const std::string &swizzle_type_str, const c10::Device &device) {
const size_t num_tensors = data_tensors.size();
NVTE_CHECK(scale_tensors.size() == num_tensors,
"Expected data_tensors and scale_tensors to have matching sizes, but got ",
num_tensors, " and ", scale_tensors.size(), ".");

// Parse swizzle type
enum class SwizzleType { Invalid, MXFP8Rowwise, MXFP8Columnwise, NVFP4 };
SwizzleType swizzle_type = SwizzleType::Invalid;
if (swizzle_type_str == "mxfp8_rowwise") {
swizzle_type = SwizzleType::MXFP8Rowwise;
} else if (swizzle_type_str == "mxfp8_columnwise") {
swizzle_type = SwizzleType::MXFP8Columnwise;
} else if (swizzle_type_str == "nvfp4") {
swizzle_type = SwizzleType::NVFP4;
} else {
NVTE_ERROR("Unsupported swizzle type (", swizzle_type_str,
"). Expected one of: mxfp8_rowwise, mxfp8_columnwise, nvfp4.");
}

// Trivial case: no tensors. Return empty tensors.
if (num_tensors == 0) {
auto empty_ptrs = at::empty({0}, at::TensorOptions().dtype(at::kLong).device(device));
auto empty_scales = at::empty({0}, at::TensorOptions().dtype(at::kByte).device(device));
return {empty_ptrs, empty_ptrs.clone(), std::move(empty_scales)};
}

// CUDA stream
auto stream = at::cuda::getCurrentCUDAStream();

// Tensor properties
NVTEScalingMode scaling_mode;
transformer_engine::DType data_dtype, scale_dtype;
NVTETensorParam data_param_name, scale_param_name;
switch (swizzle_type) {
case SwizzleType::MXFP8Rowwise:
case SwizzleType::MXFP8Columnwise:
scaling_mode = NVTE_MXFP8_1D_SCALING;
data_dtype = transformer_engine::DType::kFloat8E4M3;
scale_dtype = transformer_engine::DType::kFloat8E8M0;
if (swizzle_type == SwizzleType::MXFP8Rowwise) {
data_param_name = kNVTERowwiseData;
scale_param_name = kNVTERowwiseScaleInv;
} else {
data_param_name = kNVTEColumnwiseData;
scale_param_name = kNVTEColumnwiseScaleInv;
}
break;
case SwizzleType::NVFP4:
scaling_mode = NVTE_NVFP4_1D_SCALING;
data_dtype = transformer_engine::DType::kFloat4E2M1;
scale_dtype = transformer_engine::DType::kFloat8E4M3;
data_param_name = kNVTERowwiseData;
scale_param_name = kNVTERowwiseScaleInv;
break;
default:
NVTE_ERROR("Unsupported swizzle type (", static_cast<int>(swizzle_type), ").");
}

// Data shape
NVTEShape data_shape = convertTorchShape(data_tensors[0].sizes());
if (swizzle_type == SwizzleType::NVFP4) {
// NVFP4 packs two 4-bit values per byte
NVTE_CHECK(data_shape.ndim > 0, "Invalid shape for NVFP4 data tensor (",
getTensorShape(data_tensors[0]), ").");
data_shape.data[data_shape.ndim - 1] *= 2;
}

// Scale shape
const NVTEShape scale_shape = convertTorchShape(scale_tensors[0].sizes());
NVTE_CHECK(scale_shape.ndim == 2,
"Expected 2D scale tensor, but got shape=", getTensorShape(scale_tensors[0]), ".");
Comment on lines +93 to +96
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Missing per-tensor shape consistency check — the function validates that data_tensors.size() == scale_tensors.size() but uses only data_tensors[0] and scale_tensors[0] as the reference shapes for all tensors. If any subsequent tensor has a different shape, the NVTETensor configuration will be wrong for that tensor, leading to an out-of-bounds scale swizzle without any diagnostic. Adding an NVTE_CHECK loop over i > 0 to assert shape equality would catch this early.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function expects uniformly-sized tensors. Actually performing the checks would incur significant overhead.

const size_t scale_numel = scale_shape.data[0] * scale_shape.data[1];
const size_t scale_dtype_bits = transformer_engine::pytorch::typeToNumBits(scale_dtype);
const size_t scale_bytes = ceildiv(scale_numel * scale_dtype_bits, 8);

// Allocate single buffer for swizzled scales. Uses a uniform stride since
// all tensors share the same scale shape.
const size_t swizzled_scales_stride = roundup(scale_bytes, 16); // Align to 16 bytes
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We dont need roundup here right?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's paranoid and somewhat redundant, but it's also cheap.

auto swizzled_scales = at::empty({static_cast<int64_t>(swizzled_scales_stride * num_tensors)},
at::TensorOptions().dtype(at::kByte).device(device));
uint8_t *swizzled_scales_dptr = reinterpret_cast<uint8_t *>(swizzled_scales.data_ptr());

// Allocate input/output NVTETensors as a single batch. The first
// num_tensors entries are inputs; the next num_tensors are outputs.
MultiTensorWrapper nvte_tensors(2 * num_tensors, scaling_mode);
NVTETensor *inputs_nvte = nvte_tensors.data();
NVTETensor *outputs_nvte = nvte_tensors.data() + num_tensors;

auto set_param = [](NVTETensor t, NVTETensorParam param, void *dptr,
transformer_engine::DType dtype, const NVTEShape &shape) {
NVTEBasicTensor data{dptr, static_cast<NVTEDType>(dtype), shape};
nvte_set_tensor_param_v2(t, param, &data, sizeof(data));
};

// Configure NVTETensors
for (size_t i = 0; i < num_tensors; ++i) {
const uint8_t swizzled_flag = 1;
nvte_set_tensor_param_v2(outputs_nvte[i], kNVTEWithGEMMSwizzledScales, &swizzled_flag,
sizeof(swizzled_flag));
void *in_scale_ptr = scale_tensors[i].data_ptr();
void *out_scale_ptr = swizzled_scales_dptr + i * swizzled_scales_stride;
set_param(inputs_nvte[i], data_param_name, nullptr, data_dtype, data_shape);
set_param(inputs_nvte[i], scale_param_name, in_scale_ptr, scale_dtype, scale_shape);
set_param(outputs_nvte[i], data_param_name, nullptr, data_dtype, data_shape);
set_param(outputs_nvte[i], scale_param_name, out_scale_ptr, scale_dtype, scale_shape);
}

// Launch swizzle kernel
nvte_multi_tensor_swizzle_scaling_factors(inputs_nvte, outputs_nvte, num_tensors, stream);

// Pack data pointers (first half) and swizzled scale pointers (second half)
// into a single host buffer and copy to device with one kernel launch.
std::vector<uint64_t> packed_ptrs_host(2 * num_tensors);
for (size_t i = 0; i < num_tensors; ++i) {
packed_ptrs_host[i] = reinterpret_cast<uintptr_t>(data_tensors[i].data_ptr());
packed_ptrs_host[num_tensors + i] =
reinterpret_cast<uintptr_t>(swizzled_scales_dptr + i * swizzled_scales_stride);
}
auto packed_ptrs_device = at::empty({static_cast<int64_t>(2 * num_tensors)},
at::TensorOptions().dtype(at::kLong).device(device));
nvte_copy_host_to_device_via_kernel(packed_ptrs_host.data(), packed_ptrs_device.data_ptr(),
2 * num_tensors * sizeof(uint64_t), stream);

// Return the two pointer arrays as views into the packed device buffer.
auto data_ptrs = packed_ptrs_device.narrow(0, 0, static_cast<int64_t>(num_tensors));
auto scale_ptrs = packed_ptrs_device.narrow(0, static_cast<int64_t>(num_tensors),
static_cast<int64_t>(num_tensors));
return {std::move(data_ptrs), std::move(scale_ptrs), std::move(swizzled_scales)};
}

} // namespace grouped_mlp_experimental
} // namespace pytorch
} // namespace transformer_engine
Loading
Loading