Skip to content

[PyTorch] Expose interleave and de-interleave function for GLU tensor for fused grouped MLP via fused ops#3078

Open
ksivaman wants to merge 3 commits into
NVIDIA:mainfrom
ksivaman:interleave_deinterleave_glu_tensor
Open

[PyTorch] Expose interleave and de-interleave function for GLU tensor for fused grouped MLP via fused ops#3078
ksivaman wants to merge 3 commits into
NVIDIA:mainfrom
ksivaman:interleave_deinterleave_glu_tensor

Conversation

@ksivaman
Copy link
Copy Markdown
Member

@ksivaman ksivaman commented Jun 3, 2026

Description

Add utilities to be used to convert FC1 weight format when loading checkpoints. The fused grouped MLP op uses interleaved format for FC1 weight. This functionality was originally added to NVIDIA-NeMo/Megatron-Bridge#2841.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Add utility functions to interleave and de-interleave FC1 weight.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: ksivamani <ksivamani@nvidia.com>
@ksivaman ksivaman requested a review from vthumbe1503 June 3, 2026 01:10
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Jun 3, 2026

Greptile Summary

Adds two public utility functions — interleave_glu_tensor and deinterleave_glu_tensor — to convert FC1 weight tensors between the contiguous [W_all, V_all] layout used by most checkpoints and the block-interleaved layout expected by the fused grouped MLP SwiGLU kernels.

  • Both functions use a reshape → transpose → contiguous → reshape pipeline; the logic is correct and the two are verified inverses of each other.
  • Both functions are properly added to __all__, re-exported from transformer_engine.pytorch, and registered in the Sphinx API reference.

Confidence Score: 5/5

Safe to merge — the change adds two pure-tensor utility functions with no side effects on existing functionality.

Both functions are logically correct (verified as true inverses via manual trace-through), well-documented, properly exported, and isolated from the rest of the codebase. No existing behaviour is modified.

No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/pytorch/utils.py Adds deinterleave_glu_tensor and interleave_glu_tensor utility functions; reshape/transpose logic is correct and the two functions are true inverses of each other.
transformer_engine/pytorch/init.py Exports the two new GLU tensor utility functions from the package's public API.
docs/api/pytorch.rst Registers both new functions in the Sphinx autodoc API reference; note interleave/deinterleave listing order here is reversed relative to init.py and all, though this is cosmetic.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["Checkpoint tensor\n[W_all, V_all]\nshape: [N, d]"] -->|"interleave_glu_tensor(t, k)"| B

    B["reshape([2, N/2k, k, d])"]
    B --> C["transpose(0, 1)\n→ [N/2k, 2, k, d]"]
    C --> D["contiguous()"]
    D --> E["reshape([N, d])"]
    E --> F["Fused kernel tensor\n[W0:k, V0:k, W1:k, V1:k, ...]\nshape: [N, d]"]

    F -->|"deinterleave_glu_tensor(t, k)"| G

    G["reshape([N/2k, 2, k, d])"]
    G --> H["transpose(0, 1)\n→ [2, N/2k, k, d]"]
    H --> I["contiguous()"]
    I --> J["reshape([N, d])"]
    J --> A
Loading

Reviews (2): Last reviewed commit: "Merge branch 'main' into interleave_dein..." | Re-trigger Greptile

Comment on lines +128 to +136
shape = tensor.shape
x = tensor.reshape(
shape[0] // (2 * interleave_size),
2,
interleave_size,
*shape[1:],
)
x = x.transpose(0, 1).contiguous()
return x.reshape(shape)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Missing input validation means users get a cryptic PyTorch reshape error instead of an actionable message. If tensor.shape[0] is not divisible by 2 * interleave_size, the call to reshape raises something like "shape '[N, 2, k, ...]' is invalid for input of size X" with no indication of which argument is wrong. An upfront check surfaces the real constraint immediately.

Suggested change
shape = tensor.shape
x = tensor.reshape(
shape[0] // (2 * interleave_size),
2,
interleave_size,
*shape[1:],
)
x = x.transpose(0, 1).contiguous()
return x.reshape(shape)
if interleave_size <= 0:
raise ValueError(f"interleave_size must be a positive integer, got {interleave_size}")
if tensor.shape[0] % (2 * interleave_size) != 0:
raise ValueError(
f"tensor dimension 0 ({tensor.shape[0]}) must be divisible by "
f"2 * interleave_size ({2 * interleave_size})"
)
shape = tensor.shape
x = tensor.reshape(
shape[0] // (2 * interleave_size),
2,
interleave_size,
*shape[1:],
)
x = x.transpose(0, 1).contiguous()
return x.reshape(shape)

Comment on lines +170 to +178
shape = tensor.shape
x = tensor.reshape(
2,
shape[0] // (2 * interleave_size),
interleave_size,
*shape[1:],
)
x = x.transpose(0, 1).contiguous()
return x.reshape(shape)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Same missing guard in interleave_glu_tensor — if interleave_size is zero or shape[0] is not a multiple of 2 * interleave_size, the user gets an opaque reshape error rather than a clear diagnostic.

Suggested change
shape = tensor.shape
x = tensor.reshape(
2,
shape[0] // (2 * interleave_size),
interleave_size,
*shape[1:],
)
x = x.transpose(0, 1).contiguous()
return x.reshape(shape)
if interleave_size <= 0:
raise ValueError(f"interleave_size must be a positive integer, got {interleave_size}")
if tensor.shape[0] % (2 * interleave_size) != 0:
raise ValueError(
f"tensor dimension 0 ({tensor.shape[0]}) must be divisible by "
f"2 * interleave_size ({2 * interleave_size})"
)
shape = tensor.shape
x = tensor.reshape(
2,
shape[0] // (2 * interleave_size),
interleave_size,
*shape[1:],
)
x = x.transpose(0, 1).contiguous()
return x.reshape(shape)

Copy link
Copy Markdown
Collaborator

@vthumbe1503 vthumbe1503 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Left few minor comments.

"""Convert a block-interleaved SwiGLU fc1 tensor to contiguous gate/linear layout.

Fused SwiGLU kernels (for example :class:`~transformer_engine.pytorch.ops.SwiGLU`
with ``glu_interleave_size`` set) expect fc1 weights in a block-interleaved
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isnt it activations that need SwiGLU? Is it that we need to interleave the weights the same way as the activations? Clarification in the comment would be helpful.

Comment on lines +171 to +178
x = tensor.reshape(
2,
shape[0] // (2 * interleave_size),
interleave_size,
*shape[1:],
)
x = x.transpose(0, 1).contiguous()
return x.reshape(shape)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am curious if we have done perf analysis of this function with and without torch compile and whether it is worth jitting it with torch.compile. I would assume that would generate a triton kernel(not sure if it is a performant one?). Although I wouldnt block this PR for that

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could probably improve perf with tex.swap_first_dims. That said, I'd expect perf is not critical if this is primarily for checkpointing.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking whether we could reuse this for activations as well. But seems like not. The logic in terms of dim ordering is a bit different

Comment thread transformer_engine/pytorch/utils.py
Comment thread transformer_engine/pytorch/utils.py Outdated
Comment thread transformer_engine/pytorch/utils.py Outdated
Comment thread transformer_engine/pytorch/utils.py Outdated
Comment thread transformer_engine/pytorch/utils.py Outdated
ksivaman added 2 commits June 3, 2026 22:46
Signed-off-by: ksivamani <ksivamani@nvidia.com>

Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants