diff --git a/NAM/container.cpp b/NAM/container.cpp index 7d497fd4..1275ac4d 100644 --- a/NAM/container.cpp +++ b/NAM/container.cpp @@ -82,7 +82,7 @@ void ContainerModel::Reset(const double sampleRate, const int maxBufferSize) _submodels[active_index].model->Reset(sampleRate, maxBufferSize); } -void ContainerModel::SetSlimmableSize(const double val) +size_t ContainerModel::_get_index_for_slimmable_size(const double val) const { size_t active_index = _submodels.size() - 1; for (size_t i = 0; i < _submodels.size(); ++i) @@ -93,6 +93,12 @@ void ContainerModel::SetSlimmableSize(const double val) break; } } + return active_index; +} + +void ContainerModel::SetSlimmableSize(const double val) +{ + const size_t active_index = _get_index_for_slimmable_size(val); // Fast path: no change to active model. if (active_index == _active_index.load(std::memory_order_acquire)) @@ -115,6 +121,17 @@ void ContainerModel::SetSlimmableSize(const double val) _active_index.store(active_index, std::memory_order_release); } +std::vector ContainerModel::GetSlimmableSizeBreakpoints() const +{ + std::vector breakpoints; + breakpoints.reserve(_submodels.size() - 1); + + for (size_t i = 0; i + 1 < _submodels.size(); ++i) + breakpoints.push_back(_submodels[i].max_value); + + return breakpoints; +} + int ContainerModel::GetPrewarmSamples() { const size_t active_index = _active_index.load(std::memory_order_acquire); diff --git a/NAM/container.h b/NAM/container.h index 2c82baed..01eb4c63 100644 --- a/NAM/container.h +++ b/NAM/container.h @@ -39,9 +39,12 @@ class ContainerModel : public DSP, public SlimmableModel void Reset(const double sampleRate, const int maxBufferSize) override; void SetPrewarmOnReset(const bool prewarmOnReset) override; void SetSlimmableSize(const double val) override; + std::vector GetSlimmableSizeBreakpoints() const override; int GetPrewarmSamples() override; private: + size_t _get_index_for_slimmable_size(const double val) const; + std::vector _submodels; std::atomic _active_index{0}; std::mutex _slim_set_mutex; diff --git a/NAM/slimmable.h b/NAM/slimmable.h index 804bc185..84899035 100644 --- a/NAM/slimmable.h +++ b/NAM/slimmable.h @@ -1,5 +1,7 @@ #pragma once +#include + namespace nam { @@ -19,6 +21,11 @@ class SlimmableModel /// Thread-safe /// Not real-time safe virtual void SetSlimmableSize(const double val) = 0; + + /// \brief Get normalized size-control values that divide the selectable slimmed models + /// \return Sorted internal breakpoints in (0.0, 1.0); 0.0 and 1.0 are implied bounds + // TODO: Make this abstract in the next breaking release. + virtual std::vector GetSlimmableSizeBreakpoints() const { return {}; } }; } // namespace nam diff --git a/NAM/wavenet/slimmable.cpp b/NAM/wavenet/slimmable.cpp index 6248618b..71121059 100644 --- a/NAM/wavenet/slimmable.cpp +++ b/NAM/wavenet/slimmable.cpp @@ -105,6 +105,21 @@ int ratio_to_channels(double ratio, const std::vector& allowed) return allowed[idx]; } +std::vector get_ratio_breakpoints(const std::vector>& per_array_allowed_channels) +{ + std::vector breakpoints; + + for (const auto& allowed : per_array_allowed_channels) + { + for (size_t i = 1; i < allowed.size(); ++i) + breakpoints.push_back((double)i / (double)allowed.size()); + } + + std::sort(breakpoints.begin(), breakpoints.end()); + breakpoints.erase(std::unique(breakpoints.begin(), breakpoints.end()), breakpoints.end()); + return breakpoints; +} + // ============================================================================ // Extract slimmed weights by walking the full weight vector in set_weights_ // order, using typed LayerArrayParams for dimensions. @@ -386,6 +401,23 @@ SlimmableWavenet::SlimmableWavenet(std::vector origin _rebuild_model(full_channels); } +std::vector SlimmableWavenet::_get_channels_for_slimmable_size(const double val) const +{ + const size_t num_arrays = _original_params.size(); + std::vector target(num_arrays); + + for (size_t i = 0; i < num_arrays; i++) + { + const auto& allowed = _per_array_allowed_channels[i]; + if (allowed.empty()) + target[i] = _original_params[i].channels; // Non-slimmable: keep full + else + target[i] = ratio_to_channels(val, allowed); + } + + return target; +} + std::unique_ptr SlimmableWavenet::_create_wavenet_for_channels(const std::vector& target_channels) { std::vector weights; @@ -494,19 +526,12 @@ void SlimmableWavenet::SetPrewarmOnReset(const bool prewarmOnReset) void SlimmableWavenet::SetSlimmableSize(const double val) { - const size_t num_arrays = _original_params.size(); - std::vector target(num_arrays); - - for (size_t i = 0; i < num_arrays; i++) - { - const auto& allowed = _per_array_allowed_channels[i]; - if (allowed.empty()) - target[i] = _original_params[i].channels; // Non-slimmable: keep full - else - target[i] = ratio_to_channels(val, allowed); - } + _stage_rebuild_model(_get_channels_for_slimmable_size(val)); +} - _stage_rebuild_model(target); +std::vector SlimmableWavenet::GetSlimmableSizeBreakpoints() const +{ + return get_ratio_breakpoints(_per_array_allowed_channels); } // ============================================================================ diff --git a/NAM/wavenet/slimmable.h b/NAM/wavenet/slimmable.h index df57b052..12c0268e 100644 --- a/NAM/wavenet/slimmable.h +++ b/NAM/wavenet/slimmable.h @@ -60,6 +60,7 @@ class SlimmableWavenet : public DSP, public SlimmableModel void Reset(const double sampleRate, const int maxBufferSize) override; void SetPrewarmOnReset(const bool prewarmOnReset) override; void SetSlimmableSize(const double val) override; + std::vector GetSlimmableSizeBreakpoints() const override; protected: int GetPrewarmSamples() override { return 0; } @@ -92,6 +93,7 @@ class SlimmableWavenet : public DSP, public SlimmableModel int _current_buffer_size = 0; double _current_sample_rate = 0.0; + std::vector _get_channels_for_slimmable_size(const double val) const; std::unique_ptr _create_wavenet_for_channels(const std::vector& target_channels); void _rebuild_model(const std::vector& target_channels); void _stage_rebuild_model(const std::vector& target_channels); diff --git a/tools/run_tests.cpp b/tools/run_tests.cpp index fab71f8e..17a2e685 100644 --- a/tools/run_tests.cpp +++ b/tools/run_tests.cpp @@ -307,6 +307,7 @@ int main() test_container::test_container_processes_audio(); test_container::test_container_slimmable_selects_submodel(); test_container::test_container_boundary_values(); + test_container::test_container_slimmable_breakpoints(); test_container::test_container_empty_submodels_throws(); test_container::test_container_last_max_value_must_cover_one(); test_container::test_container_unsorted_submodels_throws(); @@ -330,6 +331,7 @@ int main() test_slimmable_wavenet::test_boundary_values(); test_slimmable_wavenet::test_default_is_max_size(); test_slimmable_wavenet::test_ratio_mapping(); + test_slimmable_wavenet::test_slimmable_breakpoints(); test_slimmable_wavenet::test_from_json(); test_slimmable_wavenet::test_wavenet_without_slimmable_loads_as_regular(); test_slimmable_wavenet::test_unsupported_method_throws(); diff --git a/tools/test/test_container.cpp b/tools/test/test_container.cpp index 46b17527..b4dc65c8 100644 --- a/tools/test/test_container.cpp +++ b/tools/test/test_container.cpp @@ -240,6 +240,20 @@ void test_container_boundary_values() assert(std::isfinite(output[i])); } +void test_container_slimmable_breakpoints() +{ + CountingDSP* small = nullptr; + CountingDSP* large = nullptr; + auto dsp = build_counting_container(small, large); + + auto* slimmable = dynamic_cast(dsp.get()); + assert(slimmable != nullptr); + + const auto breakpoints = slimmable->GetSlimmableSizeBreakpoints(); + assert(breakpoints.size() == 1); + assert(breakpoints[0] == 0.5); +} + void test_container_empty_submodels_throws() { nlohmann::json j; diff --git a/tools/test/test_slimmable_wavenet.cpp b/tools/test/test_slimmable_wavenet.cpp index 388ceeae..131ead9f 100644 --- a/tools/test/test_slimmable_wavenet.cpp +++ b/tools/test/test_slimmable_wavenet.cpp @@ -244,6 +244,20 @@ void test_ratio_mapping() assert(any_different); } +void test_slimmable_breakpoints() +{ + auto dsp = nam::get_dsp(std::filesystem::path("example_models/slimmable_wavenet.nam")); + assert(dsp != nullptr); + + auto* slimmable = dynamic_cast(dsp.get()); + assert(slimmable != nullptr); + + const auto breakpoints = slimmable->GetSlimmableSizeBreakpoints(); + assert(breakpoints.size() == 2); + assert(std::abs(breakpoints[0] - (1.0 / 3.0)) < 1e-12); + assert(std::abs(breakpoints[1] - (2.0 / 3.0)) < 1e-12); +} + void test_from_json() {