Skip to content

fix unfused padding causal sdpa#3063

Open
hungryGeek16 wants to merge 472 commits into
NVIDIA:mainfrom
hungryGeek16:fix-unfused-padding-causal-sdpa
Open

fix unfused padding causal sdpa#3063
hungryGeek16 wants to merge 472 commits into
NVIDIA:mainfrom
hungryGeek16:fix-unfused-padding-causal-sdpa

Conversation

@hungryGeek16
Copy link
Copy Markdown

@hungryGeek16 hungryGeek16 commented May 31, 2026

Adds a targeted PyTorch SDPA fallback for unfused THD padding_causal self-attention so TransformerEngine does not materialize the full quadratic padding/causal mask. Includes a regression test that fails if get_full_mask is called on this path.

ptrendx and others added 30 commits February 14, 2025 17:10
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
…x FP8 related codes (NVIDIA#1468)

* add prob permute; fix fp8tensor

Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* revert unnecessary changes in UT

Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>

* remove unnecessary probs dtype convert

Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>

* keep the output nums if probs is not provided

Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* refine the doc string

Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>

* fix lint

Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>

* use fp32 compute type

Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>

* style fix

Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>

* fix empty input return

Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>

* separate prob related functions out

Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Xin Yao <xiny@nvidia.com>
Co-authored-by: Phuong Nguyen <phuonguyen@nvidia.com>
flax module with compute dtype inferred from the inputs

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
* Fix issues for MCore DDP.

Signed-off-by: Dennis Liu <denliu@nvidia.com>

* Remove force data release for CPU offloading.

Signed-off-by: Dennis Liu <denliu@nvidia.com>

* Add preserved attributeds.

Signed-off-by: Dennis Liu <denliu@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add main_grad to prevserved attributes.

Signed-off-by: Dennis Liu <denliu@nvidia.com>

* Change prepare_for_saving to original tensor and add .data to CPU hook.

Signed-off-by: Dennis Liu <denliu@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update.

Signed-off-by: Dennis Liu <denliu@nvidia.com>

* Fix for LayernormLinear in FP8.

Signed-off-by: Dennis Liu <denliu@nvidia.com>

---------

Signed-off-by: Dennis Liu <denliu@nvidia.com>
Co-authored-by: Xin Yao <xiny@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Fix typo

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* fix fuse_wgrad_accumulation for GroupedLinear

Signed-off-by: Xin Yao <xiny@nvidia.com>

* fix fuse_wgrad_accumulation for GroupedLinear

Signed-off-by: Xin Yao <xiny@nvidia.com>

* update tests

Signed-off-by: Xin Yao <xiny@nvidia.com>

---------

Signed-off-by: Xin Yao <xiny@nvidia.com>
Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
* Fix te sequential for older pytorch versions

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* FIxes

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

---------

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* commit some debug code

Signed-off-by: Xiaowei Ren <xren@nvidia.com>

* add more debug info

Signed-off-by: Xiaowei Ren <xren@nvidia.com>

* debug code commit and typo fix

Signed-off-by: Xiaowei Ren <xren@nvidia.com>

* a typo fix

Signed-off-by: Xiaowei Ren <xren@nvidia.com>

* remove debug info

Signed-off-by: Xiaowei Ren <xren@nvidia.com>

* do not return lse

Signed-off-by: Xiaowei Ren <xren@nvidia.com>

* add amax_per_step for quantizers of CP

Signed-off-by: Xiaowei Ren <xren@nvidia.com>

* fix FP8 + CP

Signed-off-by: Xiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* bug fix

Signed-off-by: Xiaowei Ren <xren@nvidia.com>

* bug fix

Signed-off-by: Xiaowei Ren <xren@nvidia.com>

* dtype fix

Signed-off-by: Xiaowei Ren <xren@nvidia.com>

* bug fix

Signed-off-by: Xiaowei Ren <xren@nvidia.com>

---------

Signed-off-by: Xiaowei Ren <xren@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Xiaowei Ren <xren@login-preos01.a51.clusters.nvidia.com>
…NVIDIA#1466)

Use same API in optimizer zero_grad as PyT optimizers

Signed-off-by: Tim Moon <tmoon@nvidia.com>
* fix

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>

* reshape inp

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>

---------

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
* minor fixes for attention

Signed-off-by: Charlene Yang <charleney@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Charlene Yang <charleney@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
…IA#1502)

* Fix a crash with module._apply(lambda t: t.cpu())

Signed-off-by: Guyue Huang <guyueh@nvidia.com>

* Add comments

Signed-off-by: Guyue Huang <guyueh@nvidia.com>

* Make sure tensor is moved to dst device before quantizer quantizes

Signed-off-by: Guyue Huang <guyueh@nvidia.com>

---------

Signed-off-by: Guyue Huang <guyueh@nvidia.com>
Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
* add remove_caches api

Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>

* Update transformer_engine/pytorch/tensor/float8_tensor.py

Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>

* explicit delete

Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* Added TMA alignment check to cast_fp8_1D

Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Use tensor const-ref instead of tensor const-ptr

Signed-off-by: Tim Moon <tmoon@nvidia.com>

---------

Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Tim Moon <tmoon@nvidia.com>
Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
)

* Skip context parallelism tests if not enough GPUs

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Apply suggestions from code review

Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>

---------

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
* delete extra tensor objects after restoring float8 tensors

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* nit fix

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* fix the leak in float8tensor and mxfloat8tensor classes

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* uncomment the fix

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix lint

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

---------

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* Fix quantized tensor shape

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* add shape to make_like; add test for chunk

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix typo from suggestion

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

---------

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
…rt (NVIDIA#1528)

Set flag in norm modules for Mcore sequence-parallel support

Signed-off-by: Tim Moon <tmoon@nvidia.com>
* Fix wheel install after src install

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix JAX imports

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* switch order of dirs for finding so

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Use existing dir src build

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix lint

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

---------

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
…IA#1548)

Don't set data to null

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
update cudnn-frontend to its new 1.11.0-rc

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
* Enable fp8_primary_weights for current scaling

Signed-off-by: kunlunl <kunlunl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Use different cast_master_weights_to_fp8 functions depending on the type of quantizer

Signed-off-by: kunlunl <kunlunl@nvidia.com>

* All amaxes of model_weights should participate in reduce-max

Signed-off-by: kunlunl <kunlunl@nvidia.com>

* Clear _high_precision_init_val automatically in cast_master_weights_to_fp8 function

Signed-off-by: kunlunl <kunlunl@nvidia.com>

* Merge all all-reduce on amaxes into one NCCL kernel

Signed-off-by: kunlunl <kunlunl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add unit tests for multi_tensor_compute_scale_and_scale_inv and preserve_high_precision_init_val

Signed-off-by: kunlunl <kunlunl@nvidia.com>

* Fix conflicts

Signed-off-by: kunlunl <kunlunl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add unit test for cast_master_weights_to_fp8

Signed-off-by: kunlunl <kunlunl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* use mock group to initialize fp8_autocast to avoid reduction of amax_history by fp8_autocast_exit

Signed-off-by: kunlunl <kunlunl@nvidia.com>

* Remove with_computing_amax and with_computing_scale

Signed-off-by: kunlunl <kunlunl@nvidia.com>

* Move replace_raw_data from QuantizedTensor to utils.py

Signed-off-by: kunlunl <kunlunl@nvidia.com>

* Remove allow_empty_output argument from nvte_compute_amax and set it always be true

Signed-off-by: kunlunl <kunlunl@nvidia.com>

* Rename import guard of recipe_common.cuh to be align with other import guards

Signed-off-by: kunlunl <kunlunl@nvidia.com>

* Add unit test for replace_raw_data

Signed-off-by: kunlunl <kunlunl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add test_replace_raw_data into qa/L0_pytorch_unittest/test.sh

Signed-off-by: kunlunl <kunlunl@nvidia.com>

* Minor changes in comments

Signed-off-by: kunlunl <kunlunl@nvidia.com>

* Add randomness to the unit test of replace_raw_data

Signed-off-by: kunlunl <kunlunl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* (Maybe need revert) Add tex.quantize_to_fragment

Signed-off-by: kunlunl <kunlunl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* (Maybe needsto rrevert) Use nvte_quantize_noop in quantize_to_fragment

Signed-off-by: kunlunl <kunlunl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix lint error

Signed-off-by: kunlunl <kunlunl@nvidia.com>

* Move high_precision_init_val test and replace_raw_data test to test_sanity.py

Signed-off-by: kunlunl <kunlunl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Remove test_fp8_model_init.py and test_replace_raw_data.py

Signed-off-by: kunlunl <kunlunl@nvidia.com>

* Remove cast_master_weights_to_fp8 and replace_raw_data from __all__ of tensor.__init__.py

Signed-off-by: kunlunl <kunlunl@nvidia.com>

* Move FP8 casting logic back from C++ tex funcs to Python

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Remove unimplemented function from header

Signed-off-by: Tim Moon <tmoon@nvidia.com>

---------

Signed-off-by: kunlunl <kunlunl@nvidia.com>
Signed-off-by: Kunlun Li <94586211+kunlunl@users.noreply.github.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: Tim Moon <tmoon@nvidia.com>
* fix dtypes of fused_attn_bwd in CP+A2A

Signed-off-by: Xiaowei Ren <xren@nvidia.com>

* fix dtypes of fused_attn_bwd in CP+P2P

Signed-off-by: Xiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix amax_per_step

Signed-off-by: Xiaowei Ren <xren@nvidia.com>

* clone scaling factors of fwd quantizers

Signed-off-by: Xiaowei Ren <xren@nvidia.com>

* fix fwd quantizers of CP+P2P

Signed-off-by: Xiaowei Ren <xren@nvidia.com>

* minor change

Signed-off-by: Xiaowei Ren <xren@nvidia.com>

* dequantize fp8 out in CP unit test

Signed-off-by: Xiaowei Ren <xren@nvidia.com>

* delete redundant None in FusedAttnFunc bwd

Signed-off-by: Xiaowei Ren <xren@nvidia.com>

---------

Signed-off-by: Xiaowei Ren <xren@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* Update usage of weightmat before saving for backward

Signed-off-by: Guyue Huang <guyueh@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix for layernorm mlp

Signed-off-by: Guyue Huang <guyueh@nvidia.com>

---------

Signed-off-by: Guyue Huang <guyueh@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
import te before te_jax

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Do not suppress MXFP8 norm in Python wrapper func

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Support FP8 current scaling in tex norm functions

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Use single envvar to enable cuDNN MXFP8 norm kernels

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Debug compilation error

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Fix compilation error

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Fix full-tile requirement for MXFP8 norm kernels

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Remove unused imports

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Add missing imports

Signed-off-by: Tim Moon <tmoon@nvidia.com>

---------

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* Fix mxfp8 columnwise data missing when switching from validation to training

Signed-off-by: Guyue Huang <guyueh@login-preos02.a51.clusters.nvidia.com>

* Fix when you interleave training and inference

Signed-off-by: Guyue Huang <guyueh@login-preos02.a51.clusters.nvidia.com>

* refact

Signed-off-by: Guyue Huang <guyueh@login-preos02.a51.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* rm useless code

Signed-off-by: Guyue Huang <guyueh@login-preos02.a51.clusters.nvidia.com>

* Update transformer_engine/pytorch/module/base.py

Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: guyueh1 <140554423+guyueh1@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix linter warnings

Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>

---------

Signed-off-by: Guyue Huang <guyueh@login-preos02.a51.clusters.nvidia.com>
Signed-off-by: guyueh1 <140554423+guyueh1@users.noreply.github.com>
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: Guyue Huang <guyueh@login-preos02.a51.clusters.nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
ptrendx and others added 19 commits April 24, 2026 15:13
…ed quantization kernel (NVIDIA#2921)

Fix the race in the dbias computation

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
fix fp8 and is_bwd_fp8 relationship

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Fix FA4 selection when FA3 is unavailable.

Signed-off-by: Björn Buschkämper <bjoern.buschkaemper@gmail.com>
…ed quantization kernel (NVIDIA#2921)

Fix the race in the dbias computation

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
…A#2922)

* remove ctype to eliminate memory usage from the cudnn kernel

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>

* Remove c_dtype from fusible ops test

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

---------

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* [Common, PyTorch] Add triton mHC kernels & pytorch operators

Signed-off-by: Kaining Zhong <kainingz@nvidia.com>

* fix

Signed-off-by: Kaining Zhong <kainingz@nvidia.com>

* nit

Signed-off-by: Kaining Zhong <kainingz@nvidia.com>

* make linter happy

Signed-off-by: Kaining Zhong <kainingz@nvidia.com>

* nit

Signed-off-by: Kaining Zhong <kainingz@nvidia.com>

* ah OK

Signed-off-by: Kaining Zhong <kainingz@nvidia.com>

* new configs to improve perf

Signed-off-by: Kaining Zhong <kainingz@nvidia.com>

* add APIs to docs

Signed-off-by: Kaining Zhong <kainingz@nvidia.com>

* fix typos, check deterministic, refactor

Signed-off-by: Kaining Zhong <kainingz@nvidia.com>

* fix

Signed-off-by: Kaining Zhong <kainingz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* reset rng for all tests

Signed-off-by: Kaining Zhong <kainingz@nvidia.com>

* add docstring

Signed-off-by: Kaining Zhong <kainingz@nvidia.com>

* fix api doc

Signed-off-by: Kaining Zhong <kainingz@nvidia.com>

* whoops

Signed-off-by: Kaining Zhong <kainingz@nvidia.com>

* grad_x doesn't have to zero

Signed-off-by: Kaining Zhong <kainingz@nvidia.com>

* nit

Signed-off-by: Kaining Zhong <kainingz@nvidia.com>

* nit

Signed-off-by: Kaining Zhong <kainingz@nvidia.com>

* force pytorch to not use bf16 for reduction

Signed-off-by: Kaining Zhong <kainingz@nvidia.com>

* use TE's general_gemm instead

Signed-off-by: Kaining Zhong <kainingz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Looks like this is how to make TE use fp32 acc

Signed-off-by: Kaining Zhong <kainingz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
…True (NVIDIA#2936)

* fix

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add test

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* cleanup

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* zero_out should also be tested

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>

---------

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: root <root@gb-nvl-059-compute03.nvidia.com>
…NVIDIA#2924)

* Fix contiguous path for k=2880

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* format

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Review suggestion from @Oleg-Goncharov

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add test for swizzle + padding fusion

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Address review comments

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

---------

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
…IDIA#2929)

* Avoid removing usages from quantized weight in linear op

Quantized weight tensor may be used across steps, so removing a usage is not safe.

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Tweak test to catch bug when alternating train and infer steps

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Avoid removing usages from quantized weights in grouped linear op

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Restore pre-forward quantizer config in ops

Turns out we still need this in case the quantizer is used before the forward, e.g. in previous ops or CPU offloading.

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Blindly preserve quantizer usages in quantized weight params.

Signed-off-by: Tim Moon <tmoon@nvidia.com>

---------

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
…n expert (NVIDIA#2947)

Add workaround for cuteDSL stride requirement for zero token expert

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
…pace for THD sequences (NVIDIA#2522)

* Get seqlens and offsets in O(N) space instead of O(N*N) space

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Re enable fast causal path

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>

* Fix: seqoffsets calculation for THD

Signed-off-by: Kshitij Janardan Lakhani <klakhani@nvidia.com>

* Clean up code. Add new comments. Fix unecessary pasing of seg pos to the seqoffsets calculation API

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Optimize and fix the slow O(T*T) path for seqlens and seqoffsets calculation for THD non-cp and Cp p2p ring
    - Newer path is O(T*max_segments) per seq
    - Newer path works well with CP p2p ring

    Fix BRCM cross attn by routing to new slow path rather than fast causal path

Signed-off-by: Kshitij Janardan Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix lint failure

Signed-off-by: Kshitij Janardan Lakhani <klakhani@nvidia.com>

---------

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Janardan Lakhani <klakhani@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Kshitij  Janardan Lakhani <klakhani@login-ptyche02.ptyche.clusters.nvidia.com>
Co-authored-by: JAX Toolbox <jax@nvidia.com>
…NVIDIA#2948)

* Switch to cuDNN-FE min version 1.23.0 to enable fused grouped MLP

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix tests

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

---------

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
…IDIA#2942)

* accumulate bias in fp32 instead of bf16 in ref impl dbias to avoid accumulated numerical error

Signed-off-by: tdophung <tdophung@nvidia.com>
…VIDIA#2955)

* Better documentation for single param and envvar guard

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix doc

Signed-off-by: ksivamani <ksivamani@nvidia.com>

* Fix test envvar

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

---------

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: ksivamani <ksivamani@nvidia.com>
@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label May 31, 2026
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 31, 2026

Greptile Summary

This PR avoids materializing a full quadratic [b, 1, sq, sk] attention mask for padding_causal unfused attention by routing through a new per-sample F.scaled_dot_product_attention(is_causal=True) loop (_forward_varlen_sdpa) when the conditions are met. It also carries several unrelated improvements: MXFP8 attention backend filtering in get_attention_backend, FP8EmulationFunc null-quantizer guards, MXFP8 documentation updates, a new cp_size field in AttentionParams, and cuda_version added to the attention run-config log.

  • Unfused padding_causal fast path: _use_varlen_sdpa gates a new _forward_varlen_sdpa method that iterates over batch items and calls PyTorch SDPA with is_causal=True, entirely bypassing get_full_mask. A new test (test_unfused_thd_padding_causal_uses_sdpa_without_full_mask) patches get_full_mask to assert it is never called and checks numerical correctness.
  • MXFP8 FP8-emulation fixes: FP8EmulationFunc now handles None quantizers in the S_quantizer/O_quantizer/dO_quantizer/dP_quantizer branches and applies a BSHD→SBHD permute when the quantizer is an MXFP8Quantizer.

Confidence Score: 3/5

The new fast path produces correct results under normal configuration, but silently computes wrong attention scores when NVTE_APPLY_QK_LAYER_SCALING=1 is in use.

The new varlen-SDPA branch passes self.softmax_scale to _forward_varlen_sdpa instead of the locally-modified scale variable. Any deployment with NVTE_APPLY_QK_LAYER_SCALING=1 and padding_causal self-attention will silently receive incorrect attention output from the new fast path, with no error or warning.

transformer_engine/pytorch/attention/dot_product_attention/backends.py — specifically the scale argument at the _forward_varlen_sdpa call site.

Important Files Changed

Filename Overview
transformer_engine/pytorch/attention/dot_product_attention/backends.py Adds _use_varlen_sdpa / _forward_varlen_sdpa fast path for padding_causal; passes self.softmax_scale instead of the locally-modified scale variable, silently breaking apply_qk_layer_scaling.
transformer_engine/pytorch/attention/dot_product_attention/utils.py Adds MXFP8 backend filtering, cp_size field to AttentionParams, cuda_version to run_config, and refactors FA3/FA4 SM90 preference logic; changes appear correct.
tests/pytorch/attention/test_attention.py Adds test_unfused_thd_padding_causal_uses_sdpa_without_full_mask which verifies both correctness and that get_full_mask is not called on the new fast path.
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py Documentation updates for MXFP8 recipe combinations; no functional logic changes.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[UnfusedDotProductAttention.forward] --> B{Convert input to sbhd}
    B --> C{padding in mask type and mask is None?}
    C -->|Yes| D[get_padding_mask: build mask from cu_seqlens]
    C -->|No| E[use existing mask]
    D --> F[_use_varlen_sdpa?]
    E --> F
    F -->|YES: padding_causal, self-attn, no bias, no dropout, no fp8, no alibi| G[_forward_varlen_sdpa: batch loop with SDPA is_causal=True, no full mask allocation]
    F -->|NO| H[get_full_mask: materialize full b x sq x sk mask]
    G --> I[_format_context: convert b,h,sq,d to output layout]
    H --> J[FP8 emulation if needed]
    J --> K[baddbmm QK matmul]
    K --> L[FusedScaleMaskSoftmax]
    L --> M[dropout + AV bmm]
    M --> N[reshape to output format]
    I --> Z[output tensor]
    N --> Z
Loading

Reviews (1): Last reviewed commit: "Avoid full mask allocation in unfused pa..." | Re-trigger Greptile

Comment on lines +575 to +585
return self._forward_varlen_sdpa(
query_layer,
key_layer,
value_layer,
q_format,
batch_size,
max_seqlen_q,
cu_seqlens_q,
attention_mask,
self.softmax_scale,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Wrong scale passed to fast-path SDPA

self.softmax_scale is passed here, but the locally-modified scale variable (which already incorporates the apply_qk_layer_scaling division by layer_number) was computed just a few lines above and is the value that should be forwarded. When NVTE_APPLY_QK_LAYER_SCALING=1 is set, scale = self.softmax_scale / self.layer_number, but the new varlen-SDPA path silently uses the unscaled value, producing wrong attention scores for every call that hits this branch with that env var enabled.

Suggested change
return self._forward_varlen_sdpa(
query_layer,
key_layer,
value_layer,
q_format,
batch_size,
max_seqlen_q,
cu_seqlens_q,
attention_mask,
self.softmax_scale,
)
if self._use_varlen_sdpa(
attn_mask_type,
attention_mask,
window_size,
core_attention_bias_type,
alibi_slopes,
fp8,
):
return self._forward_varlen_sdpa(
query_layer,
key_layer,
value_layer,
q_format,
batch_size,
max_seqlen_q,
cu_seqlens_q,
attention_mask,
scale,
)

Comment on lines +357 to +362
if attn_mask_type != "padding_causal":
return False
if window_size not in [None, (-1, 0), (-1, -1)]:
return False
if attn_mask_type == "padding_causal" and attention_mask is None:
return False
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 The first attn_mask_type != "padding_causal" guard already returns False for any value other than "padding_causal", so the later if attn_mask_type == "padding_causal" and … condition is always true at that point — the attn_mask_type == part is dead code and creates a misleading impression that the branch could be False.

Suggested change
if attn_mask_type != "padding_causal":
return False
if window_size not in [None, (-1, 0), (-1, -1)]:
return False
if attn_mask_type == "padding_causal" and attention_mask is None:
return False
if attn_mask_type != "padding_causal":
return False
if window_size not in [None, (-1, 0), (-1, -1)]:
return False
if attention_mask is None:
return False

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

@cyanguwa
Copy link
Copy Markdown
Collaborator

cyanguwa commented Jun 1, 2026

Thanks for the contribution @hungryGeek16, but it looks like your base branch may be out of date - could you rebase please? Thanks!

@jberchtold-nvidia jberchtold-nvidia removed their request for review June 1, 2026 15:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.