From c41d066d77c215a708bdbeee631e1627441cff87 Mon Sep 17 00:00:00 2001 From: Steven Atkinson Date: Tue, 23 Jun 2026 15:01:48 -0700 Subject: [PATCH 1/4] Add slimmable breakpoint introspection --- NAM/container.cpp | 28 ++++++++++++- NAM/container.h | 4 ++ NAM/slimmable.h | 15 +++++++ NAM/wavenet/slimmable.cpp | 60 ++++++++++++++++++++++----- NAM/wavenet/slimmable.h | 3 ++ tools/run_tests.cpp | 2 + tools/test/test_container.cpp | 26 ++++++++++++ tools/test/test_slimmable_wavenet.cpp | 39 +++++++++++++++++ 8 files changed, 165 insertions(+), 12 deletions(-) diff --git a/NAM/container.cpp b/NAM/container.cpp index 7d497fd4..80fd85e3 100644 --- a/NAM/container.cpp +++ b/NAM/container.cpp @@ -82,7 +82,7 @@ void ContainerModel::Reset(const double sampleRate, const int maxBufferSize) _submodels[active_index].model->Reset(sampleRate, maxBufferSize); } -void ContainerModel::SetSlimmableSize(const double val) +size_t ContainerModel::_get_index_for_slimmable_size(const double val) const { size_t active_index = _submodels.size() - 1; for (size_t i = 0; i < _submodels.size(); ++i) @@ -93,6 +93,12 @@ void ContainerModel::SetSlimmableSize(const double val) break; } } + return active_index; +} + +void ContainerModel::SetSlimmableSize(const double val) +{ + const size_t active_index = _get_index_for_slimmable_size(val); // Fast path: no change to active model. if (active_index == _active_index.load(std::memory_order_acquire)) @@ -115,6 +121,26 @@ void ContainerModel::SetSlimmableSize(const double val) _active_index.store(active_index, std::memory_order_release); } +std::vector ContainerModel::GetSlimmableSizeBreakpoints() const +{ + std::vector breakpoints; + breakpoints.reserve(_submodels.size()); + breakpoints.push_back(0.0); + + for (size_t i = 0; i + 1 < _submodels.size(); ++i) + { + if (_submodels[i].max_value != breakpoints.back()) + breakpoints.push_back(_submodels[i].max_value); + } + + return breakpoints; +} + +bool ContainerModel::WillSlimmableSizeChange(const double val) const +{ + return _get_index_for_slimmable_size(val) != _active_index.load(std::memory_order_acquire); +} + int ContainerModel::GetPrewarmSamples() { const size_t active_index = _active_index.load(std::memory_order_acquire); diff --git a/NAM/container.h b/NAM/container.h index 2c82baed..dbe0030f 100644 --- a/NAM/container.h +++ b/NAM/container.h @@ -39,9 +39,13 @@ class ContainerModel : public DSP, public SlimmableModel void Reset(const double sampleRate, const int maxBufferSize) override; void SetPrewarmOnReset(const bool prewarmOnReset) override; void SetSlimmableSize(const double val) override; + std::vector GetSlimmableSizeBreakpoints() const override; + bool WillSlimmableSizeChange(const double val) const override; int GetPrewarmSamples() override; private: + size_t _get_index_for_slimmable_size(const double val) const; + std::vector _submodels; std::atomic _active_index{0}; std::mutex _slim_set_mutex; diff --git a/NAM/slimmable.h b/NAM/slimmable.h index 804bc185..5105b69c 100644 --- a/NAM/slimmable.h +++ b/NAM/slimmable.h @@ -1,5 +1,7 @@ #pragma once +#include + namespace nam { @@ -19,6 +21,19 @@ class SlimmableModel /// Thread-safe /// Not real-time safe virtual void SetSlimmableSize(const double val) = 0; + + /// \brief Get normalized size-control values where the selected slimmed model can change + /// \return Sorted breakpoints in [0.0, 1.0], including 0.0 when known + virtual std::vector GetSlimmableSizeBreakpoints() const { return {}; } + + /// \brief Check whether SetSlimmableSize(val) would require selecting or staging a different slimmed model + /// \param val Value between 0.0 (minimum size) and 1.0 (maximum size) + /// \return true if val maps to a slimmed model that is not already active or staged + virtual bool WillSlimmableSizeChange(const double val) const + { + (void)val; + return true; + } }; } // namespace nam diff --git a/NAM/wavenet/slimmable.cpp b/NAM/wavenet/slimmable.cpp index 6248618b..361655de 100644 --- a/NAM/wavenet/slimmable.cpp +++ b/NAM/wavenet/slimmable.cpp @@ -105,6 +105,21 @@ int ratio_to_channels(double ratio, const std::vector& allowed) return allowed[idx]; } +std::vector get_ratio_breakpoints(const std::vector>& per_array_allowed_channels) +{ + std::vector breakpoints{0.0}; + + for (const auto& allowed : per_array_allowed_channels) + { + for (size_t i = 1; i < allowed.size(); ++i) + breakpoints.push_back((double)i / (double)allowed.size()); + } + + std::sort(breakpoints.begin(), breakpoints.end()); + breakpoints.erase(std::unique(breakpoints.begin(), breakpoints.end()), breakpoints.end()); + return breakpoints; +} + // ============================================================================ // Extract slimmed weights by walking the full weight vector in set_weights_ // order, using typed LayerArrayParams for dimensions. @@ -386,6 +401,23 @@ SlimmableWavenet::SlimmableWavenet(std::vector origin _rebuild_model(full_channels); } +std::vector SlimmableWavenet::_get_channels_for_slimmable_size(const double val) const +{ + const size_t num_arrays = _original_params.size(); + std::vector target(num_arrays); + + for (size_t i = 0; i < num_arrays; i++) + { + const auto& allowed = _per_array_allowed_channels[i]; + if (allowed.empty()) + target[i] = _original_params[i].channels; // Non-slimmable: keep full + else + target[i] = ratio_to_channels(val, allowed); + } + + return target; +} + std::unique_ptr SlimmableWavenet::_create_wavenet_for_channels(const std::vector& target_channels) { std::vector weights; @@ -494,19 +526,25 @@ void SlimmableWavenet::SetPrewarmOnReset(const bool prewarmOnReset) void SlimmableWavenet::SetSlimmableSize(const double val) { - const size_t num_arrays = _original_params.size(); - std::vector target(num_arrays); + _stage_rebuild_model(_get_channels_for_slimmable_size(val)); +} - for (size_t i = 0; i < num_arrays; i++) - { - const auto& allowed = _per_array_allowed_channels[i]; - if (allowed.empty()) - target[i] = _original_params[i].channels; // Non-slimmable: keep full - else - target[i] = ratio_to_channels(val, allowed); - } +std::vector SlimmableWavenet::GetSlimmableSizeBreakpoints() const +{ + return get_ratio_breakpoints(_per_array_allowed_channels); +} + +bool SlimmableWavenet::WillSlimmableSizeChange(const double val) const +{ + const auto target = _get_channels_for_slimmable_size(val); + + if (target == _current_channels && _active_model) + return false; - _stage_rebuild_model(target); + if (auto pending = _pending_load_acquire()) + return pending->channels != target; + + return true; } // ============================================================================ diff --git a/NAM/wavenet/slimmable.h b/NAM/wavenet/slimmable.h index df57b052..4eb8cb89 100644 --- a/NAM/wavenet/slimmable.h +++ b/NAM/wavenet/slimmable.h @@ -60,6 +60,8 @@ class SlimmableWavenet : public DSP, public SlimmableModel void Reset(const double sampleRate, const int maxBufferSize) override; void SetPrewarmOnReset(const bool prewarmOnReset) override; void SetSlimmableSize(const double val) override; + std::vector GetSlimmableSizeBreakpoints() const override; + bool WillSlimmableSizeChange(const double val) const override; protected: int GetPrewarmSamples() override { return 0; } @@ -92,6 +94,7 @@ class SlimmableWavenet : public DSP, public SlimmableModel int _current_buffer_size = 0; double _current_sample_rate = 0.0; + std::vector _get_channels_for_slimmable_size(const double val) const; std::unique_ptr _create_wavenet_for_channels(const std::vector& target_channels); void _rebuild_model(const std::vector& target_channels); void _stage_rebuild_model(const std::vector& target_channels); diff --git a/tools/run_tests.cpp b/tools/run_tests.cpp index fab71f8e..a8c2ecf1 100644 --- a/tools/run_tests.cpp +++ b/tools/run_tests.cpp @@ -307,6 +307,7 @@ int main() test_container::test_container_processes_audio(); test_container::test_container_slimmable_selects_submodel(); test_container::test_container_boundary_values(); + test_container::test_container_slimmable_breakpoints_and_change_check(); test_container::test_container_empty_submodels_throws(); test_container::test_container_last_max_value_must_cover_one(); test_container::test_container_unsorted_submodels_throws(); @@ -330,6 +331,7 @@ int main() test_slimmable_wavenet::test_boundary_values(); test_slimmable_wavenet::test_default_is_max_size(); test_slimmable_wavenet::test_ratio_mapping(); + test_slimmable_wavenet::test_slimmable_breakpoints_and_change_check(); test_slimmable_wavenet::test_from_json(); test_slimmable_wavenet::test_wavenet_without_slimmable_loads_as_regular(); test_slimmable_wavenet::test_unsupported_method_throws(); diff --git a/tools/test/test_container.cpp b/tools/test/test_container.cpp index 46b17527..0b160a12 100644 --- a/tools/test/test_container.cpp +++ b/tools/test/test_container.cpp @@ -240,6 +240,32 @@ void test_container_boundary_values() assert(std::isfinite(output[i])); } +void test_container_slimmable_breakpoints_and_change_check() +{ + CountingDSP* small = nullptr; + CountingDSP* large = nullptr; + auto dsp = build_counting_container(small, large); + + auto* slimmable = dynamic_cast(dsp.get()); + assert(slimmable != nullptr); + + const auto breakpoints = slimmable->GetSlimmableSizeBreakpoints(); + assert(breakpoints.size() == 2); + assert(breakpoints[0] == 0.0); + assert(breakpoints[1] == 0.5); + + // Defaults to the largest submodel. + assert(!slimmable->WillSlimmableSizeChange(1.0)); + assert(!slimmable->WillSlimmableSizeChange(0.5)); + assert(slimmable->WillSlimmableSizeChange(0.49)); + + slimmable->SetSlimmableSize(0.0); + assert(!slimmable->WillSlimmableSizeChange(0.0)); + assert(!slimmable->WillSlimmableSizeChange(0.49)); + assert(slimmable->WillSlimmableSizeChange(0.5)); + assert(slimmable->WillSlimmableSizeChange(1.0)); +} + void test_container_empty_submodels_throws() { nlohmann::json j; diff --git a/tools/test/test_slimmable_wavenet.cpp b/tools/test/test_slimmable_wavenet.cpp index 388ceeae..10fb1d5c 100644 --- a/tools/test/test_slimmable_wavenet.cpp +++ b/tools/test/test_slimmable_wavenet.cpp @@ -244,6 +244,45 @@ void test_ratio_mapping() assert(any_different); } +void test_slimmable_breakpoints_and_change_check() +{ + auto dsp = nam::get_dsp(std::filesystem::path("example_models/slimmable_wavenet.nam")); + assert(dsp != nullptr); + + auto* slimmable = dynamic_cast(dsp.get()); + assert(slimmable != nullptr); + + const auto breakpoints = slimmable->GetSlimmableSizeBreakpoints(); + assert(breakpoints.size() == 3); + assert(breakpoints[0] == 0.0); + assert(std::abs(breakpoints[1] - (1.0 / 3.0)) < 1e-12); + assert(std::abs(breakpoints[2] - (2.0 / 3.0)) < 1e-12); + + // Defaults to the full-size model. + assert(!slimmable->WillSlimmableSizeChange(1.0)); + assert(!slimmable->WillSlimmableSizeChange(0.67)); + assert(slimmable->WillSlimmableSizeChange(0.66)); + + // Staging the same slimmed channel count again should not require another rebuild. + slimmable->SetSlimmableSize(0.34); + assert(!slimmable->WillSlimmableSizeChange(0.5)); + assert(slimmable->WillSlimmableSizeChange(0.0)); + + const double sample_rate = dsp->GetExpectedSampleRate() > 0 ? dsp->GetExpectedSampleRate() : 48000.0; + const int buffer_size = 16; + dsp->Reset(sample_rate, buffer_size); + + std::vector input(buffer_size, 0.1); + std::vector output(buffer_size); + NAM_SAMPLE* in_ptr = input.data(); + NAM_SAMPLE* out_ptr = output.data(); + dsp->process(&in_ptr, &out_ptr, buffer_size); + + assert(!slimmable->WillSlimmableSizeChange(0.5)); + assert(slimmable->WillSlimmableSizeChange(0.0)); + assert(slimmable->WillSlimmableSizeChange(1.0)); +} + void test_from_json() { From 70d209a07111b6a9b6290a3dba6238efd7da04d3 Mon Sep 17 00:00:00 2001 From: Steven Atkinson Date: Tue, 23 Jun 2026 15:09:02 -0700 Subject: [PATCH 2/4] Remove slimmable size change check --- NAM/container.cpp | 5 ----- NAM/container.h | 1 - NAM/slimmable.h | 9 --------- NAM/wavenet/slimmable.cpp | 13 ------------- NAM/wavenet/slimmable.h | 1 - tools/run_tests.cpp | 4 ++-- tools/test/test_container.cpp | 13 +------------ tools/test/test_slimmable_wavenet.cpp | 26 +------------------------- 8 files changed, 4 insertions(+), 68 deletions(-) diff --git a/NAM/container.cpp b/NAM/container.cpp index 80fd85e3..0da65f1c 100644 --- a/NAM/container.cpp +++ b/NAM/container.cpp @@ -136,11 +136,6 @@ std::vector ContainerModel::GetSlimmableSizeBreakpoints() const return breakpoints; } -bool ContainerModel::WillSlimmableSizeChange(const double val) const -{ - return _get_index_for_slimmable_size(val) != _active_index.load(std::memory_order_acquire); -} - int ContainerModel::GetPrewarmSamples() { const size_t active_index = _active_index.load(std::memory_order_acquire); diff --git a/NAM/container.h b/NAM/container.h index dbe0030f..01eb4c63 100644 --- a/NAM/container.h +++ b/NAM/container.h @@ -40,7 +40,6 @@ class ContainerModel : public DSP, public SlimmableModel void SetPrewarmOnReset(const bool prewarmOnReset) override; void SetSlimmableSize(const double val) override; std::vector GetSlimmableSizeBreakpoints() const override; - bool WillSlimmableSizeChange(const double val) const override; int GetPrewarmSamples() override; private: diff --git a/NAM/slimmable.h b/NAM/slimmable.h index 5105b69c..b02b50fb 100644 --- a/NAM/slimmable.h +++ b/NAM/slimmable.h @@ -25,15 +25,6 @@ class SlimmableModel /// \brief Get normalized size-control values where the selected slimmed model can change /// \return Sorted breakpoints in [0.0, 1.0], including 0.0 when known virtual std::vector GetSlimmableSizeBreakpoints() const { return {}; } - - /// \brief Check whether SetSlimmableSize(val) would require selecting or staging a different slimmed model - /// \param val Value between 0.0 (minimum size) and 1.0 (maximum size) - /// \return true if val maps to a slimmed model that is not already active or staged - virtual bool WillSlimmableSizeChange(const double val) const - { - (void)val; - return true; - } }; } // namespace nam diff --git a/NAM/wavenet/slimmable.cpp b/NAM/wavenet/slimmable.cpp index 361655de..5d4dc004 100644 --- a/NAM/wavenet/slimmable.cpp +++ b/NAM/wavenet/slimmable.cpp @@ -534,19 +534,6 @@ std::vector SlimmableWavenet::GetSlimmableSizeBreakpoints() const return get_ratio_breakpoints(_per_array_allowed_channels); } -bool SlimmableWavenet::WillSlimmableSizeChange(const double val) const -{ - const auto target = _get_channels_for_slimmable_size(val); - - if (target == _current_channels && _active_model) - return false; - - if (auto pending = _pending_load_acquire()) - return pending->channels != target; - - return true; -} - // ============================================================================ // Config / factory / registration // ============================================================================ diff --git a/NAM/wavenet/slimmable.h b/NAM/wavenet/slimmable.h index 4eb8cb89..12c0268e 100644 --- a/NAM/wavenet/slimmable.h +++ b/NAM/wavenet/slimmable.h @@ -61,7 +61,6 @@ class SlimmableWavenet : public DSP, public SlimmableModel void SetPrewarmOnReset(const bool prewarmOnReset) override; void SetSlimmableSize(const double val) override; std::vector GetSlimmableSizeBreakpoints() const override; - bool WillSlimmableSizeChange(const double val) const override; protected: int GetPrewarmSamples() override { return 0; } diff --git a/tools/run_tests.cpp b/tools/run_tests.cpp index a8c2ecf1..17a2e685 100644 --- a/tools/run_tests.cpp +++ b/tools/run_tests.cpp @@ -307,7 +307,7 @@ int main() test_container::test_container_processes_audio(); test_container::test_container_slimmable_selects_submodel(); test_container::test_container_boundary_values(); - test_container::test_container_slimmable_breakpoints_and_change_check(); + test_container::test_container_slimmable_breakpoints(); test_container::test_container_empty_submodels_throws(); test_container::test_container_last_max_value_must_cover_one(); test_container::test_container_unsorted_submodels_throws(); @@ -331,7 +331,7 @@ int main() test_slimmable_wavenet::test_boundary_values(); test_slimmable_wavenet::test_default_is_max_size(); test_slimmable_wavenet::test_ratio_mapping(); - test_slimmable_wavenet::test_slimmable_breakpoints_and_change_check(); + test_slimmable_wavenet::test_slimmable_breakpoints(); test_slimmable_wavenet::test_from_json(); test_slimmable_wavenet::test_wavenet_without_slimmable_loads_as_regular(); test_slimmable_wavenet::test_unsupported_method_throws(); diff --git a/tools/test/test_container.cpp b/tools/test/test_container.cpp index 0b160a12..8aafcf99 100644 --- a/tools/test/test_container.cpp +++ b/tools/test/test_container.cpp @@ -240,7 +240,7 @@ void test_container_boundary_values() assert(std::isfinite(output[i])); } -void test_container_slimmable_breakpoints_and_change_check() +void test_container_slimmable_breakpoints() { CountingDSP* small = nullptr; CountingDSP* large = nullptr; @@ -253,17 +253,6 @@ void test_container_slimmable_breakpoints_and_change_check() assert(breakpoints.size() == 2); assert(breakpoints[0] == 0.0); assert(breakpoints[1] == 0.5); - - // Defaults to the largest submodel. - assert(!slimmable->WillSlimmableSizeChange(1.0)); - assert(!slimmable->WillSlimmableSizeChange(0.5)); - assert(slimmable->WillSlimmableSizeChange(0.49)); - - slimmable->SetSlimmableSize(0.0); - assert(!slimmable->WillSlimmableSizeChange(0.0)); - assert(!slimmable->WillSlimmableSizeChange(0.49)); - assert(slimmable->WillSlimmableSizeChange(0.5)); - assert(slimmable->WillSlimmableSizeChange(1.0)); } void test_container_empty_submodels_throws() diff --git a/tools/test/test_slimmable_wavenet.cpp b/tools/test/test_slimmable_wavenet.cpp index 10fb1d5c..534cd872 100644 --- a/tools/test/test_slimmable_wavenet.cpp +++ b/tools/test/test_slimmable_wavenet.cpp @@ -244,7 +244,7 @@ void test_ratio_mapping() assert(any_different); } -void test_slimmable_breakpoints_and_change_check() +void test_slimmable_breakpoints() { auto dsp = nam::get_dsp(std::filesystem::path("example_models/slimmable_wavenet.nam")); assert(dsp != nullptr); @@ -257,30 +257,6 @@ void test_slimmable_breakpoints_and_change_check() assert(breakpoints[0] == 0.0); assert(std::abs(breakpoints[1] - (1.0 / 3.0)) < 1e-12); assert(std::abs(breakpoints[2] - (2.0 / 3.0)) < 1e-12); - - // Defaults to the full-size model. - assert(!slimmable->WillSlimmableSizeChange(1.0)); - assert(!slimmable->WillSlimmableSizeChange(0.67)); - assert(slimmable->WillSlimmableSizeChange(0.66)); - - // Staging the same slimmed channel count again should not require another rebuild. - slimmable->SetSlimmableSize(0.34); - assert(!slimmable->WillSlimmableSizeChange(0.5)); - assert(slimmable->WillSlimmableSizeChange(0.0)); - - const double sample_rate = dsp->GetExpectedSampleRate() > 0 ? dsp->GetExpectedSampleRate() : 48000.0; - const int buffer_size = 16; - dsp->Reset(sample_rate, buffer_size); - - std::vector input(buffer_size, 0.1); - std::vector output(buffer_size); - NAM_SAMPLE* in_ptr = input.data(); - NAM_SAMPLE* out_ptr = output.data(); - dsp->process(&in_ptr, &out_ptr, buffer_size); - - assert(!slimmable->WillSlimmableSizeChange(0.5)); - assert(slimmable->WillSlimmableSizeChange(0.0)); - assert(slimmable->WillSlimmableSizeChange(1.0)); } void test_from_json() From 1a0d2f18da8395e6b85580810e466d8b815789d1 Mon Sep 17 00:00:00 2001 From: Steven Atkinson Date: Tue, 23 Jun 2026 15:14:49 -0700 Subject: [PATCH 3/4] Document future slimmable breakpoint API break --- NAM/slimmable.h | 1 + 1 file changed, 1 insertion(+) diff --git a/NAM/slimmable.h b/NAM/slimmable.h index b02b50fb..4054aca6 100644 --- a/NAM/slimmable.h +++ b/NAM/slimmable.h @@ -24,6 +24,7 @@ class SlimmableModel /// \brief Get normalized size-control values where the selected slimmed model can change /// \return Sorted breakpoints in [0.0, 1.0], including 0.0 when known + // TODO: Make this abstract in the next breaking release. virtual std::vector GetSlimmableSizeBreakpoints() const { return {}; } }; From 290cb00346ec1f9f1d4382e6596ad8ac68253cc7 Mon Sep 17 00:00:00 2001 From: Steven Atkinson Date: Tue, 23 Jun 2026 15:21:53 -0700 Subject: [PATCH 4/4] Return internal slimmable breakpoints --- NAM/container.cpp | 8 ++------ NAM/slimmable.h | 4 ++-- NAM/wavenet/slimmable.cpp | 2 +- tools/test/test_container.cpp | 5 ++--- tools/test/test_slimmable_wavenet.cpp | 7 +++---- 5 files changed, 10 insertions(+), 16 deletions(-) diff --git a/NAM/container.cpp b/NAM/container.cpp index 0da65f1c..1275ac4d 100644 --- a/NAM/container.cpp +++ b/NAM/container.cpp @@ -124,14 +124,10 @@ void ContainerModel::SetSlimmableSize(const double val) std::vector ContainerModel::GetSlimmableSizeBreakpoints() const { std::vector breakpoints; - breakpoints.reserve(_submodels.size()); - breakpoints.push_back(0.0); + breakpoints.reserve(_submodels.size() - 1); for (size_t i = 0; i + 1 < _submodels.size(); ++i) - { - if (_submodels[i].max_value != breakpoints.back()) - breakpoints.push_back(_submodels[i].max_value); - } + breakpoints.push_back(_submodels[i].max_value); return breakpoints; } diff --git a/NAM/slimmable.h b/NAM/slimmable.h index 4054aca6..84899035 100644 --- a/NAM/slimmable.h +++ b/NAM/slimmable.h @@ -22,8 +22,8 @@ class SlimmableModel /// Not real-time safe virtual void SetSlimmableSize(const double val) = 0; - /// \brief Get normalized size-control values where the selected slimmed model can change - /// \return Sorted breakpoints in [0.0, 1.0], including 0.0 when known + /// \brief Get normalized size-control values that divide the selectable slimmed models + /// \return Sorted internal breakpoints in (0.0, 1.0); 0.0 and 1.0 are implied bounds // TODO: Make this abstract in the next breaking release. virtual std::vector GetSlimmableSizeBreakpoints() const { return {}; } }; diff --git a/NAM/wavenet/slimmable.cpp b/NAM/wavenet/slimmable.cpp index 5d4dc004..71121059 100644 --- a/NAM/wavenet/slimmable.cpp +++ b/NAM/wavenet/slimmable.cpp @@ -107,7 +107,7 @@ int ratio_to_channels(double ratio, const std::vector& allowed) std::vector get_ratio_breakpoints(const std::vector>& per_array_allowed_channels) { - std::vector breakpoints{0.0}; + std::vector breakpoints; for (const auto& allowed : per_array_allowed_channels) { diff --git a/tools/test/test_container.cpp b/tools/test/test_container.cpp index 8aafcf99..b4dc65c8 100644 --- a/tools/test/test_container.cpp +++ b/tools/test/test_container.cpp @@ -250,9 +250,8 @@ void test_container_slimmable_breakpoints() assert(slimmable != nullptr); const auto breakpoints = slimmable->GetSlimmableSizeBreakpoints(); - assert(breakpoints.size() == 2); - assert(breakpoints[0] == 0.0); - assert(breakpoints[1] == 0.5); + assert(breakpoints.size() == 1); + assert(breakpoints[0] == 0.5); } void test_container_empty_submodels_throws() diff --git a/tools/test/test_slimmable_wavenet.cpp b/tools/test/test_slimmable_wavenet.cpp index 534cd872..131ead9f 100644 --- a/tools/test/test_slimmable_wavenet.cpp +++ b/tools/test/test_slimmable_wavenet.cpp @@ -253,10 +253,9 @@ void test_slimmable_breakpoints() assert(slimmable != nullptr); const auto breakpoints = slimmable->GetSlimmableSizeBreakpoints(); - assert(breakpoints.size() == 3); - assert(breakpoints[0] == 0.0); - assert(std::abs(breakpoints[1] - (1.0 / 3.0)) < 1e-12); - assert(std::abs(breakpoints[2] - (2.0 / 3.0)) < 1e-12); + assert(breakpoints.size() == 2); + assert(std::abs(breakpoints[0] - (1.0 / 3.0)) < 1e-12); + assert(std::abs(breakpoints[1] - (2.0 / 3.0)) < 1e-12); } void test_from_json()