Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/TiledArray/tensor/arena_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "TiledArray/error.h"
#include "TiledArray/math/blas.h"
#include "TiledArray/math/gemm_helper.h"
#include "TiledArray/tensor/complex.h"
#include "TiledArray/tensor/type_traits.h"

#include <btas/zb/range.h>
Expand Down Expand Up @@ -250,6 +251,13 @@ class ArenaTensor {
::TiledArray::scale_to(*this, -T(1));
return *this;
}
/// In-place complex conjugation. Routes through the free `scale_to` kernel
/// with a ComplexConjugate factor, which conjugates each arena-backed scalar
/// in place (a no-op for real `T`). Mirrors `neg_to()`.
ArenaTensor& conj_to() {
::TiledArray::scale_to(*this, ::TiledArray::detail::conj_op());
return *this;
}

/// axpy: <tt>*this += other * factor</tt> (axpy semantics; factor scales
/// only the added operand). Delegates to the free `axpy` CPO that the
Expand Down
18 changes: 18 additions & 0 deletions src/TiledArray/tensor/complex.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,24 @@ TILEDARRAY_FORCE_INLINE std::complex<R> conj(const std::complex<R> z) {
return std::conj(z);
}

/// Conjugate a (possibly nested) tensor element

/// Enables `conj` of a tensor whose elements are themselves tensors
/// (tensor-of-tensors): `Tensor::conj()` is `scale(conj_op())`, which
/// multiplies each element by a ComplexConjugate operator and thus dispatches
/// back here for each inner tensor. This overload forwards to the element's own
/// `conj()`, recursing until the leaf scalar overloads above terminate it.
/// SFINAE'd on a non-numeric type that provides a `conj()` member so it never
/// competes with the scalar overloads.
/// \tparam T A (nested) tensor type
/// \param t The tensor to conjugate
/// \return The elementwise complex conjugate of `t`
template <typename T,
typename std::enable_if<!is_numeric_v<T>>::type* = nullptr>
TILEDARRAY_FORCE_INLINE auto conj(const T& t) -> decltype(t.conj()) {
return t.conj();
}

/// Inner product of a real value and a numeric value

/// \tparam L A real scalar type
Expand Down
122 changes: 86 additions & 36 deletions src/TiledArray/tensor/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,19 +111,22 @@ inline bool scale_gemm_timing_enabled() {

/// Counters for one scale regime. `{0}` member-init gives well-defined zero.
struct ScaleRegimeCounters {
std::atomic<std::uint64_t> gemm_ns{0}; // wall ns inside the strided gemm
std::atomic<std::uint64_t> fb_ns{0}; // wall ns inside the AXPY fallback
std::atomic<std::uint64_t> gemm_runs{0}; // clean rows/cols (one strided GEMM)
std::atomic<std::uint64_t> fb_runs{0}; // rows/cols that fell back to AXPY
std::atomic<std::uint64_t> gemm_flop{0}; // 2*K*N*A (clean), summed
std::atomic<std::uint64_t> fb_flop{0}; // exact 2*K*Sum(cellsize) (fallback)
std::atomic<std::uint64_t> fb_absent{0}; // fallback reason: an empty cell
std::atomic<std::uint64_t> fb_ragged{0}; // fallback reason: ragged inner size
std::atomic<std::uint64_t> fb_stride{0}; // fallback reason: multi-page stride
std::atomic<std::uint64_t> gemm_ns{0}; // wall ns inside the strided gemm
std::atomic<std::uint64_t> fb_ns{0}; // wall ns inside the AXPY fallback
std::atomic<std::uint64_t> gemm_runs{
0}; // clean rows/cols (one strided GEMM)
std::atomic<std::uint64_t> fb_runs{0}; // rows/cols that fell back to AXPY
std::atomic<std::uint64_t> gemm_flop{0}; // 2*K*N*A (clean), summed
std::atomic<std::uint64_t> fb_flop{0}; // exact 2*K*Sum(cellsize) (fallback)
std::atomic<std::uint64_t> fb_absent{0}; // fallback reason: an empty cell
std::atomic<std::uint64_t> fb_ragged{
0}; // fallback reason: ragged inner size
std::atomic<std::uint64_t> fb_stride{
0}; // fallback reason: multi-page stride
// --- phase breakdown of the per-(b,m) loop (Amdahl of the 75% overhead) ---
std::atomic<std::uint64_t> kernel_ns{0}; // whole for-b/for-m loop body
std::atomic<std::uint64_t> check_pres_ns{0};// per-row presence + size scan
std::atomic<std::uint64_t> check_str_ns{0}; // per-row constant-stride walk
std::atomic<std::uint64_t> kernel_ns{0}; // whole for-b/for-m loop body
std::atomic<std::uint64_t> check_pres_ns{0}; // per-row presence + size scan
std::atomic<std::uint64_t> check_str_ns{0}; // per-row constant-stride walk
// beta-eligibility: how many Tensor::gemm CALLS land on a freshly-allocated
// (this->empty()) output tile -- where beta=0 would be valid -- vs an
// accumulation into an existing tile (beta=1 required for correctness).
Expand Down Expand Up @@ -189,7 +192,10 @@ struct ScaleGemmTimingDumper {
const auto gns = L(g_scale[r].gemm_ns), fns = L(g_scale[r].fb_ns);
const auto gr = L(g_scale[r].gemm_runs), fr = L(g_scale[r].fb_runs);
const auto gf = L(g_scale[r].gemm_flop), ff = L(g_scale[r].fb_flop);
tg_ns += gns; tf_ns += fns; tg_fl += gf; tf_fl += ff;
tg_ns += gns;
tf_ns += fns;
tg_fl += gf;
tf_fl += ff;
const double tt = static_cast<double>(gns + fns);
const double ftot = static_cast<double>(gf + ff);
std::cerr << "[scale-timing] " << names[r] << ":\n";
Expand All @@ -200,16 +206,18 @@ struct ScaleGemmTimingDumper {
std::cerr << "[scale-timing] time coverage (GEMM / total) : "
<< (tt > 0 ? 100.0 * gns / tt : 0.0) << "%\n";
std::cerr << "[scale-timing] FLOP coverage (GEMM / total) : "
<< (ftot > 0 ? 100.0 * gf / ftot : 0.0) << "% ("
<< gf / 1e9 << " GFLOP gemm / " << ftot / 1e9 << " GFLOP)\n";
<< (ftot > 0 ? 100.0 * gf / ftot : 0.0) << "% (" << gf / 1e9
<< " GFLOP gemm / " << ftot / 1e9 << " GFLOP)\n";
std::cerr << "[scale-timing] GFLOP/s strided="
<< (gns > 0 ? gf / static_cast<double>(gns) : 0.0)
<< " fallback="
<< (fns > 0 ? ff / static_cast<double>(fns) : 0.0) << "\n";
std::cerr << "[scale-timing] fallback runs by reason: absent="
<< L(g_scale[r].fb_absent) << " ragged=" << L(g_scale[r].fb_ragged)
<< L(g_scale[r].fb_absent)
<< " ragged=" << L(g_scale[r].fb_ragged)
<< " multipage-stride=" << L(g_scale[r].fb_stride) << "\n";
// Phase breakdown of the per-(b,m) loop = where the non-GEMM overhead goes.
// Phase breakdown of the per-(b,m) loop = where the non-GEMM overhead
// goes.
const auto kn = L(g_scale[r].kernel_ns);
const auto cp = L(g_scale[r].check_pres_ns);
const auto cs = L(g_scale[r].check_str_ns);
Expand Down Expand Up @@ -2009,7 +2017,14 @@ class Tensor {
// early exit for empty this
if (empty()) return {};

if constexpr (is_tensor_view_v<value_type>) {
if constexpr (is_arena_tensor_v<value_type>) {
// Arena inner cells: scale via the arena kernel (which manages the slab),
// then apply the result permutation if non-trivial. Mirrors the arena
// add(right, perm) overload above. ArenaTensor is also a view, so this
// branch must precede the view branch below.
auto result = scale(factor);
return arena_perm_is_trivial(perm) ? result : result.permute(perm);
} else if constexpr (is_tensor_view_v<value_type>) {
TA_EXCEPTION(
"Tensor<View>::scale(factor, perm): permutation is not "
"supported for view inner cells");
Expand All @@ -2035,8 +2050,18 @@ class Tensor {
// early exit for empty this
if (empty()) return *this;

return inplace_unary(
[factor](value_type& MADNESS_RESTRICT res) { res *= factor; });
if constexpr (is_arena_tensor_v<value_type>) {
// Arena inner cells: route through each cell's own in-place scale_to (the
// free arena kernel), which handles a ComplexConjugate factor by
// conjugating each arena scalar in place. Going through `cell *= factor`
// would instead select the generic operator*=(.., ComplexConjugate) ->
// detail::conj(cell), which has no value-returning conj for ArenaTensor.
return inplace_unary(
[factor](value_type& MADNESS_RESTRICT c) { c.scale_to(factor); });
} else {
return inplace_unary(
[factor](value_type& MADNESS_RESTRICT res) { res *= factor; });
}
}

// Addition operations
Expand Down Expand Up @@ -3382,16 +3407,28 @@ class Tensor {
long a0 = -1;
for (integer k = 0; k != K; ++k) {
const auto& c = lc0[k];
if (c.empty()) { absent = true; break; }
if (c.empty()) {
absent = true;
break;
}
long s = static_cast<long>(c.size());
if (a0 < 0) a0 = s; else if (a0 != s) ragged = true;
if (a0 < 0)
a0 = s;
else if (a0 != s)
ragged = true;
}
if (!absent)
for (integer n = 0; n != N; ++n) {
const auto& c = rc0[n];
if (c.empty()) { absent = true; break; }
if (c.empty()) {
absent = true;
break;
}
long s = static_cast<long>(c.size());
if (a0 < 0) a0 = s; else if (a0 != s) ragged = true;
if (a0 < 0)
a0 = s;
else if (a0 != s)
ragged = true;
}
std::uint64_t fl = 0;
for (integer n = 0; n != N; ++n)
Expand All @@ -3401,9 +3438,9 @@ class Tensor {
1, std::memory_order_relaxed);
detail::g_scale[0].fb_flop.fetch_add(
fl, std::memory_order_relaxed);
(absent ? detail::g_scale[0].fb_absent
: ragged ? detail::g_scale[0].fb_ragged
: detail::g_scale[0].fb_stride)
(absent ? detail::g_scale[0].fb_absent
: ragged ? detail::g_scale[0].fb_ragged
: detail::g_scale[0].fb_stride)
.fetch_add(1, std::memory_order_relaxed);
}
detail::ScopedScaleTimer _scale_fb(detail::g_scale[0].fb_ns);
Expand Down Expand Up @@ -3525,28 +3562,41 @@ class Tensor {
long a0 = -1;
for (integer k = 0; k != K; ++k) {
const auto& c = right_data[k * N + n];
if (c.empty()) { absent = true; break; }
if (c.empty()) {
absent = true;
break;
}
long s = static_cast<long>(c.size());
if (a0 < 0) a0 = s; else if (a0 != s) ragged = true;
if (a0 < 0)
a0 = s;
else if (a0 != s)
ragged = true;
}
if (!absent)
for (integer m = 0; m != M; ++m) {
const auto& c = this_data[m * N + n];
if (c.empty()) { absent = true; break; }
if (c.empty()) {
absent = true;
break;
}
long s = static_cast<long>(c.size());
if (a0 < 0) a0 = s; else if (a0 != s) ragged = true;
if (a0 < 0)
a0 = s;
else if (a0 != s)
ragged = true;
}
std::uint64_t fl = 0;
for (integer m = 0; m != M; ++m)
fl += 2ull * static_cast<std::uint64_t>(K) *
static_cast<std::uint64_t>(this_data[m * N + n].size());
fl +=
2ull * static_cast<std::uint64_t>(K) *
static_cast<std::uint64_t>(this_data[m * N + n].size());
detail::g_scale[1].fb_runs.fetch_add(
1, std::memory_order_relaxed);
detail::g_scale[1].fb_flop.fetch_add(
fl, std::memory_order_relaxed);
(absent ? detail::g_scale[1].fb_absent
: ragged ? detail::g_scale[1].fb_ragged
: detail::g_scale[1].fb_stride)
(absent ? detail::g_scale[1].fb_absent
: ragged ? detail::g_scale[1].fb_ragged
: detail::g_scale[1].fb_stride)
.fetch_add(1, std::memory_order_relaxed);
}
detail::ScopedScaleTimer _scale_fb(detail::g_scale[1].fb_ns);
Expand Down
25 changes: 25 additions & 0 deletions tests/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,31 @@ BOOST_AUTO_TEST_CASE(inplace_conj_scal_op) {
}
}

BOOST_AUTO_TEST_CASE(conj_op_tensor_of_tensor) {
// conj() of a tensor-of-tensors must conjugate every inner element.
// Regression: detail::conj had no overload for a tensor-valued element, so
// Tensor<Tensor<complex>>::conj() (== scale(conj_op())) failed to compile.
using TensorOfTensorZ = Tensor<TensorZ>;
TensorOfTensorZ s(r);
for (std::size_t i = 0ul; i < s.size(); ++i) {
TensorZ inner(r);
rand_fill(static_cast<int>(431 + i), inner.size(), inner.data());
s[i] = inner;
}

TensorOfTensorZ t;
BOOST_REQUIRE_NO_THROW(t = s.conj());

BOOST_CHECK_EQUAL(t.range(), s.range());
for (std::size_t i = 0ul; i < t.size(); ++i) {
BOOST_CHECK_EQUAL(t[i].range(), s[i].range());
for (std::size_t j = 0ul; j < t[i].size(); ++j) {
BOOST_CHECK_EQUAL(t[i][j].real(), s[i][j].real());
BOOST_CHECK_EQUAL(t[i][j].imag(), -s[i][j].imag());
}
}
}

BOOST_AUTO_TEST_CASE(block) {
TensorZ s(r);
auto lobound = r.lobound();
Expand Down
61 changes: 61 additions & 0 deletions tests/tot_construction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "global_fixture.h"
#include "unit_test_config.h"

#include <complex>
#include <cstddef>
#include <vector>

Expand Down Expand Up @@ -574,6 +575,59 @@ void test_arena_tile_permute() {
}
}

/// conj()/conj(perm)/conj_to() on a tensor-of-tensors with complex inner cells.
/// Exercises the arena-aware scale(factor) and scale(factor, perm) paths (the
/// latter previously threw for view/arena inner) and in-place conj_to. Real
/// conj would be a no-op, so inner cells carry a nonzero imaginary part.
template <typename InnerTile>
void test_conj_tot() {
using OuterTile = TA::Tensor<InnerTile>;
using cd = typename InnerTile::value_type; // std::complex<double>

// rank-2 outer so a non-identity outer permutation is meaningful
const TA::Range outer{2, 2};
auto rng = [](const auto&) { return typename InnerTile::range_type{3}; };
auto fill = [](auto& cell, const auto& idx) {
const long e = static_cast<long>(idx[0]) * 10 + static_cast<long>(idx[1]);
for (std::size_t i = 0; i < cell.size(); ++i)
cell.data()[i] = cd(100.0 * e + i, 1.0 + e - static_cast<double>(i));
};
OuterTile src = TA::detail::make_nested_tile<OuterTile>(outer, rng, fill);

auto inner_is_conj = [](const InnerTile& got, const InnerTile& s) {
BOOST_REQUIRE_EQUAL(got.size(), s.size());
for (std::size_t i = 0; i < got.size(); ++i) {
BOOST_CHECK_EQUAL(got.data()[i].real(), s.data()[i].real());
BOOST_CHECK_EQUAL(got.data()[i].imag(), -s.data()[i].imag());
}
};

// out-of-place conj() (scale(conj_op) -> arena_trivial_unary)
OuterTile c = src.conj();
for (std::size_t o = 0; o < src.range().volume(); ++o)
inner_is_conj(c.data()[o], src.data()[o]);

// permuted conj() (scale(conj_op, perm) -- the arena branch under test);
// must agree with conj-then-permute.
TA::Permutation perm({1, 0}); // swap the two outer modes
OuterTile cp = src.conj(perm);
OuterTile cp_ref = src.conj().permute(perm);
BOOST_REQUIRE_EQUAL(cp.range(), cp_ref.range());
for (std::size_t o = 0; o < cp.range().volume(); ++o) {
const InnerTile& a = cp.data()[o];
const InnerTile& b = cp_ref.data()[o];
BOOST_REQUIRE_EQUAL(a.size(), b.size());
for (std::size_t i = 0; i < a.size(); ++i)
BOOST_CHECK_EQUAL(a.data()[i], b.data()[i]);
}

// in-place conj_to() (scale_to(conj_op) -> free arena operator*=)
OuterTile t2 = TA::detail::make_nested_tile<OuterTile>(outer, rng, fill);
t2.conj_to();
for (std::size_t o = 0; o < src.range().volume(); ++o)
inner_is_conj(t2.data()[o], src.data()[o]);
}

} // namespace

BOOST_AUTO_TEST_SUITE(tot_construction_suite, TA_UT_LABEL_SERIAL)
Expand Down Expand Up @@ -673,6 +727,13 @@ BOOST_AUTO_TEST_CASE(neg_arena_inner) {
test_tot_neg<TA::ArenaTensor<double>, TA::DensePolicy>();
}

BOOST_AUTO_TEST_CASE(conj_tot_tensor_inner) {
test_conj_tot<TA::Tensor<std::complex<double>>>();
}
BOOST_AUTO_TEST_CASE(conj_tot_arena_inner) {
test_conj_tot<TA::ArenaTensor<std::complex<double>>>();
}

// canonical inner contraction: c(ij;mn) = sum_k sum_o a(ijk;mo) b(ijk;on)
BOOST_AUTO_TEST_CASE(einsum_contraction_tensor_inner) {
test_tot_einsum_contraction<TA::Tensor<double>, TA::DensePolicy>(
Expand Down
Loading