Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion NAM/wavenet/a2_fast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,11 @@ A2FastModel<Channels>::A2FastModel(std::vector<float> weights, double expected_s

_load_weights(weights);

int prewarm = 0;
// Receptive field = 1 (the sample being produced) + sum of per-layer lookbacks +
// (head kernel - 1). The leading 1 matches the generic WaveNet's prewarm count
// (model.cpp: mPrewarmSamples starts at 1 when there's no condition DSP), so the
// fast path warms up by exactly the same number of samples as the model it replaces.
int prewarm = 1;
for (int i = 0; i < kNumLayers; i++)
prewarm += _layers[i].max_lookback;
prewarm += kHeadKernelSize - 1;
Expand Down
2 changes: 2 additions & 0 deletions tools/run_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,8 @@ int main()
test_a2_fast::test_detector_rejects_gating();
test_a2_fast::test_matches_generic_nano();
test_a2_fast::test_matches_generic_standard();
test_a2_fast::test_prewarm_matches_generic_nano();
test_a2_fast::test_prewarm_matches_generic_standard();
test_a2_fast::test_process_realtime_safe_nano();
test_a2_fast::test_process_realtime_safe_standard();
#endif
Expand Down
32 changes: 32 additions & 0 deletions tools/test/test_a2_fast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,38 @@ void test_matches_generic_standard()
test_matches_generic(8);
}

// The fast path must report the same prewarm count as the generic WaveNet it
// replaces; otherwise Reset() warms the two by a different number of samples and
// their first post-Reset output can diverge (regression guard for the A2 prewarm
// off-by-one). Builds both from the identical config via the same dual path as
// test_matches_generic.
void test_prewarm_matches_generic(int channels)
{
const auto cfg = build_a2_config(channels);
const int weight_count = a2_weight_count(channels);
const auto weights = make_deterministic_weights(weight_count, /*seed=*/0xA2FA500u + channels);

auto fast_cfg = nam::wavenet::a2_fast::create_a2_fast_config(cfg, 48000.0);
std::vector<float> w_fast = weights;
auto fast_dsp = fast_cfg->create(std::move(w_fast), 48000.0);

auto generic_cfg = nam::wavenet::parse_config_json(cfg, 48000.0);
std::vector<float> w_gen = weights;
auto generic_dsp = generic_cfg.create(std::move(w_gen), 48000.0);

assert(fast_dsp->GetPrewarmSamples() == generic_dsp->GetPrewarmSamples());
}

void test_prewarm_matches_generic_nano()
{
test_prewarm_matches_generic(3);
}

void test_prewarm_matches_generic_standard()
{
test_prewarm_matches_generic(8);
}

// Real-time safety: once the DSP has been Reset (buffers sized, prewarmed),
// subsequent process() calls must not allocate or free heap memory. Uses the
// same allocation-tracking infrastructure as the generic WaveNet RT-safety
Expand Down
Loading