From 43bcbcdac40f95a9d7675333943b6747bb6c8789 Mon Sep 17 00:00:00 2001 From: Robert Haist Date: Thu, 25 Jun 2026 11:58:16 +0200 Subject: [PATCH] Fix A2 fast-path prewarm count to match the generic WaveNet A2FastModel computed its prewarm sample count as the layer-stack lookback distance: sum of per-layer (kernel_size-1)*dilation, plus (head_kernel-1). The generic WaveNet it replaces computes the receptive field, which is one greater: it seeds mPrewarmSamples at 1 (the sample being produced) before adding the same lookback terms (model.cpp). So the fast path warmed up by one fewer sample than the model it is meant to be a drop-in for. prewarm() runs process() in whole maxBufferSize blocks until the count is reached, so the off-by-one only changes the block count when the total (6346 for the A2 shape) is an exact multiple of the buffer size. Power-of-2 buffers (64/256/4096) mask it, which is why the existing equivalence test (test_matches_generic, blocks 64/256) never caught it; on a buffer size that divides 6346 the two paths warm up by a different number of blocks and their first post-Reset output diverges. Seed prewarm at 1 to match the generic receptive-field formula. Add test_prewarm_matches_generic_{nano,standard}, which builds both the fast and generic model from the same config and asserts equal GetPrewarmSamples(). Verified the test catches the regression: reverting the seed to 0 aborts the new assertion; with the fix the full suite passes. --- NAM/wavenet/a2_fast.cpp | 6 +++++- tools/run_tests.cpp | 2 ++ tools/test/test_a2_fast.cpp | 32 ++++++++++++++++++++++++++++++++ 3 files changed, 39 insertions(+), 1 deletion(-) diff --git a/NAM/wavenet/a2_fast.cpp b/NAM/wavenet/a2_fast.cpp index 67093c88..9b89c735 100644 --- a/NAM/wavenet/a2_fast.cpp +++ b/NAM/wavenet/a2_fast.cpp @@ -172,7 +172,11 @@ A2FastModel::A2FastModel(std::vector weights, double expected_s _load_weights(weights); - int prewarm = 0; + // Receptive field = 1 (the sample being produced) + sum of per-layer lookbacks + + // (head kernel - 1). The leading 1 matches the generic WaveNet's prewarm count + // (model.cpp: mPrewarmSamples starts at 1 when there's no condition DSP), so the + // fast path warms up by exactly the same number of samples as the model it replaces. + int prewarm = 1; for (int i = 0; i < kNumLayers; i++) prewarm += _layers[i].max_lookback; prewarm += kHeadKernelSize - 1; diff --git a/tools/run_tests.cpp b/tools/run_tests.cpp index 61c7ba4a..72839e1e 100644 --- a/tools/run_tests.cpp +++ b/tools/run_tests.cpp @@ -353,6 +353,8 @@ int main() test_a2_fast::test_detector_rejects_gating(); test_a2_fast::test_matches_generic_nano(); test_a2_fast::test_matches_generic_standard(); + test_a2_fast::test_prewarm_matches_generic_nano(); + test_a2_fast::test_prewarm_matches_generic_standard(); test_a2_fast::test_process_realtime_safe_nano(); test_a2_fast::test_process_realtime_safe_standard(); #endif diff --git a/tools/test/test_a2_fast.cpp b/tools/test/test_a2_fast.cpp index 7761f7eb..2652cd7e 100644 --- a/tools/test/test_a2_fast.cpp +++ b/tools/test/test_a2_fast.cpp @@ -268,6 +268,38 @@ void test_matches_generic_standard() test_matches_generic(8); } +// The fast path must report the same prewarm count as the generic WaveNet it +// replaces; otherwise Reset() warms the two by a different number of samples and +// their first post-Reset output can diverge (regression guard for the A2 prewarm +// off-by-one). Builds both from the identical config via the same dual path as +// test_matches_generic. +void test_prewarm_matches_generic(int channels) +{ + const auto cfg = build_a2_config(channels); + const int weight_count = a2_weight_count(channels); + const auto weights = make_deterministic_weights(weight_count, /*seed=*/0xA2FA500u + channels); + + auto fast_cfg = nam::wavenet::a2_fast::create_a2_fast_config(cfg, 48000.0); + std::vector w_fast = weights; + auto fast_dsp = fast_cfg->create(std::move(w_fast), 48000.0); + + auto generic_cfg = nam::wavenet::parse_config_json(cfg, 48000.0); + std::vector w_gen = weights; + auto generic_dsp = generic_cfg.create(std::move(w_gen), 48000.0); + + assert(fast_dsp->GetPrewarmSamples() == generic_dsp->GetPrewarmSamples()); +} + +void test_prewarm_matches_generic_nano() +{ + test_prewarm_matches_generic(3); +} + +void test_prewarm_matches_generic_standard() +{ + test_prewarm_matches_generic(8); +} + // Real-time safety: once the DSP has been Reset (buffers sized, prewarmed), // subsequent process() calls must not allocate or free heap memory. Uses the // same allocation-tracking infrastructure as the generic WaveNet RT-safety