Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions NAM/wavenet/a2_fast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,17 @@ bool is_a2_shape(const nlohmann::json& config, int* channels)
if (head_it != config.end() && !head_it->is_null())
return false;

// No conditioning DSP. When given a non-null condition_dsp the generic WaveNet
// builds a nested model and routes the conditioning signal through it before the
// layer stack; the fast path has no such stage and feeds the raw input as the
// condition. The condition DSP carries its own weights, so the parent weight
// stream is identical with or without it and the loader cannot detect the
// difference -- the detector must reject it here, or the fast path would silently
// produce different audio than the model it replaces.
auto cond_it = config.find("condition_dsp");
if (cond_it != config.end() && !cond_it->is_null())
return false;

// head_scale is loaded from the trailing weight, but require the field to
// stay schema-compatible with the generic WaveNet parser.
auto hs_it = config.find("head_scale");
Expand Down Expand Up @@ -827,6 +838,14 @@ bool is_a2_shape(const nlohmann::json& config, int* channels)
return false;
}

// Legacy boolean `gated` (the pre-gating_mode schema): the generic parser maps
// gated==true to GATED layers, which the fast path does not implement. A genuinely
// gated model has a larger weight stream and the loader would throw, but reject it
// here so the boundary is enforced by the detector rather than a downstream error.
auto gated_it = la.find("gated");
if (gated_it != la.end() && gated_it->is_boolean() && gated_it->get<bool>())
return false;

// secondary_activation: all null (or field absent)
auto sa_it = la.find("secondary_activation");
if (sa_it != la.end() && !sa_it->is_null())
Expand Down
2 changes: 2 additions & 0 deletions tools/run_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,8 @@ int main()
test_a2_fast::test_detector_rejects_wrong_kernel_sizes();
test_a2_fast::test_detector_rejects_wrong_activation();
test_a2_fast::test_detector_rejects_gating();
test_a2_fast::test_detector_rejects_condition_dsp();
test_a2_fast::test_detector_rejects_legacy_gated();
test_a2_fast::test_matches_generic_nano();
test_a2_fast::test_matches_generic_standard();
test_a2_fast::test_process_realtime_safe_nano();
Expand Down
21 changes: 21 additions & 0 deletions tools/test/test_a2_fast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,27 @@ void test_detector_rejects_gating()
assert(!nam::wavenet::a2_fast::is_a2_shape(cfg, nullptr));
}

// A condition DSP routes the conditioning signal through a nested model; the fast
// path has no such stage. The nested model holds its own weights, so the parent
// weight stream is unchanged and only the detector can catch this -- otherwise the
// fast path would silently produce different audio than the generic WaveNet.
void test_detector_rejects_condition_dsp()
{
auto cfg = build_a2_config(8);
cfg["condition_dsp"] = {{"version", "0.5.0"}, {"architecture", "Linear"},
{"config", nlohmann::json::object()}, {"weights", nlohmann::json::array()}};
assert(!nam::wavenet::a2_fast::is_a2_shape(cfg, nullptr));
}

// Legacy boolean `gated` (pre-gating_mode schema) maps to GATED layers in the
// generic parser, which the fast path does not implement.
void test_detector_rejects_legacy_gated()
{
auto cfg = build_a2_config(3);
cfg["layers"][0]["gated"] = true;
assert(!nam::wavenet::a2_fast::is_a2_shape(cfg, nullptr));
}

void test_matches_generic(int channels)
{
const auto cfg = build_a2_config(channels);
Expand Down
Loading