Skip to content

[PyTorch] Propagate skip_fp8_weight_update in GroupedLinear during FP8 CUDA graph capture#3065

Open
LeSingh1 wants to merge 2 commits into
NVIDIA:mainfrom
LeSingh1:fix-grouped-linear-skip-fp8-weight-update-3051
Open

[PyTorch] Propagate skip_fp8_weight_update in GroupedLinear during FP8 CUDA graph capture#3065
LeSingh1 wants to merge 2 commits into
NVIDIA:mainfrom
LeSingh1:fix-grouped-linear-skip-fp8-weight-update-3051

Conversation

@LeSingh1
Copy link
Copy Markdown
Contributor

Description

GroupedLinear.forward hardcoded None for skip_fp8_weight_update, so the FP8 graph-capture skip tensor was never forwarded during CUDA graph replay — the bug reported in #3051. Linear.forward handles this correctly: when FP8GlobalStateManager.fp8_graph_capturing() is true it reads quantization_state.skip_fp8_weight_update_tensor, forces is_first_microbatch = False, and threads the tensor into its forward args. GroupedLinear omitted that retrieval block and passed a literal None, even though the autograd _GroupedLinear.forward already unpacks and uses the flag (quantize_weight(..., skip_update_flag=skip_fp8_weight_update)).

Fix

Mirror the Linear reference path in GroupedLinear.forward: add the same fp8_graph_capturing() retrieval block and thread skip_fp8_weight_update into non_tensor_args (positionally aligned with the existing autograd unpack). FP8GlobalStateManager is already imported in the module.

Verification status

Runtime-unverified — developed without a CUDA GPU, so the FP8 / CUDA-graph numerics and the Nemotron-MoE gradient-suppression reproducer were not run. The change is a near-mechanical parity fix against the established Linear path (same accessor, same is_first_microbatch forcing, same threading), and the positional unpack lines up with _GroupedLinear.forward. A GPU regression test (comparing GroupedLinear FP8 grads under graph-replay-active vs delayed) belongs alongside the existing CUDA-graph tests but can't run on this platform — I'd ask a maintainer with a GPU to confirm numerics before merge, and I'm happy to add the test.

Developed with AI assistance.

Addresses #3051

@LeSingh1 LeSingh1 requested a review from ksivaman as a code owner May 31, 2026 23:21
@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label May 31, 2026
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 31, 2026

Greptile Summary

This PR fixes a regression where GroupedLinear.forward hardcoded None for skip_fp8_weight_update during FP8 CUDA graph capture, preventing the weight-update skip tensor from being threaded into the autograd function. The fix is a near-mechanical parity change against Linear.forward.

  • grouped_linear.py: Adds the fp8_graph_capturing() retrieval block (identical to Linear.forward) to read quantization_state.skip_fp8_weight_update_tensor and force is_first_microbatch = False when the tensor is set; the tensor is then placed at position 18 in non_tensor_args, which aligns precisely with the existing unpack in _GroupedLinear.forward.
  • test_cuda_graphs.py: Introduces _GroupedLinearWrapper to adapt GroupedLinear's 2-D token API to the test harness's 3-D input, wires it into _test_cuda_graphs, and adds a dedicated parametrized test covering both DelayedScaling and Float8CurrentScaling recipes with fp8_weight_caching=True.

Confidence Score: 5/5

The change is a single-site, additive-only fix to GroupedLinear.forward that copies an already-proven code block from Linear.forward verbatim; the positional alignment with the existing non_tensor_args unpack has been confirmed, and no existing logic is modified.

The fix is mechanical and narrow: one new if/else block and one variable substitution in non_tensor_args. The new block is identical to the reference path in Linear.forward, the downstream unpack in _GroupedLinear.forward already expects the tensor at that position, and FP8GlobalStateManager is already imported. The author clearly flags that GPU numerics were not run, but the logic is straightforward enough that the missing GPU validation is a process note rather than a code concern.

No files require special attention. A reviewer with GPU access should run the new test_make_graphed_callables_grouped_linear_with_fp8_weight_caching test to confirm numeric correctness before merge, as the author requests.

Important Files Changed

Filename Overview
transformer_engine/pytorch/module/grouped_linear.py Adds the fp8_graph_capturing() skip-tensor retrieval block that was missing from GroupedLinear.forward, mirroring the identical pattern in Linear.forward; positional placement in non_tensor_args matches the _GroupedLinear.forward unpack exactly.
tests/pytorch/test_cuda_graphs.py Adds _GroupedLinearWrapper, wires grouped_linear into _test_cuda_graphs, and introduces a dedicated test_make_graphed_callables_grouped_linear_with_fp8_weight_caching test; follows existing FP8 CUDA-graph test patterns correctly.

Sequence Diagram

sequenceDiagram
    participant Caller
    participant GroupedLinear.forward
    participant FP8GlobalStateManager
    participant _GroupedLinear.forward
    participant quantize_weight

    Caller->>GroupedLinear.forward: forward(inp, m_splits, is_first_microbatch)
    GroupedLinear.forward->>FP8GlobalStateManager: fp8_graph_capturing()?
    alt Graph capturing active
        FP8GlobalStateManager-->>GroupedLinear.forward: True
        GroupedLinear.forward->>FP8GlobalStateManager: quantization_state.skip_fp8_weight_update_tensor
        FP8GlobalStateManager-->>GroupedLinear.forward: skip_tensor
        GroupedLinear.forward->>GroupedLinear.forward: "is_first_microbatch = False"
    else Not capturing
        FP8GlobalStateManager-->>GroupedLinear.forward: False
        GroupedLinear.forward->>GroupedLinear.forward: "skip_fp8_weight_update = None"
    end
    GroupedLinear.forward->>_GroupedLinear.forward: "non_tensor_args[18]=skip_fp8_weight_update"
    _GroupedLinear.forward->>quantize_weight: "skip_update_flag=skip_fp8_weight_update"
    quantize_weight-->>_GroupedLinear.forward: quantized weight (or cached)
    _GroupedLinear.forward-->>GroupedLinear.forward: output, new_workspaces
    GroupedLinear.forward-->>Caller: output
Loading

Reviews (5): Last reviewed commit: "[PyTorch] Add CUDA graph FP8 weight-cach..." | Re-trigger Greptile

@LeSingh1 LeSingh1 force-pushed the fix-grouped-linear-skip-fp8-weight-update-3051 branch 2 times, most recently from 852029a to f8e5daa Compare June 1, 2026 05:43
@LeSingh1 LeSingh1 force-pushed the fix-grouped-linear-skip-fp8-weight-update-3051 branch from f8e5daa to 70d505e Compare June 1, 2026 05:46
@jberchtold-nvidia jberchtold-nvidia removed their request for review June 1, 2026 15:37
@ptrendx
Copy link
Copy Markdown
Member

ptrendx commented Jun 1, 2026

@LeSingh1 Could you add the test that would exercise this functionality?

@ptrendx ptrendx self-assigned this Jun 1, 2026
@LeSingh1
Copy link
Copy Markdown
Contributor Author

LeSingh1 commented Jun 1, 2026

@ptrendx Added a test in tests/pytorch/test_cuda_graphs.py: test_make_graphed_callables_grouped_linear_with_fp8_weight_caching.

It wraps GroupedLinear (flattening the [seqlen, batch, hidden] harness input to 2D and splitting evenly across 2 GEMMs) so it slots into the existing _test_cuda_graphs harness, then runs the full / individual / none graph modes with fp8_weight_caching=True and asserts the graphed outputs equal the eager reference. That equality only holds when is_first_microbatch is threaded into the weight-update skip tensor on every microbatch — on the pre-fix None hardcode the cached FP8 weights diverge, so the test fails without the change. Parametrized over DelayedScaling + Float8CurrentScaling, both fp8_params settings, and the fp32/fp16 dtypes already used in this file.

One disclosure: I don't have FP8-capable GPU hardware locally, so I verified the test by construction (mirroring test_make_graphed_callables_with_fp8_weight_caching) rather than by running it — it'll need a CI/GPU run to confirm. Happy to adjust the recipe coverage or wrapper approach if you'd prefer it structured differently.

LeSingh1 added 2 commits June 1, 2026 14:59
…8 CUDA graph capture

GroupedLinear.forward hardcoded None for skip_fp8_weight_update, so the
FP8 graph-capture skip tensor was never forwarded during CUDA graph
replay. Mirror Linear.forward: when fp8_graph_capturing() is true, read
quantization_state.skip_fp8_weight_update_tensor, force is_first_microbatch
to False, and thread the tensor into the forward call (the slot
_GroupedLinear.forward already unpacks).

Fixes NVIDIA#3051

Signed-off-by: LeSingh1 <sshaurya914@gmail.com>
Exercises skip_fp8_weight_update propagation in GroupedLinear during FP8
CUDA graph capture. With fp8_weight_caching enabled, graphed and eager
runs only match when is_first_microbatch is threaded into the weight-
update skip tensor for every microbatch, which the prior None hardcode
prevented.

Signed-off-by: LeSingh1 <sshaurya914@gmail.com>
@LeSingh1 LeSingh1 force-pushed the fix-grouped-linear-skip-fp8-weight-update-3051 branch from 31c60f2 to 6c75a8e Compare June 1, 2026 22:00
@LeSingh1
Copy link
Copy Markdown
Contributor Author

LeSingh1 commented Jun 1, 2026

Rebased onto main to resolve a conflict with #3038 (the GroupedLinear graph-safe refactor). Worth noting: #3038 plumbs skip_fp8_weight_update all the way through _GroupedLinear's autograd function, but the module-level GroupedLinear.forward still passes None, # skip_fp8_weight_update into non_tensor_args, so the value never reaches that plumbing during graph capture. This change wires it up the same way Linear.forward does (read the skip tensor from FP8GlobalStateManager.quantization_state when fp8_graph_capturing()). The test from the previous commit comes along with the rebase. Still GPU/CI-only on my end since I don't have FP8 hardware locally.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants