Skip to content

Optimize NVFP4 4over6 candidate error path#3068

Open
zianglih wants to merge 11 commits into
NVIDIA:mainfrom
zianglih:nvfp4-4over6-fp16-error-modes
Open

Optimize NVFP4 4over6 candidate error path#3068
zianglih wants to merge 11 commits into
NVIDIA:mainfrom
zianglih:nvfp4-4over6-fp16-error-modes

Conversation

@zianglih
Copy link
Copy Markdown
Contributor

@zianglih zianglih commented Jun 1, 2026

Description

@HumansAnd
This PR adds a fast NVFP4 4over6 candidate-error path that compares map-to-4 and map-to-6 candidates in the E4M3-scaled E2M1 product domain after the E2M1 and E4M3 values are rounded through FP16 conversion.

Earlier revisions exposed this as two additional public error modes, MAE_FP16 and MSE_FP16. The interface has been refactored so the public 4over6 error mode remains the selection metric, MAE or MSE, while NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH=1 selects the faster FP16 product-domain implementation for that metric.

Fixes # (issue): N/A

Motivation:

  • The parent 4over6 PR, NVIDIA/TransformerEngine#2972, showed that 4over6 quantization can become compute-bound. In particular, the previous error path is bottlenecked by E2M1 dequantization and FP32 arithmetic instruction count.
  • Blackwell exposes dedicated PTX conversion instructions for both FP4 E2M1 and FP8 E4M3 into FP16 (cvt.rn.f16x2.e2m1x2 and cvt.rn.f16x2.e4m3x2). This lets the 4over6 kernel construct the candidate E4M3 x E2M1 product with fewer scalar FP32 operations.
  • E4M3 x E2M1 products can be represented exactly in FP16 for this use case, while BF16 does not have enough mantissa bits for the same guarantee.
  • The fast error path compares candidates in the E4M3 x E2M1 scaled range. The final error difference and accumulation are still FP32, matching the previous design, but candidate reconstruction uses the Blackwell FP16 conversion/multiply path instead of the heavier FP32 dequantization expression.
  • We scale the original input into the E4M3-scaled domain instead of fully decoding each candidate back to the original input domain. That applies the FP32 global scaling once to the original input; fully decoding both candidates would require per-element FP32 scaling on each decoded candidate, increasing the arithmetic cost.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Add NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH to select the fast FP16 product-domain candidate-error implementation for NVFP4 4over6.
  • Keep NVTE_NVFP4_4OVER6_ERR_MODE focused on the selection metric, MAE or MSE, rather than encoding implementation details in additional public modes.
  • Thread the 4over6 error fast-math option through NVTEQuantizationConfig as a dedicated boolean config.
  • Add the FP16 candidate-error path to the NVFP4 4over6 CUDA kernel using Blackwell FP16 conversion PTX with the rn modifier and FP16 multiply.
  • Extend the PyTorch reference implementation with an emulated FP16 candidate path. PyTorch does not expose this exact FP16 round-to-nearest conversion/multiply sequence, so the reference reconstructs the FP16 result from integer fields and checks bitwise exactness against the kernel.
  • Extend NVFP4 exact quantization tests to cover 4over6 MAE / MSE, E4M3 max 448 / 256, and error fast-math enabled / disabled.
  • Update the environment variable docs for NVTE_NVFP4_4OVER6_ERR_MODE and NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH.

Testing note:

  • The C++ operator test was intentionally not expanded for the FP16 product-domain path. That test suite checks that each 4over6 block matches either map-to-4 or map-to-6 exactly; it does not validate candidate selection. Since this PR changes error computation and candidate selection, the strict PyTorch reference tests provide the meaningful coverage.
  • Although this path is controlled by NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH, the tests still require bitwise exactness against the PyTorch reference implementation. The reference emulates the FP16 conversion/multiply sequence that PyTorch does not expose directly, and the exact quantization tests use zero tolerance.

Validation:

python3 -m pytest --tb=auto tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py -k 4over6
# 1824 passed, 1632 skipped, 432 deselected, 2 warnings
pre-commit run --all-files
# Passed

Performance:

No backend implementation changed in the interface refactor, so the benchmark values below are unchanged from the original FP16 error-path implementation.

Commands for the 2D activation sweep:

# Baseline, no 4over6.
NVTE_NVFP4_DISABLE_RHT=1 \
NVTE_NVFP4_DISABLE_STOCHASTIC_ROUNDING=1 \
python3 benchmarks/linear/benchmark_grouped_linear.py --recipe nvfp4
# 4over6 activation modes.
NVTE_NVFP4_4OVER6=activations \
NVTE_NVFP4_4OVER6_ERR_MODE=<MAE|MSE> \
NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH=<0|1> \
NVTE_NVFP4_DISABLE_RHT=1 \
NVTE_NVFP4_DISABLE_STOCHASTIC_ROUNDING=1 \
python3 benchmarks/linear/benchmark_grouped_linear.py --recipe nvfp4

Raw 2D activation grouped forward/backward timings, in ms per microbatch:

err_mode err_use_fast_math m k n recipe num_gemms grouped_fwd_bwd_time_ms
baseline - 16384 7168 2048 nvfp4 4 0.768440
baseline - 32768 7168 2048 nvfp4 4 1.246045
baseline - 65536 7168 2048 nvfp4 4 2.226334
baseline - 98304 7168 2048 nvfp4 4 3.220651
baseline - 16384 7168 2048 nvfp4 8 0.999235
baseline - 32768 7168 2048 nvfp4 8 1.428313
baseline - 65536 7168 2048 nvfp4 8 2.400536
baseline - 98304 7168 2048 nvfp4 8 3.387845
MAE 0 16384 7168 2048 nvfp4 4 1.638763
MAE 0 32768 7168 2048 nvfp4 4 2.870079
MAE 0 65536 7168 2048 nvfp4 4 5.330869
MAE 0 98304 7168 2048 nvfp4 4 7.818854
MAE 0 16384 7168 2048 nvfp4 8 1.914825
MAE 0 32768 7168 2048 nvfp4 8 3.168478
MAE 0 65536 7168 2048 nvfp4 8 5.612699
MAE 0 98304 7168 2048 nvfp4 8 8.076360
MSE 0 16384 7168 2048 nvfp4 4 1.672350
MSE 0 32768 7168 2048 nvfp4 4 2.936419
MSE 0 65536 7168 2048 nvfp4 4 5.464571
MSE 0 98304 7168 2048 nvfp4 4 8.020864
MSE 0 16384 7168 2048 nvfp4 8 1.944005
MSE 0 32768 7168 2048 nvfp4 8 3.234871
MSE 0 65536 7168 2048 nvfp4 8 5.745626
MSE 0 98304 7168 2048 nvfp4 8 8.274712
MAE 1 16384 7168 2048 nvfp4 4 0.979436
MAE 1 32768 7168 2048 nvfp4 4 1.648099
MAE 1 65536 7168 2048 nvfp4 4 2.966817
MAE 1 98304 7168 2048 nvfp4 4 4.395422
MAE 1 16384 7168 2048 nvfp4 8 1.176896
MAE 1 32768 7168 2048 nvfp4 8 1.848563
MAE 1 65536 7168 2048 nvfp4 8 3.148987
MAE 1 98304 7168 2048 nvfp4 8 4.552140
MSE 1 16384 7168 2048 nvfp4 4 0.978000
MSE 1 32768 7168 2048 nvfp4 4 1.644019
MSE 1 65536 7168 2048 nvfp4 4 2.974658
MSE 1 98304 7168 2048 nvfp4 4 4.401849
MSE 1 16384 7168 2048 nvfp4 8 1.176988
MSE 1 32768 7168 2048 nvfp4 8 1.846523
MSE 1 65536 7168 2048 nvfp4 8 3.150669
MSE 1 98304 7168 2048 nvfp4 8 4.554252

2D activation slowdown relative to baseline:

m k n recipe num_gemms MAE, fast=0 MSE, fast=0 MAE, fast=1 MSE, fast=1
16384 7168 2048 nvfp4 4 2.133x 2.176x 1.275x 1.273x
32768 7168 2048 nvfp4 4 2.303x 2.357x 1.323x 1.319x
65536 7168 2048 nvfp4 4 2.394x 2.455x 1.333x 1.336x
98304 7168 2048 nvfp4 4 2.428x 2.490x 1.365x 1.367x
16384 7168 2048 nvfp4 8 1.916x 1.945x 1.178x 1.178x
32768 7168 2048 nvfp4 8 2.218x 2.265x 1.294x 1.293x
65536 7168 2048 nvfp4 8 2.338x 2.393x 1.312x 1.312x
98304 7168 2048 nvfp4 8 2.384x 2.442x 1.344x 1.344x

The fast error path consistently reduces 4over6 overhead compared with the default MAE / MSE original-domain error path in this sweep. We also see the same speedup trend on another grouped-linear NVFP4 recipe.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

zianglih added 2 commits May 31, 2026 23:20
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label Jun 1, 2026
@zianglih zianglih marked this pull request as ready for review June 1, 2026 07:19
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Jun 1, 2026

Greptile Summary

This PR adds a fast NVFP4 4over6 candidate-error path controlled by NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH=1. Instead of decoding candidates fully back to the original input domain in FP32, the new path uses Blackwell PTX (cvt.rn.f16x2.e2m1x2, cvt.rn.f16x2.e4m3x2, mul.rn.f16x2) to compute the E2M1 × E4M3 product in FP16 and compare candidates in the E4M3-scaled domain, roughly halving 4over6 overhead.

  • CUDA kernel (quantize_4over6_nvfp4.cuh): introduces FP16ErrorScalePair, compute_fp16_error_scales, f16x2_scaled_to_float2, and accumulate_fp16_scaled_error_pair; removes the old non-rn fast_math arithmetic branch from accumulate_dequant_error and folds it cleanly into a compile-time if constexpr (Cfg::err_use_fast_math) dispatch.
  • Python reference (quantization_ref_nvfp4.py): emulates the FP16 multiplication sequence via integer bit manipulation in _ref_nvfp4_4over6_fp16_candidate; adds _sum_4over6_2d_error with an assert tile_len_y == 16 guard; threads nvfp4_4over6_err_use_fast_math through the quantizer constructor and all call sites.
  • Tests (test_nvfp4_quantize_exact.py): consolidates parametrization into a NVFP44Over6TestConfig dataclass; covers MAE/MSE × e4m3_max=448/256 × fast_math on/off; uses an env-var context manager that correctly saves and restores NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH.

Confidence Score: 5/5

Safe to merge; all paths verified correct and covered by zero-tolerance bitwise tests.

The PTX register layout in compute_fp16_error_scales is verified: map4 is packed into bits [7:0] of fp8_pair, cvt.rn.f16x2.e4m3x2 converts lower→lower/upper→upper, and the two mov.b32 replications correctly broadcast each FP16 scale into an f16x2 for mul.rn.f16x2. The Python reference _ref_nvfp4_4over6_fp16_candidate correctly decomposes E2M1 and E4M3 bit fields, and the subnormal FP16 branch is dead code in practice (minimum product exponent is -10, well above the -14 threshold). All error accumulation differences and the tree-reduction ordering match between CUDA and Python. The 1824-passed, zero-tolerance test run provides strong bitwise correctness evidence.

No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh Adds FP16 fast-error path using Blackwell PTX (cvt.rn.f16x2.e4m3x2, cvt.rn.f16x2.e2m1x2, mul.rn.f16x2); removes the old non-rn fast_math branch from accumulate_dequant_error; bit-packing and PTX register layout verified correct.
transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py Adds _ref_nvfp4_4over6_fp16_candidate and _sum_4over6_2d_error helpers; threads nvfp4_4over6_err_use_fast_math through the reference quantizer; NaN E4M3 guard is dead code in practice (scale values are always ≤ 0x7E/448).
tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py Replaces three separate parametrize decorators with a consolidated NVFP44Over6TestConfig dataclass; adds environment-variable context manager; correctly covers MAE/MSE × e4m3_max=448/256 × fast_math=on/off combinations.
docs/envvars.rst Updates NVTE_NVFP4_4OVER6_ERR_MODE and NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH descriptions to accurately reflect the refactored interface.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["cvt_fp32_to_fp4_8x_with_error()"] --> B{err_use_fast_math?}
    B -->|false| C["accumulate_dequant_error()\nFP32 decode: dequant x sf x global_amax / denom\n__fdiv_rn / __fmul_rn / __fsub_rn"]
    B -->|true| D["accumulate_fp16_scaled_error_pair()\nFP16: mul.rn.f16x2(q_h2, scale_h2)\nthen cvt.f32.f16 to FP32 diff"]
    D --> E["f16x2_scaled_to_float2()\nmul.rn.f16x2 then cvt.f32.f16 x2"]
    D --> F["original = x x global_encode_scale\n(__fmul_rn)"]
    E --> G["diff = candidate minus original"]
    F --> G
    G --> H["compute_error_rn: abs diff or diff squared"]
    C --> H
    H --> I["FP32 error accumulation\n(__fadd_rn)"]
Loading

Reviews (7): Last reviewed commit: "Drop scripts" | Re-trigger Greptile

Comment on lines +517 to +524
def _sum_4over6_2d_error(err: torch.Tensor, tile_len_y: int) -> torch.Tensor:
"""Reduce 16 row errors in the same tree order as the CUDA warp reduction."""
rows = err.view(err.shape[0] // tile_len_y, tile_len_y, err.shape[1], 1)
rows = rows.squeeze(-1)
rows = rows[:, 0:8, :] + rows[:, 8:16, :]
rows = rows[:, 0:4, :] + rows[:, 4:8, :]
rows = rows[:, 0:2, :] + rows[:, 2:4, :]
return (rows[:, 0, :] + rows[:, 1, :]).unsqueeze(-1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Hardcoded 16-row tree reduction without enforcement

The function accepts tile_len_y as a parameter and uses it for the initial view, but the subsequent binary-tree steps unconditionally slice at [8:16], [4:8], [2:4], [0:2], which are only correct when tile_len_y == 16. A caller passing tile_len_y = 32 would silently sum only the first 16 rows and discard the remaining 16; a caller passing tile_len_y = 8 would produce a shape mismatch on the first add. Adding an assert tile_len_y == 16 at the top would make the contract explicit and catch future regressions immediately.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in 5136b58. _sum_4over6_2d_error now asserts tile_len_y == 16 before the tree reduction, and I also updated the reference docstring/comment to distinguish original-domain MAE/MSE from E4M3-scaled FP16 product-domain MAE_FP16/MSE_FP16.

Signed-off-by: Ziang Li <ziangli@umich.edu>
@zianglih
Copy link
Copy Markdown
Contributor Author

zianglih commented Jun 2, 2026

Since this is identical to the orginal mode in the algebraic sense, may be a better interface design it to let this mode replace the original fast math modes, instead of extending error modes.

zianglih added 4 commits June 1, 2026 22:08
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
@zianglih zianglih changed the title Add FP16 error modes for NVFP4 4over6 Add fast FP16 error path for NVFP4 4over6 Jun 2, 2026
@zianglih
Copy link
Copy Markdown
Contributor Author

zianglih commented Jun 2, 2026

Interface refactor update:

I refactored the core 4over6 error interface so the selection metric and implementation fast path are separate knobs:

  • NVTE_NVFP4_4OVER6_ERR_MODE remains the candidate-selection metric and only accepts MAE / MSE.
  • NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH=1 enables the FP16 product-domain error path for the selected metric.
  • The C++ quantization config mirrors this split with NVTENVFP44Over6Mode for disabled / min-MAE / min-MSE, plus a separate nvfp4_4over6_err_use_fast_math boolean.

This replaces the earlier public MAE_FP16 / MSE_FP16 mode design. The backend FP16 error implementation and benchmark numbers did not change in this interface refactor; the PR body now relabels the commands and tables around ERR_MODE + ERR_USE_FAST_MATH instead.

@zianglih zianglih changed the title Add fast FP16 error path for NVFP4 4over6 Optimize NVFP4 4over6 candidate error path Jun 2, 2026
@zianglih
Copy link
Copy Markdown
Contributor Author

zianglih commented Jun 2, 2026

@greptile-apps I have updated PR body, review again.

@zianglih
Copy link
Copy Markdown
Contributor Author

zianglih commented Jun 2, 2026

FlashInfer PR that implements the same contract:

@ptrendx
Copy link
Copy Markdown
Member

ptrendx commented Jun 2, 2026

Do you have any data on how different the result is between the regular path and this new fast path? Maybe we could just make it a default path rather than introduce another env variable?

@zianglih
Copy link
Copy Markdown
Contributor Author

zianglih commented Jun 2, 2026

Hi @ptrendx ,

I have some results using this script:

"""Compare NVFP4 4over6 E4M3 scales with and without error fast math."""

import os
from contextlib import contextmanager

import torch
from transformer_engine.pytorch import NVFP4Quantizer
import transformer_engine_torch as tex


M, K = 98304, 7168


@contextmanager
def _error_fast_math(enabled: bool):
    old_value = os.environ.get("NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH")
    os.environ["NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH"] = "1" if enabled else "0"
    try:
        yield
    finally:
        if old_value is None:
            os.environ.pop("NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH", None)
        else:
            os.environ["NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH"] = old_value


def _quantize_scale_bytes(
    x: torch.Tensor,
    err_mode: str,
    err_fast_math: bool,
    row_scaled: bool,
    with_2d_quantization: bool,
    nvfp4_e4m3_max: int,
) -> torch.Tensor:
    quantizer = NVFP4Quantizer(
        fp4_dtype=tex.DType.kFloat4E2M1,
        rowwise=True,
        columnwise=False,
        with_amax_reduction=False,
        amax_reduction_group=None,
        with_rht=False,
        with_post_rht_amax=False,
        with_2d_quantization=with_2d_quantization,
        row_scaled_nvfp4=row_scaled,
        nvfp4_use_4over6=True,
        nvfp4_e4m3_max=nvfp4_e4m3_max,
        nvfp4_4over6_err_mode=err_mode,
    )
    with _error_fast_math(err_fast_math):
        quantized = quantizer(x)
    assert quantized._rowwise_scale_inv is not None
    return quantized._rowwise_scale_inv.contiguous().view(torch.uint8)


def _compare_e4m3(
    x: torch.Tensor,
    dtype_name: str,
    scale_mode: str,
    row_scaled: bool,
    quant_mode: str,
    with_2d_quantization: bool,
    nvfp4_e4m3_max: int,
) -> None:
    for err_mode in ("MAE", "MSE"):
        regular = _quantize_scale_bytes(
            x, err_mode, False, row_scaled, with_2d_quantization, nvfp4_e4m3_max
        )
        fast = _quantize_scale_bytes(
            x, err_mode, True, row_scaled, with_2d_quantization, nvfp4_e4m3_max
        )
        same = torch.count_nonzero(regular == fast).item()
        total = regular.numel()
        print(
            f"{scale_mode:>6} {quant_mode:>5} {nvfp4_e4m3_max:8d} "
            f"{dtype_name:>5} {err_mode:>3} "
            f"{100.0 * same / total:12.6f} {total - same:15d} {total}"
        )


def main():
    torch.set_grad_enabled(False)
    print(f"shape=({M}, {K}), 1d_e4m3_values={M * K // 16}")
    print("scale quant e4m3_max dtype mode same_e4m3_pct different_e4m3 total_e4m3")
    for scale_mode, row_scaled, quant_mode, with_2d_quantization in (
        ("tensor", False, "1d", False),
        ("tensor", False, "2d", True),
        ("row", True, "1d", False),
    ):
        for nvfp4_e4m3_max in (256, 448):
            for dtype, dtype_name in ((torch.bfloat16, "bf16"), (torch.float16, "fp16")):
                torch.manual_seed(1234)
                x = torch.randn((M, K), dtype=dtype, device="cuda")
                _compare_e4m3(
                    x,
                    dtype_name,
                    scale_mode,
                    row_scaled,
                    quant_mode,
                    with_2d_quantization,
                    nvfp4_e4m3_max,
                )
                del x
                torch.cuda.empty_cache()


if __name__ == "__main__":
    main()

The agreement rate is consistently >99.9% on random M, K = 98304, 7168:

shape=(98304, 7168), 1d_e4m3_values=44040192
scale quant e4m3_max dtype mode same_e4m3_pct different_e4m3 total_e4m3
tensor    1d      256  bf16 MAE    99.980177            8730 44040192
tensor    1d      256  bf16 MSE    99.993304            2949 44040192
tensor    1d      256  fp16 MAE    99.995965            1777 44040192
tensor    1d      256  fp16 MSE    99.998940             467 44040192
tensor    1d      448  bf16 MAE    99.931292           30259 44040192
tensor    1d      448  bf16 MSE    99.977966            9704 44040192
tensor    1d      448  fp16 MAE    99.998719             564 44040192
tensor    1d      448  fp16 MSE    99.999782              96 44040192
tensor    2d      256  bf16 MAE    99.999964              16 44040192
tensor    2d      256  bf16 MSE    99.999782              96 44040192
tensor    2d      256  fp16 MAE   100.000000               0 44040192
tensor    2d      256  fp16 MSE    99.999927              32 44040192
tensor    2d      448  bf16 MAE    99.999964              16 44040192
tensor    2d      448  bf16 MSE    99.999455             240 44040192
tensor    2d      448  fp16 MAE   100.000000               0 44040192
tensor    2d      448  fp16 MSE   100.000000               0 44040192
   row    1d      256  bf16 MAE    99.977505            9907 44040192
   row    1d      256  bf16 MSE    99.993213            2989 44040192
   row    1d      256  fp16 MAE    99.996885            1372 44040192
   row    1d      256  fp16 MSE    99.999255             328 44040192
   row    1d      448  bf16 MAE    99.983867            7105 44040192
   row    1d      448  bf16 MSE    99.996428            1573 44040192
   row    1d      448  fp16 MAE    99.997593            1060 44040192
   row    1d      448  fp16 MSE    99.999505             218 44040192

We did not introduce any extra env var in this PR. NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH was already there previously but for controlling rounding modifiers in dequant arithmetic instrcutions (which did not lead to noticeable speedup). We just replace that backend with a new implementation.

I do think we can make NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH =1 the default but I am not sure if we should make this specific contract the only kept 4over6 implementation. Also in the future we may land other performance improvements which may change the contract numerics. I think it is better to support the canonical 4over6 contract from the original paper by NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH=0, and use NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH =1 to land this contract and potential future improvements.

template <typename Cfg, int E4M3_MAX>
__device__ __forceinline__ uint32_t cvt_fp32_to_fp4_8x_with_error(const float (&x)[8],
const float block_scale_inverse,
__device__ __forceinline__ float2 e2m1x2_scaled_e4m3_to_float2(const uint32_t e2m1_byte,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we only compute 2 values in this function? Also, we don't use half of the scale_h2 here. Why don't we instead try to convert here values from both of the branches (so both 4 and 6 would be there, the scaling factors for both of these branches would be converted in a single instruction). Ideally we would then reuse those scaling factors rather than recasting them for every element in a block - considering we are math bound here, we need to make sure that we eliminate as many instructions as possible.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried a refactored implementation in 54797b3 but it did not show meaningful speedup:

NVTE_NVFP4_4OVER6=activations \
NVTE_NVFP4_4OVER6_ERR_MODE=<MAE|MSE> \
NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH=1 \
NVTE_NVFP4_DISABLE_RHT=1 \
NVTE_NVFP4_DISABLE_STOCHASTIC_ROUNDING=1 \
python3 benchmarks/linear/benchmark_grouped_linear.py --recipe nvfp4

Extended Fast-Path Timing Table:

m k n recipe num_gemms baseline_ms old_MAE_fast1_ms refactor_MAE_fast1_ms MAE_refactor_speedup old_MSE_fast1_ms refactor_MSE_fast1_ms MSE_refactor_speedup
16384 7168 2048 nvfp4 4 0.768440 0.979436 0.984238 0.995x 0.978000 0.984061 0.994x
32768 7168 2048 nvfp4 4 1.246045 1.648099 1.643575 1.003x 1.644019 1.645252 0.999x
65536 7168 2048 nvfp4 4 2.226334 2.966817 2.977467 0.996x 2.974658 2.990481 0.995x
98304 7168 2048 nvfp4 4 3.220651 4.395422 4.396148 1.000x 4.401849 4.409521 0.998x
16384 7168 2048 nvfp4 8 0.999235 1.176896 1.182225 0.995x 1.176988 1.258703 0.935x
32768 7168 2048 nvfp4 8 1.428313 1.848563 1.853778 0.997x 1.846523 1.857125 0.994x
65536 7168 2048 nvfp4 8 2.400536 3.148987 3.152150 0.999x 3.150669 3.153690 0.999x
98304 7168 2048 nvfp4 8 3.387845 4.552140 4.549876 1.000x 4.554252 4.560287 0.999x

the refactor-vs-old geomean was:

MAE:      0.9983x
MSE:      0.9889x
combined: 0.9936x

This refactor is essentially common instruction lifting/reuse, and it keeps the same core PTX instructions (cvt.rn.f16x2.e4m3x2, cvt.rn.f16x2.e2m1x2, mul.rn.f16x2) rather than introducing a different PTX operation. I think the compiler can already do this in the old implementation but I am not sure.

Copy link
Copy Markdown
Contributor Author

@zianglih zianglih Jun 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kernel level benchmark does not show speedup either:

mode kernel err metric old us refactor us old/refactor
1d nvfp4 - strict 103.260 103.259 1.000x
1d 4over6 MAE strict 799.852 800.892 0.999x
1d 4over6 MAE fast 288.834 289.100 0.999x
1d 4over6 MSE strict 829.294 829.693 1.000x
1d 4over6 MSE fast 287.426 289.148 0.994x
2d nvfp4 - strict 126.692 126.737 1.000x
2d 4over6 MAE strict 847.700 847.070 1.001x
2d 4over6 MAE fast 306.286 306.012 1.001x
2d 4over6 MSE strict 866.949 867.081 1.000x
2d 4over6 MSE fast 306.748 305.842 1.003x

Script: 83e2308 , shape (16384, 6144), with --warmup 20 --iters 2000

Copy link
Copy Markdown
Contributor Author

@zianglih zianglih Jun 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I keep the refactoring in beaed67 .The perf behavior of the explicit pattern is more robust to compiler optimizations.

zianglih added 4 commits June 2, 2026 16:36
Signed-off-by: Ziang Li <ziangli@umich.edu>
This reverts commit 54797b3.

Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants