diff --git a/NAM/wavenet/a2_fast.cpp b/NAM/wavenet/a2_fast.cpp index 67093c88..9b89c735 100644 --- a/NAM/wavenet/a2_fast.cpp +++ b/NAM/wavenet/a2_fast.cpp @@ -172,7 +172,11 @@ A2FastModel::A2FastModel(std::vector weights, double expected_s _load_weights(weights); - int prewarm = 0; + // Receptive field = 1 (the sample being produced) + sum of per-layer lookbacks + + // (head kernel - 1). The leading 1 matches the generic WaveNet's prewarm count + // (model.cpp: mPrewarmSamples starts at 1 when there's no condition DSP), so the + // fast path warms up by exactly the same number of samples as the model it replaces. + int prewarm = 1; for (int i = 0; i < kNumLayers; i++) prewarm += _layers[i].max_lookback; prewarm += kHeadKernelSize - 1; diff --git a/tools/run_tests.cpp b/tools/run_tests.cpp index 61c7ba4a..72839e1e 100644 --- a/tools/run_tests.cpp +++ b/tools/run_tests.cpp @@ -353,6 +353,8 @@ int main() test_a2_fast::test_detector_rejects_gating(); test_a2_fast::test_matches_generic_nano(); test_a2_fast::test_matches_generic_standard(); + test_a2_fast::test_prewarm_matches_generic_nano(); + test_a2_fast::test_prewarm_matches_generic_standard(); test_a2_fast::test_process_realtime_safe_nano(); test_a2_fast::test_process_realtime_safe_standard(); #endif diff --git a/tools/test/test_a2_fast.cpp b/tools/test/test_a2_fast.cpp index 7761f7eb..2652cd7e 100644 --- a/tools/test/test_a2_fast.cpp +++ b/tools/test/test_a2_fast.cpp @@ -268,6 +268,38 @@ void test_matches_generic_standard() test_matches_generic(8); } +// The fast path must report the same prewarm count as the generic WaveNet it +// replaces; otherwise Reset() warms the two by a different number of samples and +// their first post-Reset output can diverge (regression guard for the A2 prewarm +// off-by-one). Builds both from the identical config via the same dual path as +// test_matches_generic. +void test_prewarm_matches_generic(int channels) +{ + const auto cfg = build_a2_config(channels); + const int weight_count = a2_weight_count(channels); + const auto weights = make_deterministic_weights(weight_count, /*seed=*/0xA2FA500u + channels); + + auto fast_cfg = nam::wavenet::a2_fast::create_a2_fast_config(cfg, 48000.0); + std::vector w_fast = weights; + auto fast_dsp = fast_cfg->create(std::move(w_fast), 48000.0); + + auto generic_cfg = nam::wavenet::parse_config_json(cfg, 48000.0); + std::vector w_gen = weights; + auto generic_dsp = generic_cfg.create(std::move(w_gen), 48000.0); + + assert(fast_dsp->GetPrewarmSamples() == generic_dsp->GetPrewarmSamples()); +} + +void test_prewarm_matches_generic_nano() +{ + test_prewarm_matches_generic(3); +} + +void test_prewarm_matches_generic_standard() +{ + test_prewarm_matches_generic(8); +} + // Real-time safety: once the DSP has been Reset (buffers sized, prewarmed), // subsequent process() calls must not allocate or free heap memory. Uses the // same allocation-tracking infrastructure as the generic WaveNet RT-safety