Skip to content

Feat/selective offload on srelu fuser#3047

Open
lhb8125 wants to merge 1 commit into
NVIDIA:mainfrom
lhb8125:feat/selective-offload-on-srelu-fuser
Open

Feat/selective offload on srelu fuser#3047
lhb8125 wants to merge 1 commit into
NVIDIA:mainfrom
lhb8125:feat/selective-offload-on-srelu-fuser

Conversation

@lhb8125
Copy link
Copy Markdown
Contributor

@lhb8125 lhb8125 commented May 27, 2026

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label May 27, 2026
@lhb8125 lhb8125 force-pushed the feat/selective-offload-on-srelu-fuser branch from dba3531 to 53e6511 Compare May 27, 2026 07:26
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 27, 2026

Greptile Summary

This PR adds selective CPU-offload control to the SReLU fuser path of _ForwardGroupedMLP_CuTeGEMMBase. When is_cpu_offload_enabled() returns true, weights and saved FC2 intermediate tensors are unconditionally kept resident on GPU, while the FC1 input activation and MoE activation tensors can optionally be kept resident via a per-op flag.

  • GroupedTensorStorage.get_data_tensors() is introduced to expose component tensors to the V1 offload path's mark_activation_offload helper.
  • grouped_fc2_x in the SFD kernel path is promoted from GroupedTensor to GroupedTensorStorage, and fc1_input_quantizer.internal = True is set unconditionally so both input branches produce GroupedTensorStorage.
  • mark_not_offload calls are wired around save_for_backward for weights, FC2 input, FC1 input, and MoE activation tensors, gated on new per-op no_offload_activation attributes.

Confidence Score: 3/5

Not ready to merge without clarification on two forward_grouped_mlp.py changes.

The no_offload_activation attribute controls whether critical tensors are kept on GPU, but it is never set anywhere in the codebase and has no public API, making the selective-offload opt-out path unreachable without undocumented monkey-patching. Separately, fc1_input_quantizer.internal = True is set unconditionally on every forward pass, changing grouped_fc1_x from a GroupedTensor to a GroupedTensorStorage regardless of whether offloading is active, which could silently break downstream consumers expecting a torch.Tensor subclass.

transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py — specifically the no_offload_activation gating logic (lines 680-681) and the unconditional fc1_input_quantizer.internal = True assignment (line 302).

Important Files Changed

Filename Overview
transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py Adds selective offload logic: fc1_input_quantizer.internal is set unconditionally (not just during offloading), grouped_fc2_x is changed from GroupedTensor to GroupedTensorStorage in the SFD path, mark_not_offload is called for weights/activations, and a new no_offload_activation attribute controls per-tensor opt-out — but this attribute is never set or documented anywhere in the repo.
transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py Adds get_data_tensors() that mirrors the field list in prepare_for_saving/restore_from_saved, enabling the V1 offload path to enumerate component tensors for marking.

Sequence Diagram

sequenceDiagram
    participant Caller as Megatron / Caller
    participant FwdMLP as _ForwardGroupedMLP_CuTeGEMMBase.forward
    participant Offload as cpu_offload (V1 or V2)
    participant Ctx as fc1_ctx / activation_ctx / fc2_ctx

    Caller->>FwdMLP: forward(input_, ...)
    FwdMLP->>FwdMLP: Build grouped_fc1_x (GroupedTensorStorage)
    FwdMLP->>FwdMLP: GEMM to activation_in, grouped_fc2_x (GroupedTensorStorage)
    FwdMLP->>Offload: is_cpu_offload_enabled()?
    alt "cpu_offloading = True"
        FwdMLP->>Offload: mark_not_offload fc1_weight_tensors
        opt "no_offload_fc1_activation=True"
            FwdMLP->>Offload: mark_not_offload grouped_fc1_x
        end
        FwdMLP->>Ctx: fc1_ctx.save_for_backward
        opt "no_offload_moe_activation=True"
            FwdMLP->>Offload: mark_not_offload activation_in scales
        end
        FwdMLP->>Ctx: activation_ctx.save_for_backward
        opt saved_grouped_fc2_x not None
            FwdMLP->>Offload: mark_not_offload saved_grouped_fc2_x
        end
        FwdMLP->>Offload: mark_not_offload fc2_weight_tensors
        FwdMLP->>Ctx: fc2_ctx.save_for_backward
    end
    FwdMLP-->>Caller: fc2_out
Loading

Reviews (9): Last reviewed commit: "Support selective offload for fused grou..." | Re-trigger Greptile

Comment on lines +542 to +546
selective_offload = hasattr(fc1_op, "activation_offloading") or hasattr(
activation_op, "activation_offloading"
)
offload_fc1_input = bool(getattr(fc1_op, "activation_offloading", False))
offload_activation_input = bool(getattr(activation_op, "activation_offloading", False))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Selective-offload gate never activates unless callers set activation_offloading on op objects

hasattr(fc1_op, "activation_offloading") checks for a dynamic attribute on the op module. mark_activation_offload (both V1 and non-V1) sets activation_offloading on tensors, not on op instances, and neither GroupedLinear nor ScaledSReLU declare this attribute. As written, selective_offload will always be False and none of the new marking logic will execute unless callers set fc1_op.activation_offloading = True externally. If this is intentional, the attribute name, type, and expected caller pattern should be documented; if not, the gate condition needs to match how the attribute is actually assigned.

@lhb8125 lhb8125 force-pushed the feat/selective-offload-on-srelu-fuser branch from 2c59510 to 6e01d0a Compare May 27, 2026 07:49
Copy link
Copy Markdown
Member

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good, but with one design suggestion.

Followup tasks after merging this PR:

  • Enable activation checkpointing in the unfused grouped linear op.
  • Update activation checkpointing to support v2 infrastructure from #1762, which is opt-out rather than opt-in.

Comment on lines +528 to +531
offload_fc1_input = bool(getattr(fc1_op, "fine_grained_activation_offloading", False))
offload_activation_input = bool(
getattr(activation_op, "fine_grained_activation_offloading", False)
)
Copy link
Copy Markdown
Member

@timmoon10 timmoon10 May 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Do these options give us value? The dense linear op and activation ops don't expose this fine-grained control:

if is_cpu_offload_enabled():
mark_activation_offload(saved_input)

if is_cpu_offload_enabled():
mark_activation_offload(x)

  • For consistency with the rest of the CPU offloading behavior, shouldn't the default be to enable offloading? Disabling offloading should be the explicit path.
  • These secret undocumented attrs are delicate and unmaintainable. Better to make them arguments in the unfused ops. However, this also means we should update the unfused impls so that they disable activation checkpointing if the option is set.

Easiest just to not make this configurable.

Copy link
Copy Markdown
Contributor Author

@lhb8125 lhb8125 Jun 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Tim, I may not fully get you points, but let me try to clarify.

MCore's fine grained activation offloading captures the tensors by adding a hook to torch's save_for_backward(), and it doesn't have a knowledge that where the tensors is from, so the default strategy is offloading all captured tensors and filtering out the tensors which are marked as not offload.

As for the weight tensors, they are marked as not offload by TE's offloading strategy, so for the unfused group linear, we reuse the attr from TE's offloading. For fused group mlp, I followed the group linear's style to mark weight tensors as not offload.

TE's v1 offloading strategy marks tensors as not offload by default, while v2 offloading strategy marks tensors as offload by default, which is consistent with MCore's offloading strategy.

For the unfused impls, we didn't insert too much codes for MCore's offloading because we only need to filter out the weight tensors, which can reuse TE's strategy.

Copy link
Copy Markdown
Member

@timmoon10 timmoon10 Jun 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It makes sense that we need to mark the activations so that the v2 offloading strategy can access them. However, my complaint:

  • te.Linear offloads all activations no matter what.
  • Forward grouped MLP only offloads activations if you set a secret attr.

Why not just make grouped MLP have the same behavior as te.Linear? If I were a user, there's no way I could figure this out without first hitting an unexpected OOM and then digging into the code. We want our APIs to be simple, general, and obvious. Having hyperspecific hacky interfaces might be needed if we're scrambling against a deadline, but it's sloppy software development.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, so you proposal is letting offloading all activations of grouped MLP as the default strategy unless we specify "do not offload ***". I agree it's more consistent and will push a fix.

fanshiqing added a commit to fanshiqing/TransformerEngine that referenced this pull request Jun 2, 2026
@lhb8125 lhb8125 force-pushed the feat/selective-offload-on-srelu-fuser branch from 6f5ef0a to c14deb7 Compare June 2, 2026 10:58
@lhb8125
Copy link
Copy Markdown
Contributor Author

lhb8125 commented Jun 2, 2026

/te-ci pytorch L1

Comment thread transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py
Signed-off-by: hongbinl <hongbinl@nvidia.com>
@lhb8125 lhb8125 force-pushed the feat/selective-offload-on-srelu-fuser branch from c14deb7 to f25a1f5 Compare June 2, 2026 11:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants