[PyTorch] Propagate skip_fp8_weight_update in GroupedLinear during FP8 CUDA graph capture#3065
Conversation
Greptile SummaryThis PR fixes a regression where
Confidence Score: 5/5The change is a single-site, additive-only fix to The fix is mechanical and narrow: one new if/else block and one variable substitution in No files require special attention. A reviewer with GPU access should run the new Important Files Changed
Sequence DiagramsequenceDiagram
participant Caller
participant GroupedLinear.forward
participant FP8GlobalStateManager
participant _GroupedLinear.forward
participant quantize_weight
Caller->>GroupedLinear.forward: forward(inp, m_splits, is_first_microbatch)
GroupedLinear.forward->>FP8GlobalStateManager: fp8_graph_capturing()?
alt Graph capturing active
FP8GlobalStateManager-->>GroupedLinear.forward: True
GroupedLinear.forward->>FP8GlobalStateManager: quantization_state.skip_fp8_weight_update_tensor
FP8GlobalStateManager-->>GroupedLinear.forward: skip_tensor
GroupedLinear.forward->>GroupedLinear.forward: "is_first_microbatch = False"
else Not capturing
FP8GlobalStateManager-->>GroupedLinear.forward: False
GroupedLinear.forward->>GroupedLinear.forward: "skip_fp8_weight_update = None"
end
GroupedLinear.forward->>_GroupedLinear.forward: "non_tensor_args[18]=skip_fp8_weight_update"
_GroupedLinear.forward->>quantize_weight: "skip_update_flag=skip_fp8_weight_update"
quantize_weight-->>_GroupedLinear.forward: quantized weight (or cached)
_GroupedLinear.forward-->>GroupedLinear.forward: output, new_workspaces
GroupedLinear.forward-->>Caller: output
Reviews (5): Last reviewed commit: "[PyTorch] Add CUDA graph FP8 weight-cach..." | Re-trigger Greptile |
852029a to
f8e5daa
Compare
f8e5daa to
70d505e
Compare
|
@LeSingh1 Could you add the test that would exercise this functionality? |
|
@ptrendx Added a test in It wraps One disclosure: I don't have FP8-capable GPU hardware locally, so I verified the test by construction (mirroring |
…8 CUDA graph capture GroupedLinear.forward hardcoded None for skip_fp8_weight_update, so the FP8 graph-capture skip tensor was never forwarded during CUDA graph replay. Mirror Linear.forward: when fp8_graph_capturing() is true, read quantization_state.skip_fp8_weight_update_tensor, force is_first_microbatch to False, and thread the tensor into the forward call (the slot _GroupedLinear.forward already unpacks). Fixes NVIDIA#3051 Signed-off-by: LeSingh1 <sshaurya914@gmail.com>
Exercises skip_fp8_weight_update propagation in GroupedLinear during FP8 CUDA graph capture. With fp8_weight_caching enabled, graphed and eager runs only match when is_first_microbatch is threaded into the weight- update skip tensor for every microbatch, which the prior None hardcode prevented. Signed-off-by: LeSingh1 <sshaurya914@gmail.com>
31c60f2 to
6c75a8e
Compare
|
Rebased onto main to resolve a conflict with #3038 (the GroupedLinear graph-safe refactor). Worth noting: #3038 plumbs skip_fp8_weight_update all the way through _GroupedLinear's autograd function, but the module-level GroupedLinear.forward still passes |
Description
GroupedLinear.forwardhardcodedNoneforskip_fp8_weight_update, so the FP8 graph-capture skip tensor was never forwarded during CUDA graph replay — the bug reported in #3051.Linear.forwardhandles this correctly: whenFP8GlobalStateManager.fp8_graph_capturing()is true it readsquantization_state.skip_fp8_weight_update_tensor, forcesis_first_microbatch = False, and threads the tensor into its forward args.GroupedLinearomitted that retrieval block and passed a literalNone, even though the autograd_GroupedLinear.forwardalready unpacks and uses the flag (quantize_weight(..., skip_update_flag=skip_fp8_weight_update)).Fix
Mirror the
Linearreference path inGroupedLinear.forward: add the samefp8_graph_capturing()retrieval block and threadskip_fp8_weight_updateintonon_tensor_args(positionally aligned with the existing autograd unpack).FP8GlobalStateManageris already imported in the module.Verification status
Addresses #3051