Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion NAM/container.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ void ContainerModel::Reset(const double sampleRate, const int maxBufferSize)
_submodels[active_index].model->Reset(sampleRate, maxBufferSize);
}

void ContainerModel::SetSlimmableSize(const double val)
size_t ContainerModel::_get_index_for_slimmable_size(const double val) const
{
size_t active_index = _submodels.size() - 1;
for (size_t i = 0; i < _submodels.size(); ++i)
Expand All @@ -93,6 +93,12 @@ void ContainerModel::SetSlimmableSize(const double val)
break;
}
}
return active_index;
}

void ContainerModel::SetSlimmableSize(const double val)
{
const size_t active_index = _get_index_for_slimmable_size(val);

// Fast path: no change to active model.
if (active_index == _active_index.load(std::memory_order_acquire))
Expand All @@ -115,6 +121,17 @@ void ContainerModel::SetSlimmableSize(const double val)
_active_index.store(active_index, std::memory_order_release);
}

std::vector<double> ContainerModel::GetSlimmableSizeBreakpoints() const
{
std::vector<double> breakpoints;
breakpoints.reserve(_submodels.size() - 1);

for (size_t i = 0; i + 1 < _submodels.size(); ++i)
breakpoints.push_back(_submodels[i].max_value);

return breakpoints;
}

int ContainerModel::GetPrewarmSamples()
{
const size_t active_index = _active_index.load(std::memory_order_acquire);
Expand Down
3 changes: 3 additions & 0 deletions NAM/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,12 @@ class ContainerModel : public DSP, public SlimmableModel
void Reset(const double sampleRate, const int maxBufferSize) override;
void SetPrewarmOnReset(const bool prewarmOnReset) override;
void SetSlimmableSize(const double val) override;
std::vector<double> GetSlimmableSizeBreakpoints() const override;
int GetPrewarmSamples() override;

private:
size_t _get_index_for_slimmable_size(const double val) const;

std::vector<Submodel> _submodels;
std::atomic<size_t> _active_index{0};
std::mutex _slim_set_mutex;
Expand Down
7 changes: 7 additions & 0 deletions NAM/slimmable.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#pragma once

#include <vector>

namespace nam
{

Expand All @@ -19,6 +21,11 @@ class SlimmableModel
/// Thread-safe
/// Not real-time safe
virtual void SetSlimmableSize(const double val) = 0;

/// \brief Get normalized size-control values that divide the selectable slimmed models
/// \return Sorted internal breakpoints in (0.0, 1.0); 0.0 and 1.0 are implied bounds
// TODO: Make this abstract in the next breaking release.
virtual std::vector<double> GetSlimmableSizeBreakpoints() const { return {}; }
};

} // namespace nam
49 changes: 37 additions & 12 deletions NAM/wavenet/slimmable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,21 @@ int ratio_to_channels(double ratio, const std::vector<int>& allowed)
return allowed[idx];
}

std::vector<double> get_ratio_breakpoints(const std::vector<std::vector<int>>& per_array_allowed_channels)
{
std::vector<double> breakpoints;

for (const auto& allowed : per_array_allowed_channels)
{
for (size_t i = 1; i < allowed.size(); ++i)
breakpoints.push_back((double)i / (double)allowed.size());
}

std::sort(breakpoints.begin(), breakpoints.end());
breakpoints.erase(std::unique(breakpoints.begin(), breakpoints.end()), breakpoints.end());
return breakpoints;
}

// ============================================================================
// Extract slimmed weights by walking the full weight vector in set_weights_
// order, using typed LayerArrayParams for dimensions.
Expand Down Expand Up @@ -386,6 +401,23 @@ SlimmableWavenet::SlimmableWavenet(std::vector<wavenet::LayerArrayParams> origin
_rebuild_model(full_channels);
}

std::vector<int> SlimmableWavenet::_get_channels_for_slimmable_size(const double val) const
{
const size_t num_arrays = _original_params.size();
std::vector<int> target(num_arrays);

for (size_t i = 0; i < num_arrays; i++)
{
const auto& allowed = _per_array_allowed_channels[i];
if (allowed.empty())
target[i] = _original_params[i].channels; // Non-slimmable: keep full
else
target[i] = ratio_to_channels(val, allowed);
}

return target;
}

std::unique_ptr<DSP> SlimmableWavenet::_create_wavenet_for_channels(const std::vector<int>& target_channels)
{
std::vector<float> weights;
Expand Down Expand Up @@ -494,19 +526,12 @@ void SlimmableWavenet::SetPrewarmOnReset(const bool prewarmOnReset)

void SlimmableWavenet::SetSlimmableSize(const double val)
{
const size_t num_arrays = _original_params.size();
std::vector<int> target(num_arrays);

for (size_t i = 0; i < num_arrays; i++)
{
const auto& allowed = _per_array_allowed_channels[i];
if (allowed.empty())
target[i] = _original_params[i].channels; // Non-slimmable: keep full
else
target[i] = ratio_to_channels(val, allowed);
}
_stage_rebuild_model(_get_channels_for_slimmable_size(val));
}

_stage_rebuild_model(target);
std::vector<double> SlimmableWavenet::GetSlimmableSizeBreakpoints() const
{
return get_ratio_breakpoints(_per_array_allowed_channels);
}

// ============================================================================
Expand Down
2 changes: 2 additions & 0 deletions NAM/wavenet/slimmable.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class SlimmableWavenet : public DSP, public SlimmableModel
void Reset(const double sampleRate, const int maxBufferSize) override;
void SetPrewarmOnReset(const bool prewarmOnReset) override;
void SetSlimmableSize(const double val) override;
std::vector<double> GetSlimmableSizeBreakpoints() const override;

protected:
int GetPrewarmSamples() override { return 0; }
Expand Down Expand Up @@ -92,6 +93,7 @@ class SlimmableWavenet : public DSP, public SlimmableModel
int _current_buffer_size = 0;
double _current_sample_rate = 0.0;

std::vector<int> _get_channels_for_slimmable_size(const double val) const;
std::unique_ptr<DSP> _create_wavenet_for_channels(const std::vector<int>& target_channels);
void _rebuild_model(const std::vector<int>& target_channels);
void _stage_rebuild_model(const std::vector<int>& target_channels);
Expand Down
2 changes: 2 additions & 0 deletions tools/run_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ int main()
test_container::test_container_processes_audio();
test_container::test_container_slimmable_selects_submodel();
test_container::test_container_boundary_values();
test_container::test_container_slimmable_breakpoints();
test_container::test_container_empty_submodels_throws();
test_container::test_container_last_max_value_must_cover_one();
test_container::test_container_unsorted_submodels_throws();
Expand All @@ -330,6 +331,7 @@ int main()
test_slimmable_wavenet::test_boundary_values();
test_slimmable_wavenet::test_default_is_max_size();
test_slimmable_wavenet::test_ratio_mapping();
test_slimmable_wavenet::test_slimmable_breakpoints();
test_slimmable_wavenet::test_from_json();
test_slimmable_wavenet::test_wavenet_without_slimmable_loads_as_regular();
test_slimmable_wavenet::test_unsupported_method_throws();
Expand Down
14 changes: 14 additions & 0 deletions tools/test/test_container.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,20 @@ void test_container_boundary_values()
assert(std::isfinite(output[i]));
}

void test_container_slimmable_breakpoints()
{
CountingDSP* small = nullptr;
CountingDSP* large = nullptr;
auto dsp = build_counting_container(small, large);

auto* slimmable = dynamic_cast<nam::SlimmableModel*>(dsp.get());
assert(slimmable != nullptr);

const auto breakpoints = slimmable->GetSlimmableSizeBreakpoints();
assert(breakpoints.size() == 1);
assert(breakpoints[0] == 0.5);
}

void test_container_empty_submodels_throws()
{
nlohmann::json j;
Expand Down
14 changes: 14 additions & 0 deletions tools/test/test_slimmable_wavenet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,20 @@ void test_ratio_mapping()
assert(any_different);
}

void test_slimmable_breakpoints()
{
auto dsp = nam::get_dsp(std::filesystem::path("example_models/slimmable_wavenet.nam"));
assert(dsp != nullptr);

auto* slimmable = dynamic_cast<nam::SlimmableModel*>(dsp.get());
assert(slimmable != nullptr);

const auto breakpoints = slimmable->GetSlimmableSizeBreakpoints();
assert(breakpoints.size() == 2);
assert(std::abs(breakpoints[0] - (1.0 / 3.0)) < 1e-12);
assert(std::abs(breakpoints[1] - (2.0 / 3.0)) < 1e-12);
}

void test_from_json()
{

Expand Down
Loading