Skip to content

Optimize grouped split metadata preparation#3075

Open
zhongbozhu wants to merge 12 commits into
NVIDIA:mainfrom
zhongbozhu:main_prepare_grouped_splits
Open

Optimize grouped split metadata preparation#3075
zhongbozhu wants to merge 12 commits into
NVIDIA:mainfrom
zhongbozhu:main_prepare_grouped_splits

Conversation

@zhongbozhu
Copy link
Copy Markdown
Collaborator

@zhongbozhu zhongbozhu commented Jun 2, 2026

Description

Refactored #2991

Progress toward #2897

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label Jun 2, 2026
@zhongbozhu zhongbozhu force-pushed the main_prepare_grouped_splits branch from 2f3c665 to 12185bf Compare June 2, 2026 00:23
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
@zhongbozhu zhongbozhu force-pushed the main_prepare_grouped_splits branch from 1c48191 to df8b4bb Compare June 2, 2026 01:02
}

const int64_t offsets_length = num_groups + 1;
auto split_sizes_i64 = split_sizes_for_kernel.scalar_type() == at::kLong
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the sake of a clean API nvte_multi_splits_to_offsets, we launch this int32 to int64 convert separately in libtorch C++

Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
@zhongbozhu zhongbozhu force-pushed the main_prepare_grouped_splits branch from 1cde346 to 09e7c06 Compare June 2, 2026 03:59
@zhongbozhu zhongbozhu marked this pull request as ready for review June 2, 2026 04:45
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Jun 2, 2026

Greptile Summary

This PR refactors grouped split metadata preparation by introducing nvte_splits_to_offsets_multi, a new CUDA kernel that computes multiple scaled prefix-sum offset arrays in a single (or chunked) kernel launch. The splits_to_offsets_kernel is moved from common.cu into a new dedicated util/splits_to_offsets.cu source file with a cleaner, generalized implementation.

  • New multi-output kernel: splits_to_offsets_multi accepts up to kMaxNumOutputs = 8 outputs per kernel invocation (with transparent chunking for more), supporting mixed int32/int64 dtypes and per-output strides/leading-zero flags, replacing three separate Python-side tensor operations in forward_grouped_mlp.py with one call.
  • Pre-computed offsets threading: All six quantizer create_grouped_tensor overrides now accept an optional precomputed_tensor_offsets tensor, letting callers skip a redundant GPU kernel when offsets are already available from splits_to_offsets_multi.
  • Alignment-aware bulk allocation: When bulk_allocate=True, output tensors are packed with 16-byte alignment to satisfy cuDNN CuTe-DSL grouped GEMM requirements; new tests verify both correctness and alignment invariants across dtype and device combinations.

Confidence Score: 5/5

Safe to merge; the new kernel is logically correct, all five recomputed offset arrays match the previous Python-side scalar multiplications, and the bulk-allocate alignment path is exercised by the new tests.

The core kernel implements a standard blocked inclusive scan that handles chunking, leading zeros, mixed dtypes, and multi-output writes without observable correctness issues. The C++ wrapper always ensures the input is int64 and CUDA-resident before the kernel is launched, and the quantizer path falls back correctly when precomputed offsets are absent. The two observations raised are about hardening the raw NVTE API against misuse from C++ callers and catching int32 overflow at extreme scale — neither represents a reachable bug in the training path exercised today.

transformer_engine/common/util/splits_to_offsets.cu — the store_output int32 path and the missing device guard in nvte_splits_to_offsets_multi are worth a second look if the API is ever called directly from C++.

Important Files Changed

Filename Overview
transformer_engine/common/util/splits_to_offsets.cu New CUDA kernel implementing prefix-sum for multiple outputs. Logic is correct; silent int64→int32 truncation in store_output is a latent risk for very large sequence counts.
transformer_engine/pytorch/csrc/extensions/misc.cpp New splits_to_offsets_multi C++ wrapper; always converts split_sizes to int64/CUDA before calling kernel — correct and synchronous. Missing CUDA-device validation in the public NVTE API layer below.
transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py Replaces three scalar tensor multiplications with a single splits_to_offsets_multi call; computed values match the previous output exactly for all five offset arrays.
transformer_engine/pytorch/csrc/quantizer.cpp Six quantizer overrides updated identically to accept precomputed_tensor_offsets; fallback to build_grouped_tensor_offsets when not provided is correctly guarded.
transformer_engine/common/include/transformer_engine/transformer_engine.h Added nvte_splits_to_offsets_multi declaration; nvte_splits_to_offsets signature preserved with renamed parameters only.
transformer_engine/pytorch/csrc/extensions/cast.cpp Signature extension to group_quantize and bgrad_group_quantize to thread tensor_offsets through to create_grouped_tensor; backward-compatible with py::none() default.
tests/pytorch/test_grouped_tensor.py Good coverage across dtypes, devices, and bulk-allocate modes; alignment invariants and zero-stride edge cases are tested.

Sequence Diagram

sequenceDiagram
    participant Py as forward_grouped_mlp.py
    participant W as splits_to_offsets_multi (C++)
    participant K as CUDA kernel
    participant Q as Quantizer::create_grouped_tensor

    Py->>W: "split_sizes, strides=[1,1,k1,k2,n2], include_leading_zero=[F,T,T,T,T]"
    W->>W: cast split_sizes to int64, move to CUDA
    W->>W: bulk_allocate output tensors (16-byte aligned)
    W->>K: nvte_splits_to_offsets_multi(split_sizes_nvte, outputs[5], strides, ...)
    K-->>W: [split_points(i32), base_offsets(i64), fc1_offsets(i64), fc2_x_offsets(i64), fc2_out_offsets(i64)]
    W-->>Py: (split_sizes_i64, [split_points, base_offsets, fc1_offsets, fc2_x_offsets, fc2_out_offsets])
    Py->>Q: "group_quantize(fc1_x, quantizer, num_groups, split_sizes, tensor_offsets=fc1_offsets)"
    Q->>Q: precomputed_tensor_offsets.has_value() → skip build_grouped_tensor_offsets
    Q-->>Py: grouped_fc1_x
    Py->>Q: "GroupedTensor(data=fc2_out_buf, first_dims=split_sizes, tensor_offsets=fc2_out_offsets)"
    Q-->>Py: grouped_fc2_out
Loading

Reviews (4): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Comment thread transformer_engine/pytorch/csrc/extensions/misc.cpp Outdated
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
@zhongbozhu
Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch L1

Copy link
Copy Markdown
Collaborator

@vthumbe1503 vthumbe1503 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM. We should put this new API in a grouped_mlp specific namespace similar to Tim's idea in this PR #3076

Comment thread transformer_engine/pytorch/csrc/quantizer.cpp Outdated
py::arg("logical_last_dim"), py::call_guard<py::gil_scoped_release>());
// Returns split_points plus one int64 offsets tensor per logical_last_dims
// entry. Passing logical_last_dim=1 gives the unscaled base offsets.
m.def("prepare_grouped_splits",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably wait for this PR #3076
and define the API under grouped_mlp_experimental namespace.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With some additional arguments, this actually becomes a very general function that can be used for grouped tensors outside the grouped MLP.

timmoon10 and others added 6 commits June 3, 2026 00:11
Replace nvte_multi_splits_to_offsets with nvte_splits_to_offsets_multi,
which takes parallel arrays of output NVTETensors, strides, and a
per-output include_leading_zero flag. The dedicated cumsum slot is gone:
outputs are now a uniform inclusive scan list where each output's length
is either N or N+1 depending on its leading-zero flag. The per-launch
output cap is internal; the public function loops kernel launches when
num_outputs exceeds it, so callers see no hard limit.

v1 nvte_splits_to_offsets now goes through the same shared kernel, and
the prior duplicated kernel is removed. The PyTorch tex wrapper switches
to MultiTensorWrapper to batch NVTETensor allocation for outputs.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Relocate the inclusive-scan kernel, helpers, and launch struct into a
dedicated transformer_engine/common/util/splits_to_offsets.cu, wrapped
in namespace transformer_engine::splits_to_offsets. Dtype dispatch in
load_split_size and store_output now uses a switch with NVTE_DEVICE_ERROR
on unsupported dtypes instead of silently treating unknown dtypes as
int64. nvte_splits_to_offsets_multi with num_outputs == 0 is now a noop
rather than an error, and the internal args struct uses bool for the
per-output include_leading_zero flag (C API stays int* for portability).
v1 parameters renamed to split_sizes/num_splits/stride to align with the
multi variant.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Rename tex.prepare_grouped_splits to tex.splits_to_offsets_multi and
turn it into a general inclusive-scan utility instead of a grouped-MLP
helper. Outputs are configured per-entry via parallel arrays of strides,
include_leading_zero flags, and dtypes; bulk_allocate is opt-in so
general callers get separate per-output at::empty buffers and only the
grouped-MLP hot path takes the shared-storage / 16-byte-aligned route
that cuDNN needs. The wrapper now takes an explicit device, coerces
split_sizes to int64 / CUDA centrally, drops the redundant per-tensor
checks (the core lib validates), and drops non_blocking=true on the
host->device migration to avoid the race Greptile flagged. Update the
grouped MLP caller and tests to the new shape.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Rename the internal create_grouped_tensor parameter from
provided_tensor_offsets to precomputed_tensor_offsets in the quantizer
subclass impls, and document the optional tensor_offsets contract on
the abstract method declaration.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Copy link
Copy Markdown
Member

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've modified the C API function and tex function so that they are more general and not specific to the grouped MLP cuDNN kernels. I've also taken the opportunity to reuse the same kernel for the single-output and multi-output splits-to-offsets functions and to do other miscellaneous housecleaning.

@timmoon10
Copy link
Copy Markdown
Member

/te-ci amd64 arm64

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants