[JAX] Grouped quant+GEMM custom partitioning rules#3058
[JAX] Grouped quant+GEMM custom partitioning rules#3058jberchtold-nvidia wants to merge 12 commits into
Conversation
…mm-custom-partition-rules
Greptile SummaryThis PR adds JAX custom partitioning rules (
Confidence Score: 3/5The core partitioning logic is new and complex; the backward FSDP path in dense.py has a silent EP-axis overwrite that could produce an incorrect sharding constraint in an edge case, and the 1-D size heuristic in gemm.py assumes a specific sharding orientation without enforcing it. The EP-axis overwrite in
Important Files Changed
Sequence DiagramsequenceDiagram
participant JAX as JAX Compiler
participant QP as GroupedQuantizePrimitive.partition
participant GP as GroupedGemmPrimitive.partition
participant SI as sharded_impl (per device)
JAX->>QP: (x, scale, group_sizes) with EP+FSDP specs
QP->>QP: _parse_partition_specs() derive flat/group output specs
QP-->>JAX: arg_shardings, out_shardings, sharded_impl
JAX->>SI: local x shard
SI->>SI: GroupedQuantizePrimitive.impl()
SI->>SI: _pad_or_slice_to_shape(scale_inv, local_shape)
SI->>SI: pmax amax over dp/fsdp
SI-->>JAX: rowwise_out, colwise_out, scale_invs, amax
JAX->>GP: lhs, rhs_weight with EP+FSDP specs
GP->>GP: inject EP into rhs dim-0 if missing
GP->>GP: strip FSDP from rhs gather_rhs_fsdp
GP-->>JAX: arg_shardings, out_sharding, sharded_impl
JAX->>SI: local lhs/rhs shards
SI->>SI: GroupedGemmPrimitive.impl() local GEMM
SI->>SI: psum(out, reduce_axis) if needed
SI-->>JAX: out
Note over JAX,SI: Backward pass inside shard_map only
JAX->>SI: _grouped_dense_bwd_rule
SI->>SI: _is_manual_mesh_axis check
alt FSDP axis is manual
SI->>SI: allgather or psum dgrad
SI->>SI: psum_scatter or psum wgrad
end
SI->>SI: with_sharding_constraint wgrad EP+FSDP spec
|
027b3e6 to
ff0407d
Compare
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
70893b7 to
1bd6b54
Compare
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
59ff8e0 to
3c30c9b
Compare
for more information, see https://pre-commit.ci
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
|
/te-ci L1 jax |
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: