Skip to content

[PyTorch] Propagate FP8 graph weight update flag in GroupedLinear#3052

Open
allenphilipj wants to merge 3 commits into
NVIDIA:mainfrom
allenphilipj:codex-grouped-linear-fp8-cudagraph-skip
Open

[PyTorch] Propagate FP8 graph weight update flag in GroupedLinear#3052
allenphilipj wants to merge 3 commits into
NVIDIA:mainfrom
allenphilipj:codex-grouped-linear-fp8-cudagraph-skip

Conversation

@allenphilipj
Copy link
Copy Markdown

Summary:

  • Propagate the FP8 graph-capture skip_fp8_weight_update tensor through GroupedLinear.
  • Align GroupedLinear graph-capture handling with Linear, LayerNormLinear, and LayerNormMLP.
  • Add a focused regression test for the forwarded skip tensor and graph-compatible is_first_microbatch behavior.

Validation:

  • git diff --check
  • python3 -m py_compile transformer_engine/pytorch/module/grouped_linear.py tests/pytorch/test_cuda_graphs.py
  • Not run: focused pytest, because pytest is not installed in this local environment.

Fixes #3051

@allenphilipj allenphilipj requested a review from ksivaman as a code owner May 28, 2026 12:36
@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label May 28, 2026
@allenphilipj allenphilipj force-pushed the codex-grouped-linear-fp8-cudagraph-skip branch from 937ef34 to 80304fa Compare May 28, 2026 12:40
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 28, 2026

Greptile Summary

This PR fixes a missing propagation of the FP8 CUDA-graph weight-update skip flag (skip_fp8_weight_update_tensor) in GroupedLinear.forward, bringing it into alignment with Linear, LayerNormLinear, and LayerNormMLP, which already implemented this pattern.

  • grouped_linear.py: Mirrors the three-line block present in sibling modules — reads skip_fp8_weight_update_tensor from FP8GlobalStateManager.quantization_state when fp8_graph_capturing() is true, and forces is_first_microbatch = False when the tensor is set, then passes it (instead of a hard-coded None) into _GroupedLinear.forward.
  • test_cuda_graphs.py: Adds test_grouped_linear_forwards_fp8_graph_skip_tensor which monkeypatches the graph-capture flag and the state-manager tensor, captures non_tensor_args, and asserts both is_first_microbatch (index 1) and skip_fp8_weight_update (index 18) carry the expected values.

Confidence Score: 5/5

Safe to merge — the change is a narrow, targeted fix that replaces a hard-coded None with the tensor already managed by sibling modules using the identical three-line pattern.

The diff is a direct copy of the same guard block from Linear/LayerNormLinear/LayerNormMLP. The non_tensor_args tuple layout in GroupedLinear already had a slot for skip_fp8_weight_update (previously always None), so no downstream unpacking changes are needed. The regression test validates both the tensor propagation and the is_first_microbatch override in a single focused call path.

No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/pytorch/module/grouped_linear.py Adds FP8 graph-capture skip tensor propagation to GroupedLinear.forward(), overriding is_first_microbatch to False when the tensor is present — directly mirrors the pattern used in Linear, LayerNormLinear, and LayerNormMLP.
tests/pytorch/test_cuda_graphs.py Adds a focused regression test that monkeypatches fp8_graph_capturing and skip_fp8_weight_update_tensor to confirm the tensor lands at the correct index (18) in non_tensor_args and that is_first_microbatch is forced to False.

Sequence Diagram

sequenceDiagram
    participant Caller
    participant GroupedLinear
    participant FP8GlobalStateManager
    participant _GroupedLinear

    Caller->>GroupedLinear: "forward(inp, m_splits, is_first_microbatch=True)"
    GroupedLinear->>FP8GlobalStateManager: fp8_graph_capturing()
    alt CUDA graph capture active
        FP8GlobalStateManager-->>GroupedLinear: True
        GroupedLinear->>FP8GlobalStateManager: quantization_state.skip_fp8_weight_update_tensor
        FP8GlobalStateManager-->>GroupedLinear: Tensor (or None)
        alt tensor is not None
            GroupedLinear->>GroupedLinear: "is_first_microbatch = False"
        end
    else Normal execution
        FP8GlobalStateManager-->>GroupedLinear: False
        GroupedLinear->>GroupedLinear: "skip_fp8_weight_update = None"
    end
    GroupedLinear->>_GroupedLinear: "forward(..., non_tensor_args[18]=skip_fp8_weight_update, ...)"
    _GroupedLinear-->>GroupedLinear: out, new_workspaces
    GroupedLinear-->>Caller: out
Loading

Reviews (6): Last reviewed commit: "Merge branch 'main' into codex-grouped-l..." | Re-trigger Greptile

Comment thread tests/pytorch/test_cuda_graphs.py Outdated
@ksivaman
Copy link
Copy Markdown
Member

/te-ci pytorch

Signed-off-by: allenphilipj <allenphilipj@users.noreply.github.com>
@allenphilipj allenphilipj force-pushed the codex-grouped-linear-fp8-cudagraph-skip branch from d7a4caa to 1890acf Compare June 2, 2026 16:38
@allenphilipj
Copy link
Copy Markdown
Author

@ksivaman I've rebased on the latest main & resolved the conflicts, would much appreciate a follow-up review.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[PyTorch] GroupedLinear does not propagate skip_fp8_weight_update during FP8 CUDA graph capture

2 participants