[PyTorch] Expose interleave and de-interleave function for GLU tensor for fused grouped MLP via fused ops#3078
Conversation
Signed-off-by: ksivamani <ksivamani@nvidia.com>
Greptile SummaryAdds two public utility functions —
Confidence Score: 5/5Safe to merge — the change adds two pure-tensor utility functions with no side effects on existing functionality. Both functions are logically correct (verified as true inverses via manual trace-through), well-documented, properly exported, and isolated from the rest of the codebase. No existing behaviour is modified. No files require special attention. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["Checkpoint tensor\n[W_all, V_all]\nshape: [N, d]"] -->|"interleave_glu_tensor(t, k)"| B
B["reshape([2, N/2k, k, d])"]
B --> C["transpose(0, 1)\n→ [N/2k, 2, k, d]"]
C --> D["contiguous()"]
D --> E["reshape([N, d])"]
E --> F["Fused kernel tensor\n[W0:k, V0:k, W1:k, V1:k, ...]\nshape: [N, d]"]
F -->|"deinterleave_glu_tensor(t, k)"| G
G["reshape([N/2k, 2, k, d])"]
G --> H["transpose(0, 1)\n→ [2, N/2k, k, d]"]
H --> I["contiguous()"]
I --> J["reshape([N, d])"]
J --> A
Reviews (2): Last reviewed commit: "Merge branch 'main' into interleave_dein..." | Re-trigger Greptile |
| shape = tensor.shape | ||
| x = tensor.reshape( | ||
| shape[0] // (2 * interleave_size), | ||
| 2, | ||
| interleave_size, | ||
| *shape[1:], | ||
| ) | ||
| x = x.transpose(0, 1).contiguous() | ||
| return x.reshape(shape) |
There was a problem hiding this comment.
Missing input validation means users get a cryptic PyTorch reshape error instead of an actionable message. If
tensor.shape[0] is not divisible by 2 * interleave_size, the call to reshape raises something like "shape '[N, 2, k, ...]' is invalid for input of size X" with no indication of which argument is wrong. An upfront check surfaces the real constraint immediately.
| shape = tensor.shape | |
| x = tensor.reshape( | |
| shape[0] // (2 * interleave_size), | |
| 2, | |
| interleave_size, | |
| *shape[1:], | |
| ) | |
| x = x.transpose(0, 1).contiguous() | |
| return x.reshape(shape) | |
| if interleave_size <= 0: | |
| raise ValueError(f"interleave_size must be a positive integer, got {interleave_size}") | |
| if tensor.shape[0] % (2 * interleave_size) != 0: | |
| raise ValueError( | |
| f"tensor dimension 0 ({tensor.shape[0]}) must be divisible by " | |
| f"2 * interleave_size ({2 * interleave_size})" | |
| ) | |
| shape = tensor.shape | |
| x = tensor.reshape( | |
| shape[0] // (2 * interleave_size), | |
| 2, | |
| interleave_size, | |
| *shape[1:], | |
| ) | |
| x = x.transpose(0, 1).contiguous() | |
| return x.reshape(shape) |
| shape = tensor.shape | ||
| x = tensor.reshape( | ||
| 2, | ||
| shape[0] // (2 * interleave_size), | ||
| interleave_size, | ||
| *shape[1:], | ||
| ) | ||
| x = x.transpose(0, 1).contiguous() | ||
| return x.reshape(shape) |
There was a problem hiding this comment.
Same missing guard in
interleave_glu_tensor — if interleave_size is zero or shape[0] is not a multiple of 2 * interleave_size, the user gets an opaque reshape error rather than a clear diagnostic.
| shape = tensor.shape | |
| x = tensor.reshape( | |
| 2, | |
| shape[0] // (2 * interleave_size), | |
| interleave_size, | |
| *shape[1:], | |
| ) | |
| x = x.transpose(0, 1).contiguous() | |
| return x.reshape(shape) | |
| if interleave_size <= 0: | |
| raise ValueError(f"interleave_size must be a positive integer, got {interleave_size}") | |
| if tensor.shape[0] % (2 * interleave_size) != 0: | |
| raise ValueError( | |
| f"tensor dimension 0 ({tensor.shape[0]}) must be divisible by " | |
| f"2 * interleave_size ({2 * interleave_size})" | |
| ) | |
| shape = tensor.shape | |
| x = tensor.reshape( | |
| 2, | |
| shape[0] // (2 * interleave_size), | |
| interleave_size, | |
| *shape[1:], | |
| ) | |
| x = x.transpose(0, 1).contiguous() | |
| return x.reshape(shape) |
vthumbe1503
left a comment
There was a problem hiding this comment.
LGTM. Left few minor comments.
| """Convert a block-interleaved SwiGLU fc1 tensor to contiguous gate/linear layout. | ||
|
|
||
| Fused SwiGLU kernels (for example :class:`~transformer_engine.pytorch.ops.SwiGLU` | ||
| with ``glu_interleave_size`` set) expect fc1 weights in a block-interleaved |
There was a problem hiding this comment.
Isnt it activations that need SwiGLU? Is it that we need to interleave the weights the same way as the activations? Clarification in the comment would be helpful.
| x = tensor.reshape( | ||
| 2, | ||
| shape[0] // (2 * interleave_size), | ||
| interleave_size, | ||
| *shape[1:], | ||
| ) | ||
| x = x.transpose(0, 1).contiguous() | ||
| return x.reshape(shape) |
There was a problem hiding this comment.
I am curious if we have done perf analysis of this function with and without torch compile and whether it is worth jitting it with torch.compile. I would assume that would generate a triton kernel(not sure if it is a performant one?). Although I wouldnt block this PR for that
There was a problem hiding this comment.
We could probably improve perf with tex.swap_first_dims. That said, I'd expect perf is not critical if this is primarily for checkpointing.
There was a problem hiding this comment.
I was thinking whether we could reuse this for activations as well. But seems like not. The logic in terms of dim ordering is a bit different
Signed-off-by: ksivamani <ksivamani@nvidia.com> Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Description
Add utilities to be used to convert FC1 weight format when loading checkpoints. The fused grouped MLP op uses interleaved format for FC1 weight. This functionality was originally added to NVIDIA-NeMo/Megatron-Bridge#2841.
Type of change
Changes
Checklist: