diff --git a/src/TiledArray/tensor/arena_tensor.h b/src/TiledArray/tensor/arena_tensor.h index 473c342125..783e8d5956 100644 --- a/src/TiledArray/tensor/arena_tensor.h +++ b/src/TiledArray/tensor/arena_tensor.h @@ -12,6 +12,7 @@ #include "TiledArray/error.h" #include "TiledArray/math/blas.h" #include "TiledArray/math/gemm_helper.h" +#include "TiledArray/tensor/complex.h" #include "TiledArray/tensor/type_traits.h" #include @@ -250,6 +251,13 @@ class ArenaTensor { ::TiledArray::scale_to(*this, -T(1)); return *this; } + /// In-place complex conjugation. Routes through the free `scale_to` kernel + /// with a ComplexConjugate factor, which conjugates each arena-backed scalar + /// in place (a no-op for real `T`). Mirrors `neg_to()`. + ArenaTensor& conj_to() { + ::TiledArray::scale_to(*this, ::TiledArray::detail::conj_op()); + return *this; + } /// axpy: *this += other * factor (axpy semantics; factor scales /// only the added operand). Delegates to the free `axpy` CPO that the diff --git a/src/TiledArray/tensor/complex.h b/src/TiledArray/tensor/complex.h index a7a25787bb..fa60a2c39b 100644 --- a/src/TiledArray/tensor/complex.h +++ b/src/TiledArray/tensor/complex.h @@ -57,6 +57,24 @@ TILEDARRAY_FORCE_INLINE std::complex conj(const std::complex z) { return std::conj(z); } +/// Conjugate a (possibly nested) tensor element + +/// Enables `conj` of a tensor whose elements are themselves tensors +/// (tensor-of-tensors): `Tensor::conj()` is `scale(conj_op())`, which +/// multiplies each element by a ComplexConjugate operator and thus dispatches +/// back here for each inner tensor. This overload forwards to the element's own +/// `conj()`, recursing until the leaf scalar overloads above terminate it. +/// SFINAE'd on a non-numeric type that provides a `conj()` member so it never +/// competes with the scalar overloads. +/// \tparam T A (nested) tensor type +/// \param t The tensor to conjugate +/// \return The elementwise complex conjugate of `t` +template >::type* = nullptr> +TILEDARRAY_FORCE_INLINE auto conj(const T& t) -> decltype(t.conj()) { + return t.conj(); +} + /// Inner product of a real value and a numeric value /// \tparam L A real scalar type diff --git a/src/TiledArray/tensor/tensor.h b/src/TiledArray/tensor/tensor.h index d7d28220f3..c623eaccb1 100644 --- a/src/TiledArray/tensor/tensor.h +++ b/src/TiledArray/tensor/tensor.h @@ -111,19 +111,22 @@ inline bool scale_gemm_timing_enabled() { /// Counters for one scale regime. `{0}` member-init gives well-defined zero. struct ScaleRegimeCounters { - std::atomic gemm_ns{0}; // wall ns inside the strided gemm - std::atomic fb_ns{0}; // wall ns inside the AXPY fallback - std::atomic gemm_runs{0}; // clean rows/cols (one strided GEMM) - std::atomic fb_runs{0}; // rows/cols that fell back to AXPY - std::atomic gemm_flop{0}; // 2*K*N*A (clean), summed - std::atomic fb_flop{0}; // exact 2*K*Sum(cellsize) (fallback) - std::atomic fb_absent{0}; // fallback reason: an empty cell - std::atomic fb_ragged{0}; // fallback reason: ragged inner size - std::atomic fb_stride{0}; // fallback reason: multi-page stride + std::atomic gemm_ns{0}; // wall ns inside the strided gemm + std::atomic fb_ns{0}; // wall ns inside the AXPY fallback + std::atomic gemm_runs{ + 0}; // clean rows/cols (one strided GEMM) + std::atomic fb_runs{0}; // rows/cols that fell back to AXPY + std::atomic gemm_flop{0}; // 2*K*N*A (clean), summed + std::atomic fb_flop{0}; // exact 2*K*Sum(cellsize) (fallback) + std::atomic fb_absent{0}; // fallback reason: an empty cell + std::atomic fb_ragged{ + 0}; // fallback reason: ragged inner size + std::atomic fb_stride{ + 0}; // fallback reason: multi-page stride // --- phase breakdown of the per-(b,m) loop (Amdahl of the 75% overhead) --- - std::atomic kernel_ns{0}; // whole for-b/for-m loop body - std::atomic check_pres_ns{0};// per-row presence + size scan - std::atomic check_str_ns{0}; // per-row constant-stride walk + std::atomic kernel_ns{0}; // whole for-b/for-m loop body + std::atomic check_pres_ns{0}; // per-row presence + size scan + std::atomic check_str_ns{0}; // per-row constant-stride walk // beta-eligibility: how many Tensor::gemm CALLS land on a freshly-allocated // (this->empty()) output tile -- where beta=0 would be valid -- vs an // accumulation into an existing tile (beta=1 required for correctness). @@ -189,7 +192,10 @@ struct ScaleGemmTimingDumper { const auto gns = L(g_scale[r].gemm_ns), fns = L(g_scale[r].fb_ns); const auto gr = L(g_scale[r].gemm_runs), fr = L(g_scale[r].fb_runs); const auto gf = L(g_scale[r].gemm_flop), ff = L(g_scale[r].fb_flop); - tg_ns += gns; tf_ns += fns; tg_fl += gf; tf_fl += ff; + tg_ns += gns; + tf_ns += fns; + tg_fl += gf; + tf_fl += ff; const double tt = static_cast(gns + fns); const double ftot = static_cast(gf + ff); std::cerr << "[scale-timing] " << names[r] << ":\n"; @@ -200,16 +206,18 @@ struct ScaleGemmTimingDumper { std::cerr << "[scale-timing] time coverage (GEMM / total) : " << (tt > 0 ? 100.0 * gns / tt : 0.0) << "%\n"; std::cerr << "[scale-timing] FLOP coverage (GEMM / total) : " - << (ftot > 0 ? 100.0 * gf / ftot : 0.0) << "% (" - << gf / 1e9 << " GFLOP gemm / " << ftot / 1e9 << " GFLOP)\n"; + << (ftot > 0 ? 100.0 * gf / ftot : 0.0) << "% (" << gf / 1e9 + << " GFLOP gemm / " << ftot / 1e9 << " GFLOP)\n"; std::cerr << "[scale-timing] GFLOP/s strided=" << (gns > 0 ? gf / static_cast(gns) : 0.0) << " fallback=" << (fns > 0 ? ff / static_cast(fns) : 0.0) << "\n"; std::cerr << "[scale-timing] fallback runs by reason: absent=" - << L(g_scale[r].fb_absent) << " ragged=" << L(g_scale[r].fb_ragged) + << L(g_scale[r].fb_absent) + << " ragged=" << L(g_scale[r].fb_ragged) << " multipage-stride=" << L(g_scale[r].fb_stride) << "\n"; - // Phase breakdown of the per-(b,m) loop = where the non-GEMM overhead goes. + // Phase breakdown of the per-(b,m) loop = where the non-GEMM overhead + // goes. const auto kn = L(g_scale[r].kernel_ns); const auto cp = L(g_scale[r].check_pres_ns); const auto cs = L(g_scale[r].check_str_ns); @@ -2009,7 +2017,14 @@ class Tensor { // early exit for empty this if (empty()) return {}; - if constexpr (is_tensor_view_v) { + if constexpr (is_arena_tensor_v) { + // Arena inner cells: scale via the arena kernel (which manages the slab), + // then apply the result permutation if non-trivial. Mirrors the arena + // add(right, perm) overload above. ArenaTensor is also a view, so this + // branch must precede the view branch below. + auto result = scale(factor); + return arena_perm_is_trivial(perm) ? result : result.permute(perm); + } else if constexpr (is_tensor_view_v) { TA_EXCEPTION( "Tensor::scale(factor, perm): permutation is not " "supported for view inner cells"); @@ -2035,8 +2050,18 @@ class Tensor { // early exit for empty this if (empty()) return *this; - return inplace_unary( - [factor](value_type& MADNESS_RESTRICT res) { res *= factor; }); + if constexpr (is_arena_tensor_v) { + // Arena inner cells: route through each cell's own in-place scale_to (the + // free arena kernel), which handles a ComplexConjugate factor by + // conjugating each arena scalar in place. Going through `cell *= factor` + // would instead select the generic operator*=(.., ComplexConjugate) -> + // detail::conj(cell), which has no value-returning conj for ArenaTensor. + return inplace_unary( + [factor](value_type& MADNESS_RESTRICT c) { c.scale_to(factor); }); + } else { + return inplace_unary( + [factor](value_type& MADNESS_RESTRICT res) { res *= factor; }); + } } // Addition operations @@ -3382,16 +3407,28 @@ class Tensor { long a0 = -1; for (integer k = 0; k != K; ++k) { const auto& c = lc0[k]; - if (c.empty()) { absent = true; break; } + if (c.empty()) { + absent = true; + break; + } long s = static_cast(c.size()); - if (a0 < 0) a0 = s; else if (a0 != s) ragged = true; + if (a0 < 0) + a0 = s; + else if (a0 != s) + ragged = true; } if (!absent) for (integer n = 0; n != N; ++n) { const auto& c = rc0[n]; - if (c.empty()) { absent = true; break; } + if (c.empty()) { + absent = true; + break; + } long s = static_cast(c.size()); - if (a0 < 0) a0 = s; else if (a0 != s) ragged = true; + if (a0 < 0) + a0 = s; + else if (a0 != s) + ragged = true; } std::uint64_t fl = 0; for (integer n = 0; n != N; ++n) @@ -3401,9 +3438,9 @@ class Tensor { 1, std::memory_order_relaxed); detail::g_scale[0].fb_flop.fetch_add( fl, std::memory_order_relaxed); - (absent ? detail::g_scale[0].fb_absent - : ragged ? detail::g_scale[0].fb_ragged - : detail::g_scale[0].fb_stride) + (absent ? detail::g_scale[0].fb_absent + : ragged ? detail::g_scale[0].fb_ragged + : detail::g_scale[0].fb_stride) .fetch_add(1, std::memory_order_relaxed); } detail::ScopedScaleTimer _scale_fb(detail::g_scale[0].fb_ns); @@ -3525,28 +3562,41 @@ class Tensor { long a0 = -1; for (integer k = 0; k != K; ++k) { const auto& c = right_data[k * N + n]; - if (c.empty()) { absent = true; break; } + if (c.empty()) { + absent = true; + break; + } long s = static_cast(c.size()); - if (a0 < 0) a0 = s; else if (a0 != s) ragged = true; + if (a0 < 0) + a0 = s; + else if (a0 != s) + ragged = true; } if (!absent) for (integer m = 0; m != M; ++m) { const auto& c = this_data[m * N + n]; - if (c.empty()) { absent = true; break; } + if (c.empty()) { + absent = true; + break; + } long s = static_cast(c.size()); - if (a0 < 0) a0 = s; else if (a0 != s) ragged = true; + if (a0 < 0) + a0 = s; + else if (a0 != s) + ragged = true; } std::uint64_t fl = 0; for (integer m = 0; m != M; ++m) - fl += 2ull * static_cast(K) * - static_cast(this_data[m * N + n].size()); + fl += + 2ull * static_cast(K) * + static_cast(this_data[m * N + n].size()); detail::g_scale[1].fb_runs.fetch_add( 1, std::memory_order_relaxed); detail::g_scale[1].fb_flop.fetch_add( fl, std::memory_order_relaxed); - (absent ? detail::g_scale[1].fb_absent - : ragged ? detail::g_scale[1].fb_ragged - : detail::g_scale[1].fb_stride) + (absent ? detail::g_scale[1].fb_absent + : ragged ? detail::g_scale[1].fb_ragged + : detail::g_scale[1].fb_stride) .fetch_add(1, std::memory_order_relaxed); } detail::ScopedScaleTimer _scale_fb(detail::g_scale[1].fb_ns); diff --git a/tests/tensor.cpp b/tests/tensor.cpp index 3e67f89469..206ebad985 100644 --- a/tests/tensor.cpp +++ b/tests/tensor.cpp @@ -714,6 +714,31 @@ BOOST_AUTO_TEST_CASE(inplace_conj_scal_op) { } } +BOOST_AUTO_TEST_CASE(conj_op_tensor_of_tensor) { + // conj() of a tensor-of-tensors must conjugate every inner element. + // Regression: detail::conj had no overload for a tensor-valued element, so + // Tensor>::conj() (== scale(conj_op())) failed to compile. + using TensorOfTensorZ = Tensor; + TensorOfTensorZ s(r); + for (std::size_t i = 0ul; i < s.size(); ++i) { + TensorZ inner(r); + rand_fill(static_cast(431 + i), inner.size(), inner.data()); + s[i] = inner; + } + + TensorOfTensorZ t; + BOOST_REQUIRE_NO_THROW(t = s.conj()); + + BOOST_CHECK_EQUAL(t.range(), s.range()); + for (std::size_t i = 0ul; i < t.size(); ++i) { + BOOST_CHECK_EQUAL(t[i].range(), s[i].range()); + for (std::size_t j = 0ul; j < t[i].size(); ++j) { + BOOST_CHECK_EQUAL(t[i][j].real(), s[i][j].real()); + BOOST_CHECK_EQUAL(t[i][j].imag(), -s[i][j].imag()); + } + } +} + BOOST_AUTO_TEST_CASE(block) { TensorZ s(r); auto lobound = r.lobound(); diff --git a/tests/tot_construction.cpp b/tests/tot_construction.cpp index bfa3a9ee75..3084f7e875 100644 --- a/tests/tot_construction.cpp +++ b/tests/tot_construction.cpp @@ -10,6 +10,7 @@ #include "global_fixture.h" #include "unit_test_config.h" +#include #include #include @@ -574,6 +575,59 @@ void test_arena_tile_permute() { } } +/// conj()/conj(perm)/conj_to() on a tensor-of-tensors with complex inner cells. +/// Exercises the arena-aware scale(factor) and scale(factor, perm) paths (the +/// latter previously threw for view/arena inner) and in-place conj_to. Real +/// conj would be a no-op, so inner cells carry a nonzero imaginary part. +template +void test_conj_tot() { + using OuterTile = TA::Tensor; + using cd = typename InnerTile::value_type; // std::complex + + // rank-2 outer so a non-identity outer permutation is meaningful + const TA::Range outer{2, 2}; + auto rng = [](const auto&) { return typename InnerTile::range_type{3}; }; + auto fill = [](auto& cell, const auto& idx) { + const long e = static_cast(idx[0]) * 10 + static_cast(idx[1]); + for (std::size_t i = 0; i < cell.size(); ++i) + cell.data()[i] = cd(100.0 * e + i, 1.0 + e - static_cast(i)); + }; + OuterTile src = TA::detail::make_nested_tile(outer, rng, fill); + + auto inner_is_conj = [](const InnerTile& got, const InnerTile& s) { + BOOST_REQUIRE_EQUAL(got.size(), s.size()); + for (std::size_t i = 0; i < got.size(); ++i) { + BOOST_CHECK_EQUAL(got.data()[i].real(), s.data()[i].real()); + BOOST_CHECK_EQUAL(got.data()[i].imag(), -s.data()[i].imag()); + } + }; + + // out-of-place conj() (scale(conj_op) -> arena_trivial_unary) + OuterTile c = src.conj(); + for (std::size_t o = 0; o < src.range().volume(); ++o) + inner_is_conj(c.data()[o], src.data()[o]); + + // permuted conj() (scale(conj_op, perm) -- the arena branch under test); + // must agree with conj-then-permute. + TA::Permutation perm({1, 0}); // swap the two outer modes + OuterTile cp = src.conj(perm); + OuterTile cp_ref = src.conj().permute(perm); + BOOST_REQUIRE_EQUAL(cp.range(), cp_ref.range()); + for (std::size_t o = 0; o < cp.range().volume(); ++o) { + const InnerTile& a = cp.data()[o]; + const InnerTile& b = cp_ref.data()[o]; + BOOST_REQUIRE_EQUAL(a.size(), b.size()); + for (std::size_t i = 0; i < a.size(); ++i) + BOOST_CHECK_EQUAL(a.data()[i], b.data()[i]); + } + + // in-place conj_to() (scale_to(conj_op) -> free arena operator*=) + OuterTile t2 = TA::detail::make_nested_tile(outer, rng, fill); + t2.conj_to(); + for (std::size_t o = 0; o < src.range().volume(); ++o) + inner_is_conj(t2.data()[o], src.data()[o]); +} + } // namespace BOOST_AUTO_TEST_SUITE(tot_construction_suite, TA_UT_LABEL_SERIAL) @@ -673,6 +727,13 @@ BOOST_AUTO_TEST_CASE(neg_arena_inner) { test_tot_neg, TA::DensePolicy>(); } +BOOST_AUTO_TEST_CASE(conj_tot_tensor_inner) { + test_conj_tot>>(); +} +BOOST_AUTO_TEST_CASE(conj_tot_arena_inner) { + test_conj_tot>>(); +} + // canonical inner contraction: c(ij;mn) = sum_k sum_o a(ijk;mo) b(ijk;on) BOOST_AUTO_TEST_CASE(einsum_contraction_tensor_inner) { test_tot_einsum_contraction, TA::DensePolicy>(