Optimize NVFP4 4over6 candidate error path#3068
Conversation
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Greptile SummaryThis PR adds a fast NVFP4 4over6 candidate-error path controlled by
Confidence Score: 5/5Safe to merge; all paths verified correct and covered by zero-tolerance bitwise tests. The PTX register layout in compute_fp16_error_scales is verified: map4 is packed into bits [7:0] of fp8_pair, cvt.rn.f16x2.e4m3x2 converts lower→lower/upper→upper, and the two mov.b32 replications correctly broadcast each FP16 scale into an f16x2 for mul.rn.f16x2. The Python reference _ref_nvfp4_4over6_fp16_candidate correctly decomposes E2M1 and E4M3 bit fields, and the subnormal FP16 branch is dead code in practice (minimum product exponent is -10, well above the -14 threshold). All error accumulation differences and the tree-reduction ordering match between CUDA and Python. The 1824-passed, zero-tolerance test run provides strong bitwise correctness evidence. No files require special attention. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["cvt_fp32_to_fp4_8x_with_error()"] --> B{err_use_fast_math?}
B -->|false| C["accumulate_dequant_error()\nFP32 decode: dequant x sf x global_amax / denom\n__fdiv_rn / __fmul_rn / __fsub_rn"]
B -->|true| D["accumulate_fp16_scaled_error_pair()\nFP16: mul.rn.f16x2(q_h2, scale_h2)\nthen cvt.f32.f16 to FP32 diff"]
D --> E["f16x2_scaled_to_float2()\nmul.rn.f16x2 then cvt.f32.f16 x2"]
D --> F["original = x x global_encode_scale\n(__fmul_rn)"]
E --> G["diff = candidate minus original"]
F --> G
G --> H["compute_error_rn: abs diff or diff squared"]
C --> H
H --> I["FP32 error accumulation\n(__fadd_rn)"]
Reviews (7): Last reviewed commit: "Drop scripts" | Re-trigger Greptile |
| def _sum_4over6_2d_error(err: torch.Tensor, tile_len_y: int) -> torch.Tensor: | ||
| """Reduce 16 row errors in the same tree order as the CUDA warp reduction.""" | ||
| rows = err.view(err.shape[0] // tile_len_y, tile_len_y, err.shape[1], 1) | ||
| rows = rows.squeeze(-1) | ||
| rows = rows[:, 0:8, :] + rows[:, 8:16, :] | ||
| rows = rows[:, 0:4, :] + rows[:, 4:8, :] | ||
| rows = rows[:, 0:2, :] + rows[:, 2:4, :] | ||
| return (rows[:, 0, :] + rows[:, 1, :]).unsqueeze(-1) |
There was a problem hiding this comment.
Hardcoded 16-row tree reduction without enforcement
The function accepts tile_len_y as a parameter and uses it for the initial view, but the subsequent binary-tree steps unconditionally slice at [8:16], [4:8], [2:4], [0:2], which are only correct when tile_len_y == 16. A caller passing tile_len_y = 32 would silently sum only the first 16 rows and discard the remaining 16; a caller passing tile_len_y = 8 would produce a shape mismatch on the first add. Adding an assert tile_len_y == 16 at the top would make the contract explicit and catch future regressions immediately.
There was a problem hiding this comment.
Addressed in 5136b58. _sum_4over6_2d_error now asserts tile_len_y == 16 before the tree reduction, and I also updated the reference docstring/comment to distinguish original-domain MAE/MSE from E4M3-scaled FP16 product-domain MAE_FP16/MSE_FP16.
Signed-off-by: Ziang Li <ziangli@umich.edu>
|
Since this is identical to the orginal mode in the algebraic sense, may be a better interface design it to let this mode replace the original fast math modes, instead of extending error modes. |
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
|
Interface refactor update: I refactored the core 4over6 error interface so the selection metric and implementation fast path are separate knobs:
This replaces the earlier public |
|
@greptile-apps I have updated PR body, review again. |
|
FlashInfer PR that implements the same contract: |
|
Do you have any data on how different the result is between the regular path and this new fast path? Maybe we could just make it a default path rather than introduce another env variable? |
|
Hi @ptrendx , I have some results using this script: """Compare NVFP4 4over6 E4M3 scales with and without error fast math."""
import os
from contextlib import contextmanager
import torch
from transformer_engine.pytorch import NVFP4Quantizer
import transformer_engine_torch as tex
M, K = 98304, 7168
@contextmanager
def _error_fast_math(enabled: bool):
old_value = os.environ.get("NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH")
os.environ["NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH"] = "1" if enabled else "0"
try:
yield
finally:
if old_value is None:
os.environ.pop("NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH", None)
else:
os.environ["NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH"] = old_value
def _quantize_scale_bytes(
x: torch.Tensor,
err_mode: str,
err_fast_math: bool,
row_scaled: bool,
with_2d_quantization: bool,
nvfp4_e4m3_max: int,
) -> torch.Tensor:
quantizer = NVFP4Quantizer(
fp4_dtype=tex.DType.kFloat4E2M1,
rowwise=True,
columnwise=False,
with_amax_reduction=False,
amax_reduction_group=None,
with_rht=False,
with_post_rht_amax=False,
with_2d_quantization=with_2d_quantization,
row_scaled_nvfp4=row_scaled,
nvfp4_use_4over6=True,
nvfp4_e4m3_max=nvfp4_e4m3_max,
nvfp4_4over6_err_mode=err_mode,
)
with _error_fast_math(err_fast_math):
quantized = quantizer(x)
assert quantized._rowwise_scale_inv is not None
return quantized._rowwise_scale_inv.contiguous().view(torch.uint8)
def _compare_e4m3(
x: torch.Tensor,
dtype_name: str,
scale_mode: str,
row_scaled: bool,
quant_mode: str,
with_2d_quantization: bool,
nvfp4_e4m3_max: int,
) -> None:
for err_mode in ("MAE", "MSE"):
regular = _quantize_scale_bytes(
x, err_mode, False, row_scaled, with_2d_quantization, nvfp4_e4m3_max
)
fast = _quantize_scale_bytes(
x, err_mode, True, row_scaled, with_2d_quantization, nvfp4_e4m3_max
)
same = torch.count_nonzero(regular == fast).item()
total = regular.numel()
print(
f"{scale_mode:>6} {quant_mode:>5} {nvfp4_e4m3_max:8d} "
f"{dtype_name:>5} {err_mode:>3} "
f"{100.0 * same / total:12.6f} {total - same:15d} {total}"
)
def main():
torch.set_grad_enabled(False)
print(f"shape=({M}, {K}), 1d_e4m3_values={M * K // 16}")
print("scale quant e4m3_max dtype mode same_e4m3_pct different_e4m3 total_e4m3")
for scale_mode, row_scaled, quant_mode, with_2d_quantization in (
("tensor", False, "1d", False),
("tensor", False, "2d", True),
("row", True, "1d", False),
):
for nvfp4_e4m3_max in (256, 448):
for dtype, dtype_name in ((torch.bfloat16, "bf16"), (torch.float16, "fp16")):
torch.manual_seed(1234)
x = torch.randn((M, K), dtype=dtype, device="cuda")
_compare_e4m3(
x,
dtype_name,
scale_mode,
row_scaled,
quant_mode,
with_2d_quantization,
nvfp4_e4m3_max,
)
del x
torch.cuda.empty_cache()
if __name__ == "__main__":
main()The agreement rate is consistently >99.9% on random We did not introduce any extra env var in this PR. I do think we can make |
| template <typename Cfg, int E4M3_MAX> | ||
| __device__ __forceinline__ uint32_t cvt_fp32_to_fp4_8x_with_error(const float (&x)[8], | ||
| const float block_scale_inverse, | ||
| __device__ __forceinline__ float2 e2m1x2_scaled_e4m3_to_float2(const uint32_t e2m1_byte, |
There was a problem hiding this comment.
Why do we only compute 2 values in this function? Also, we don't use half of the scale_h2 here. Why don't we instead try to convert here values from both of the branches (so both 4 and 6 would be there, the scaling factors for both of these branches would be converted in a single instruction). Ideally we would then reuse those scaling factors rather than recasting them for every element in a block - considering we are math bound here, we need to make sure that we eliminate as many instructions as possible.
There was a problem hiding this comment.
I tried a refactored implementation in 54797b3 but it did not show meaningful speedup:
NVTE_NVFP4_4OVER6=activations \
NVTE_NVFP4_4OVER6_ERR_MODE=<MAE|MSE> \
NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH=1 \
NVTE_NVFP4_DISABLE_RHT=1 \
NVTE_NVFP4_DISABLE_STOCHASTIC_ROUNDING=1 \
python3 benchmarks/linear/benchmark_grouped_linear.py --recipe nvfp4Extended Fast-Path Timing Table:
| m | k | n | recipe | num_gemms | baseline_ms | old_MAE_fast1_ms | refactor_MAE_fast1_ms | MAE_refactor_speedup | old_MSE_fast1_ms | refactor_MSE_fast1_ms | MSE_refactor_speedup |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 16384 | 7168 | 2048 | nvfp4 | 4 | 0.768440 | 0.979436 | 0.984238 | 0.995x | 0.978000 | 0.984061 | 0.994x |
| 32768 | 7168 | 2048 | nvfp4 | 4 | 1.246045 | 1.648099 | 1.643575 | 1.003x | 1.644019 | 1.645252 | 0.999x |
| 65536 | 7168 | 2048 | nvfp4 | 4 | 2.226334 | 2.966817 | 2.977467 | 0.996x | 2.974658 | 2.990481 | 0.995x |
| 98304 | 7168 | 2048 | nvfp4 | 4 | 3.220651 | 4.395422 | 4.396148 | 1.000x | 4.401849 | 4.409521 | 0.998x |
| 16384 | 7168 | 2048 | nvfp4 | 8 | 0.999235 | 1.176896 | 1.182225 | 0.995x | 1.176988 | 1.258703 | 0.935x |
| 32768 | 7168 | 2048 | nvfp4 | 8 | 1.428313 | 1.848563 | 1.853778 | 0.997x | 1.846523 | 1.857125 | 0.994x |
| 65536 | 7168 | 2048 | nvfp4 | 8 | 2.400536 | 3.148987 | 3.152150 | 0.999x | 3.150669 | 3.153690 | 0.999x |
| 98304 | 7168 | 2048 | nvfp4 | 8 | 3.387845 | 4.552140 | 4.549876 | 1.000x | 4.554252 | 4.560287 | 0.999x |
the refactor-vs-old geomean was:
MAE: 0.9983x
MSE: 0.9889x
combined: 0.9936x
This refactor is essentially common instruction lifting/reuse, and it keeps the same core PTX instructions (cvt.rn.f16x2.e4m3x2, cvt.rn.f16x2.e2m1x2, mul.rn.f16x2) rather than introducing a different PTX operation. I think the compiler can already do this in the old implementation but I am not sure.
There was a problem hiding this comment.
Kernel level benchmark does not show speedup either:
| mode | kernel | err | metric | old us | refactor us | old/refactor |
|---|---|---|---|---|---|---|
| 1d | nvfp4 | - | strict | 103.260 | 103.259 | 1.000x |
| 1d | 4over6 | MAE | strict | 799.852 | 800.892 | 0.999x |
| 1d | 4over6 | MAE | fast | 288.834 | 289.100 | 0.999x |
| 1d | 4over6 | MSE | strict | 829.294 | 829.693 | 1.000x |
| 1d | 4over6 | MSE | fast | 287.426 | 289.148 | 0.994x |
| 2d | nvfp4 | - | strict | 126.692 | 126.737 | 1.000x |
| 2d | 4over6 | MAE | strict | 847.700 | 847.070 | 1.001x |
| 2d | 4over6 | MAE | fast | 306.286 | 306.012 | 1.001x |
| 2d | 4over6 | MSE | strict | 866.949 | 867.081 | 1.000x |
| 2d | 4over6 | MSE | fast | 306.748 | 305.842 | 1.003x |
Script: 83e2308 , shape (16384, 6144), with --warmup 20 --iters 2000
There was a problem hiding this comment.
I keep the refactoring in beaed67 .The perf behavior of the explicit pattern is more robust to compiler optimizations.
Signed-off-by: Ziang Li <ziangli@umich.edu>
This reverts commit 54797b3. Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Description
@HumansAnd
This PR adds a fast NVFP4 4over6 candidate-error path that compares map-to-4 and map-to-6 candidates in the E4M3-scaled E2M1 product domain after the E2M1 and E4M3 values are rounded through FP16 conversion.
Earlier revisions exposed this as two additional public error modes,The interface has been refactored so the public 4over6 error mode remains the selection metric,MAE_FP16andMSE_FP16.MAEorMSE, whileNVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH=1selects the faster FP16 product-domain implementation for that metric.Fixes # (issue): N/A
Motivation:
cvt.rn.f16x2.e2m1x2andcvt.rn.f16x2.e4m3x2). This lets the 4over6 kernel construct the candidate E4M3 x E2M1 product with fewer scalar FP32 operations.Type of change
Changes
Please list the changes introduced in this PR:
NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATHto select the fast FP16 product-domain candidate-error implementation for NVFP4 4over6.NVTE_NVFP4_4OVER6_ERR_MODEfocused on the selection metric,MAEorMSE, rather than encoding implementation details in additional public modes.NVTEQuantizationConfigas a dedicated boolean config.rnmodifier and FP16 multiply.MAE/MSE, E4M3 max 448 / 256, and error fast-math enabled / disabled.NVTE_NVFP4_4OVER6_ERR_MODEandNVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH.Testing note:
NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH, the tests still require bitwise exactness against the PyTorch reference implementation. The reference emulates the FP16 conversion/multiply sequence that PyTorch does not expose directly, and the exact quantization tests use zero tolerance.Validation:
python3 -m pytest --tb=auto tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py -k 4over6 # 1824 passed, 1632 skipped, 432 deselected, 2 warningspre-commit run --all-files # PassedPerformance:
No backend implementation changed in the interface refactor, so the benchmark values below are unchanged from the original FP16 error-path implementation.
Commands for the 2D activation sweep:
# Baseline, no 4over6. NVTE_NVFP4_DISABLE_RHT=1 \ NVTE_NVFP4_DISABLE_STOCHASTIC_ROUNDING=1 \ python3 benchmarks/linear/benchmark_grouped_linear.py --recipe nvfp4Raw 2D activation grouped forward/backward timings, in ms per microbatch:
2D activation slowdown relative to baseline:
The fast error path consistently reduces 4over6 overhead compared with the default
MAE/MSEoriginal-domain error path in this sweep. We also see the same speedup trend on another grouped-linear NVFP4 recipe.Checklist: