Feat/selective offload on srelu fuser#3047
Conversation
dba3531 to
53e6511
Compare
Greptile SummaryThis PR adds selective CPU-offload control to the SReLU fuser path of
Confidence Score: 3/5Not ready to merge without clarification on two forward_grouped_mlp.py changes. The no_offload_activation attribute controls whether critical tensors are kept on GPU, but it is never set anywhere in the codebase and has no public API, making the selective-offload opt-out path unreachable without undocumented monkey-patching. Separately, fc1_input_quantizer.internal = True is set unconditionally on every forward pass, changing grouped_fc1_x from a GroupedTensor to a GroupedTensorStorage regardless of whether offloading is active, which could silently break downstream consumers expecting a torch.Tensor subclass. transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py — specifically the no_offload_activation gating logic (lines 680-681) and the unconditional fc1_input_quantizer.internal = True assignment (line 302). Important Files Changed
Sequence DiagramsequenceDiagram
participant Caller as Megatron / Caller
participant FwdMLP as _ForwardGroupedMLP_CuTeGEMMBase.forward
participant Offload as cpu_offload (V1 or V2)
participant Ctx as fc1_ctx / activation_ctx / fc2_ctx
Caller->>FwdMLP: forward(input_, ...)
FwdMLP->>FwdMLP: Build grouped_fc1_x (GroupedTensorStorage)
FwdMLP->>FwdMLP: GEMM to activation_in, grouped_fc2_x (GroupedTensorStorage)
FwdMLP->>Offload: is_cpu_offload_enabled()?
alt "cpu_offloading = True"
FwdMLP->>Offload: mark_not_offload fc1_weight_tensors
opt "no_offload_fc1_activation=True"
FwdMLP->>Offload: mark_not_offload grouped_fc1_x
end
FwdMLP->>Ctx: fc1_ctx.save_for_backward
opt "no_offload_moe_activation=True"
FwdMLP->>Offload: mark_not_offload activation_in scales
end
FwdMLP->>Ctx: activation_ctx.save_for_backward
opt saved_grouped_fc2_x not None
FwdMLP->>Offload: mark_not_offload saved_grouped_fc2_x
end
FwdMLP->>Offload: mark_not_offload fc2_weight_tensors
FwdMLP->>Ctx: fc2_ctx.save_for_backward
end
FwdMLP-->>Caller: fc2_out
Reviews (9): Last reviewed commit: "Support selective offload for fused grou..." | Re-trigger Greptile |
| selective_offload = hasattr(fc1_op, "activation_offloading") or hasattr( | ||
| activation_op, "activation_offloading" | ||
| ) | ||
| offload_fc1_input = bool(getattr(fc1_op, "activation_offloading", False)) | ||
| offload_activation_input = bool(getattr(activation_op, "activation_offloading", False)) |
There was a problem hiding this comment.
Selective-offload gate never activates unless callers set
activation_offloading on op objects
hasattr(fc1_op, "activation_offloading") checks for a dynamic attribute on the op module. mark_activation_offload (both V1 and non-V1) sets activation_offloading on tensors, not on op instances, and neither GroupedLinear nor ScaledSReLU declare this attribute. As written, selective_offload will always be False and none of the new marking logic will execute unless callers set fc1_op.activation_offloading = True externally. If this is intentional, the attribute name, type, and expected caller pattern should be documented; if not, the gate condition needs to match how the attribute is actually assigned.
2c59510 to
6e01d0a
Compare
There was a problem hiding this comment.
Overall looks good, but with one design suggestion.
Followup tasks after merging this PR:
- Enable activation checkpointing in the unfused grouped linear op.
- Update activation checkpointing to support v2 infrastructure from #1762, which is opt-out rather than opt-in.
| offload_fc1_input = bool(getattr(fc1_op, "fine_grained_activation_offloading", False)) | ||
| offload_activation_input = bool( | ||
| getattr(activation_op, "fine_grained_activation_offloading", False) | ||
| ) |
There was a problem hiding this comment.
- Do these options give us value? The dense linear op and activation ops don't expose this fine-grained control:
TransformerEngine/transformer_engine/pytorch/ops/basic/basic_linear.py
Lines 1052 to 1053 in ace2a96
TransformerEngine/transformer_engine/pytorch/ops/basic/activation.py
Lines 116 to 117 in ace2a96
- For consistency with the rest of the CPU offloading behavior, shouldn't the default be to enable offloading? Disabling offloading should be the explicit path.
- These secret undocumented attrs are delicate and unmaintainable. Better to make them arguments in the unfused ops. However, this also means we should update the unfused impls so that they disable activation checkpointing if the option is set.
Easiest just to not make this configurable.
There was a problem hiding this comment.
Thanks Tim, I may not fully get you points, but let me try to clarify.
MCore's fine grained activation offloading captures the tensors by adding a hook to torch's save_for_backward(), and it doesn't have a knowledge that where the tensors is from, so the default strategy is offloading all captured tensors and filtering out the tensors which are marked as not offload.
As for the weight tensors, they are marked as not offload by TE's offloading strategy, so for the unfused group linear, we reuse the attr from TE's offloading. For fused group mlp, I followed the group linear's style to mark weight tensors as not offload.
TE's v1 offloading strategy marks tensors as not offload by default, while v2 offloading strategy marks tensors as offload by default, which is consistent with MCore's offloading strategy.
For the unfused impls, we didn't insert too much codes for MCore's offloading because we only need to filter out the weight tensors, which can reuse TE's strategy.
There was a problem hiding this comment.
It makes sense that we need to mark the activations so that the v2 offloading strategy can access them. However, my complaint:
te.Linearoffloads all activations no matter what.- Forward grouped MLP only offloads activations if you set a secret attr.
Why not just make grouped MLP have the same behavior as te.Linear? If I were a user, there's no way I could figure this out without first hitting an unexpected OOM and then digging into the code. We want our APIs to be simple, general, and obvious. Having hyperspecific hacky interfaces might be needed if we're scrambling against a deadline, but it's sloppy software development.
There was a problem hiding this comment.
I see, so you proposal is letting offloading all activations of grouped MLP as the default strategy unless we specify "do not offload ***". I agree it's more consistent and will push a fix.
cherry-pick NVIDIA#3047
6f5ef0a to
c14deb7
Compare
|
/te-ci pytorch L1 |
Signed-off-by: hongbinl <hongbinl@nvidia.com>
c14deb7 to
f25a1f5
Compare
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: