strided BLAS DGEMM path for ToT einsum contractions#559
Open
zhihao-deng wants to merge 18 commits into
Open
Conversation
…A_STRIDED_DGEMM_COUNT option
…atch==1 assert) + tile tests
…ntEngine + e2e (nbatch>1)
…ContEngine + e2e (left-external Mo>1)
…left-external + nbatch>1) + benches
Route the regime-A hc+e einsum (outer Hadamard + outer contraction, inner outer-product) through the landed arena_strided_dgemm_ce_e core (M=N=1, K=tile volume) in run_regime_a_arena, replacing the per-cell rank-1 dger loop with one strided DGEMM per outer-contraction tile. Gated to view+double arena ToT contraction with num_contract_ranks()==0; all other kinds keep the per-cell path. Adds a regime_a_strided_disabled() kill switch, tile/e2e/differential/edge tests, and a strided-vs-per-cell benchmark (~7.3x on a C6H14-like shape).
… (either-side hce+ce)
… timing probe, bench & tests
… ranks The einsum_tot arena-matches-owning tests iterate over all result tile ordinals but only inspect tiles local to the calling rank, then assert the per-rank elements_compared / result_outer_cells_seen counts (and the fatal BOOST_REQUIRE_GT(elements_compared, 0u)) against the global expected totals. That holds under np=1 (all tiles local) but fails under np=2: each rank sees only its share, and a rank owning no result tiles trips the REQUIRE_GT. All-reduce the accumulators (gop.sum on the counts, gop.max on max_abs_diff) before the assertions so every rank checks the true global totals. Fixes the 14 np=2 einsum_tot failures.
Tensor::conj() is scale(conj_op()), which multiplies each element by a ComplexConjugate operator and thus calls detail::conj() on each element. For a tensor-of-tensors the element is itself a TA::Tensor, and detail::conj() only had scalar (real/std::complex) overloads, so conj() of a Tensor<Tensor<...>> (and DistArray<Tensor<Tensor<...>>>::operator()(...).conj()) failed to compile. Add a detail::conj() overload for non-numeric types that forwards to the element's own conj(), recursing until the scalar overloads terminate it. SFINAE'd on a non-numeric type with a conj() member so it never competes with the scalar overloads. Add a Tensor<Tensor<complex>> conj test.
…renaTensor::conj_to)
The complex-ToT conj recursion (prior commit) handled the value-returning path,
but the out-of-place permuted path threw for arena/view inner: Tensor::scale(
factor, perm) had only a view-TA_EXCEPTION branch and a value-based unary branch.
The DistArray .conj() expression lowers to scale(factor, perm) (and scale_to),
so adjoint of a complex ArenaTensor-backed tensor-of-tensors hit that throw.
- scale(factor, perm): add an arena branch mirroring add(right, perm) — scale via
the arena kernel (manages the slab), then permute the result if non-trivial
(arena_perm_is_trivial). Precedes the view branch since ArenaTensor is a view.
- ArenaTensor::conj_to(): in-place conjugation via the free scale_to kernel with a
ComplexConjugate factor (no-op for real T); mirrors neg_to(). Include complex.h.
- tests/tot_construction: conj_tot_{tensor,arena}_inner exercise conj(),
conj(perm), and conj_to() on complex tensor-of-tensors for both inner kinds.
scale_to needed no arena branch: res *= conj_op routes through the free
operator*= -> scale_to kernel, conjugating arena scalars in place.
…nsor-conj tensor: recurse conj() into nested tiles (tensor-of-tensors)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Lift per-cell ToT (ArenaTensor) einsum work to BLAS-3 GEMM wherever possible,
instead of looping per-cell ops.
Recast the following ArenaTensor einsum cases as strided GEMM:
ce+e, inner outer-product): ride the outer-contractionindex into BLAS K
ce+ce, inner contraction — guarded subset, not the generalcase): ride the outer-external index into BLAS M.
Everything outside these guarded regimes keeps the existing per-cell path, so
behavior is unchanged elsewhere.
Guards
A strided GEMM fires only when the cell run is "clean": all cells present,
uniform inner size, and a single constant inter-cell stride. Empty inners punch
holes that break contiguity, so we fall to segmented kernels: walk each run
and emit one strided GEMM per maximal contiguous segment of present cells,
skipping the holes (accumulating with β=1 across segments).
Notes
Still carries env-gated diagnostics (
TA_GEMM_TIMING,TA_STRIDED_DGEMM_VERBOSE,and the
TA_STRIDED_DGEMM_COUNTbuild counters)