Optimize grouped split metadata preparation#3075
Conversation
2f3c665 to
12185bf
Compare
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
1c48191 to
df8b4bb
Compare
| } | ||
|
|
||
| const int64_t offsets_length = num_groups + 1; | ||
| auto split_sizes_i64 = split_sizes_for_kernel.scalar_type() == at::kLong |
There was a problem hiding this comment.
For the sake of a clean API nvte_multi_splits_to_offsets, we launch this int32 to int64 convert separately in libtorch C++
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
1cde346 to
09e7c06
Compare
for more information, see https://pre-commit.ci
Greptile SummaryThis PR refactors grouped split metadata preparation by introducing
Confidence Score: 5/5Safe to merge; the new kernel is logically correct, all five recomputed offset arrays match the previous Python-side scalar multiplications, and the bulk-allocate alignment path is exercised by the new tests. The core kernel implements a standard blocked inclusive scan that handles chunking, leading zeros, mixed dtypes, and multi-output writes without observable correctness issues. The C++ wrapper always ensures the input is int64 and CUDA-resident before the kernel is launched, and the quantizer path falls back correctly when precomputed offsets are absent. The two observations raised are about hardening the raw NVTE API against misuse from C++ callers and catching int32 overflow at extreme scale — neither represents a reachable bug in the training path exercised today. transformer_engine/common/util/splits_to_offsets.cu — the Important Files Changed
Sequence DiagramsequenceDiagram
participant Py as forward_grouped_mlp.py
participant W as splits_to_offsets_multi (C++)
participant K as CUDA kernel
participant Q as Quantizer::create_grouped_tensor
Py->>W: "split_sizes, strides=[1,1,k1,k2,n2], include_leading_zero=[F,T,T,T,T]"
W->>W: cast split_sizes to int64, move to CUDA
W->>W: bulk_allocate output tensors (16-byte aligned)
W->>K: nvte_splits_to_offsets_multi(split_sizes_nvte, outputs[5], strides, ...)
K-->>W: [split_points(i32), base_offsets(i64), fc1_offsets(i64), fc2_x_offsets(i64), fc2_out_offsets(i64)]
W-->>Py: (split_sizes_i64, [split_points, base_offsets, fc1_offsets, fc2_x_offsets, fc2_out_offsets])
Py->>Q: "group_quantize(fc1_x, quantizer, num_groups, split_sizes, tensor_offsets=fc1_offsets)"
Q->>Q: precomputed_tensor_offsets.has_value() → skip build_grouped_tensor_offsets
Q-->>Py: grouped_fc1_x
Py->>Q: "GroupedTensor(data=fc2_out_buf, first_dims=split_sizes, tensor_offsets=fc2_out_offsets)"
Q-->>Py: grouped_fc2_out
Reviews (4): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
|
/te-ci pytorch L1 |
vthumbe1503
left a comment
There was a problem hiding this comment.
Overall LGTM. We should put this new API in a grouped_mlp specific namespace similar to Tim's idea in this PR #3076
| py::arg("logical_last_dim"), py::call_guard<py::gil_scoped_release>()); | ||
| // Returns split_points plus one int64 offsets tensor per logical_last_dims | ||
| // entry. Passing logical_last_dim=1 gives the unscaled base offsets. | ||
| m.def("prepare_grouped_splits", |
There was a problem hiding this comment.
We should probably wait for this PR #3076
and define the API under grouped_mlp_experimental namespace.
There was a problem hiding this comment.
With some additional arguments, this actually becomes a very general function that can be used for grouped tensors outside the grouped MLP.
Replace nvte_multi_splits_to_offsets with nvte_splits_to_offsets_multi, which takes parallel arrays of output NVTETensors, strides, and a per-output include_leading_zero flag. The dedicated cumsum slot is gone: outputs are now a uniform inclusive scan list where each output's length is either N or N+1 depending on its leading-zero flag. The per-launch output cap is internal; the public function loops kernel launches when num_outputs exceeds it, so callers see no hard limit. v1 nvte_splits_to_offsets now goes through the same shared kernel, and the prior duplicated kernel is removed. The PyTorch tex wrapper switches to MultiTensorWrapper to batch NVTETensor allocation for outputs. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Tim Moon <tmoon@nvidia.com>
Relocate the inclusive-scan kernel, helpers, and launch struct into a dedicated transformer_engine/common/util/splits_to_offsets.cu, wrapped in namespace transformer_engine::splits_to_offsets. Dtype dispatch in load_split_size and store_output now uses a switch with NVTE_DEVICE_ERROR on unsupported dtypes instead of silently treating unknown dtypes as int64. nvte_splits_to_offsets_multi with num_outputs == 0 is now a noop rather than an error, and the internal args struct uses bool for the per-output include_leading_zero flag (C API stays int* for portability). v1 parameters renamed to split_sizes/num_splits/stride to align with the multi variant. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Rename tex.prepare_grouped_splits to tex.splits_to_offsets_multi and turn it into a general inclusive-scan utility instead of a grouped-MLP helper. Outputs are configured per-entry via parallel arrays of strides, include_leading_zero flags, and dtypes; bulk_allocate is opt-in so general callers get separate per-output at::empty buffers and only the grouped-MLP hot path takes the shared-storage / 16-byte-aligned route that cuDNN needs. The wrapper now takes an explicit device, coerces split_sizes to int64 / CUDA centrally, drops the redundant per-tensor checks (the core lib validates), and drops non_blocking=true on the host->device migration to avoid the race Greptile flagged. Update the grouped MLP caller and tests to the new shape. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Tim Moon <tmoon@nvidia.com>
Rename the internal create_grouped_tensor parameter from provided_tensor_offsets to precomputed_tensor_offsets in the quantizer subclass impls, and document the optional tensor_offsets contract on the abstract method declaration. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
timmoon10
left a comment
There was a problem hiding this comment.
I've modified the C API function and tex function so that they are more general and not specific to the grouped MLP cuDNN kernels. I've also taken the opportunity to reuse the same kernel for the single-output and multi-output splits-to-offsets functions and to do other miscellaneous housecleaning.
|
/te-ci amd64 arm64 |
Description
Refactored #2991
Progress toward #2897
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: