Add low-memory streaming conversion for unscanned DeepSeek-family checkpoints#4160
Add low-memory streaming conversion for unscanned DeepSeek-family checkpoints#4160discobot wants to merge 1 commit into
Conversation
…ckpoints The converter buffered every (dequantized) tensor from all safetensors shards in one dict before assembling a second full copy of the model, so converting Kimi-K2.6 needed ~2.5 TB of host RAM (AI-Hypercomputer#4071). Tensor loading is now lazy: an index over the shard headers is built up front and each tensor is read (and int4-dequantized) only when the assembly consumes it. A new --low_memory flag additionally stages converted leaves in disk-backed numpy memmaps under TMPDIR and saves without simulated-device sharding, keeping peak RSS at O(one tensor) instead of O(2x model). Both the default path and the low-memory path produce bit-identical checkpoints to before; a new unit test covers streaming, disk spilling, and a save/restore round trip on a tiny synthetic kimi-style checkpoint.
Codecov Report✅ All modified and coverable lines are covered by tests. 📢 Thoughts on this report? Let us know! |
|
How does this approach compare to lazy_load_tensors (which is now true by default) mode in src/maxtext/checkpoint_conversion/to_maxtext.py ? |
|
This is the standalone convert_deepseek_family_unscanned_ckpt.py, which didn't have lazy loading. _LazyShardLoader adds the same on-demand load/dequant. The part lazy_load_tensors doesn't cover is the save: with simulated_cpu_devices_count=16, device_put pulls the whole pytree back into RAM and still OOMs, so --low_memory spills leaves to .npy memmaps and skips simulated sharding to keep them disk-backed. Tested on a synthetic int4-expert checkpoint, not a full Kimi-K2.6 run (that's the >2TB-host case this is meant to fix), so I've left the e2e box unchecked. |
Description
Fixes #4071.
The unscanned converter buffered every dequantized tensor for all shards before
assembly, OOM-ing on hosts with less than ~2.5 TB RAM for Kimi-K2.6.
Two things beyond what the issue establishes: assembly builds a second full fp16 copy
of the model while the buffered dict is still alive, and even with loading fixed, the
default save path (
simulated_cpu_devices_count=16) re-materializes the whole pytreeas RAM-resident jax.Arrays inside
shard_jax_weights— so streaming the loads aloneis not enough.
This PR makes tensor loading lazy unconditionally (a header-only index over the shards;
each tensor is read and int4-dequantized exactly once, when assembly consumes it) and
adds an opt-in
--low_memoryflag that stages converted leaves in read-only disk-backedmemmaps under
TMPDIRand saves via the single-device path. Peak RSS drops fromO(2x model) to O(one tensor); on an 8-shard synthetic kimi-style checkpoint, RSS during
the shard scan goes from +400 MB monotonic growth to flat (+0.5 MB).
Checkpoints are bit-identical in all directions (verified via Orbax restore): old
converter vs new default, new default vs
--low_memory, and 16-simulated-devicesharded save vs the low-memory single-device save — the saved checkpoint is
topology-independent, so skipping simulated-device sharding in low-memory mode does not
change the artifact.
Adds
tests/unit/convert_deepseek_unscanned_low_memory_test.py(synthetic int4multi-shard checkpoint; asserts no tensor reads before assembly, no re-reads, bit-exact
low-memory equivalence, and a save/restore round trip against independently computed
values), ignored in default CI like the other torch-dependent conversion tests, and
documents the flag in the Kimi runbook.
Tests
python3 -m pytest tests/unit/convert_deepseek_unscanned_low_memory_test.py tests/unit/dequantize_pack_quantized_int4_test.py(9 passed)--low_memory true; restored pytrees bit-identicalpylint10.00/10 andpyink --pyink-indentation=2 --line-length=122clean on touched filesChecklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.This change was developed with assistance from Claude Code.