DFlash speculative decoding for MiniMax-M2.7 (FSDP2): auto mask-token, FSDP2 resume fixes, per-checkpoint draft export#1621
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds a DFlash draft-weight export callback, an FSDP2 buffer/DTensor monkey-patch (with DTensor-safe grad clipping), integrates both into the speculative-decoding training script (checkpoint-format detection and mask-token init), disables vLLM prefix caching for hidden-state dumps, and forwards Slurm requeue settings to the launcher executor. ChangesSpeculative Decoding Training Enhancements
Launcher Slurm Requeue Configuration
Package initialization
Benchmark spec
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Suggested reviewers
🚥 Pre-merge checks | ✅ 5 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Warning
CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.
Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
examples/speculative_decoding/eagle_utils.py (1)
53-53: 🛠️ Refactor suggestion | 🟠 Major | ⚡ Quick winUpdate
__all__to includeDFlashExportCallback.The coding guidelines require defining the public API with
__all__. SinceDFlashExportCallbackis imported bymain.py(line 39), it should be exported.-__all__ = ["EagleOfflineDataCollator", "OfflineSupervisedDataset"] +__all__ = ["DFlashExportCallback", "EagleOfflineDataCollator", "OfflineSupervisedDataset"]As per coding guidelines: "Define the public API with
__all__at the top of each Python module."🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@examples/speculative_decoding/eagle_utils.py` at line 53, The module's public API list __all__ is missing DFlashExportCallback; update the __all__ declaration (which currently lists "EagleOfflineDataCollator" and "OfflineSupervisedDataset") to also include "DFlashExportCallback" so the symbol is exported for consumers like main.py that import it.
🧹 Nitpick comments (5)
examples/speculative_decoding/fsdp2_buffer_patch.py (4)
238-240: ⚡ Quick winUse
print_rank_0to avoid noisy logs in multi-rank environments.These print statements execute on every rank, which can produce excessive output on large clusters. Consider using
print_rank_0frommodelopt.torch.utilsor guarding with a rank check.+from modelopt.torch.utils import print_rank_0 + # In apply() function: - print("[fsdp2_buffer_patch] Patched fsdp2_load_full_state_dict for buffer compatibility") + print_rank_0("[fsdp2_buffer_patch] Patched fsdp2_load_full_state_dict for buffer compatibility") except Exception as e: - print(f"[fsdp2_buffer_patch] Patch skipped: {e}") + print_rank_0(f"[fsdp2_buffer_patch] Patch skipped: {e}")As per coding guidelines: "use
print_rank_0orwarn_rank_0to avoid noisy logs."🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@examples/speculative_decoding/fsdp2_buffer_patch.py` around lines 238 - 240, The two print calls in fsdp2_buffer_patch (the success message and the except message around fsdp2_load_full_state_dict) should be replaced with rank-safe logging: import and call print_rank_0 from modelopt.torch.utils (or guard with a rank check) so messages only appear on rank 0; update the success print and the exception print to use print_rank_0 and include the exception variable in the error message (e) while keeping the same text context.
1-3: 💤 Low valueAdd
__all__to define the public API.Per coding guidelines, each Python module should define
__all__to make the public API explicit.+__all__ = ["apply", "patch_accelerator"] + """Monkey-patch for accelerate's fsdp2_load_full_state_dict buffer handling.As per coding guidelines: "Define the public API with
__all__at the top of each Python module."🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@examples/speculative_decoding/fsdp2_buffer_patch.py` around lines 1 - 3, This module is missing an explicit public API; add a module-level __all__ declaration at the top (after the SPDX headers) listing the public names exported by this file (e.g. __all__ = ["Name1", "function_name", "CLASS_NAME"]), ensuring each symbol included matches the actual top-level functions/classes/variables defined later in the file; place the __all__ immediately below the license lines to satisfy the coding guideline.
322-326: ⚡ Quick winUse
print_rank_0here as well.def patch_accelerator(accelerator): """Replace accelerator's clip_grad_norm_ with FSDP2-safe version.""" accelerator.clip_grad_norm_ = _clip_grad_norm - print("[fsdp2_buffer_patch] Patched accelerator.clip_grad_norm_ " - "for FSDP2 DTensor compatibility") + from modelopt.torch.utils import print_rank_0 + print_rank_0("[fsdp2_buffer_patch] Patched accelerator.clip_grad_norm_ " + "for FSDP2 DTensor compatibility")As per coding guidelines: "use
print_rank_0orwarn_rank_0to avoid noisy logs."🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@examples/speculative_decoding/fsdp2_buffer_patch.py` around lines 322 - 326, The patch_accelerator function currently uses print to log the patch; replace that call with print_rank_0 to follow logging guidelines. Update the function to call print_rank_0("[fsdp2_buffer_patch] Patched accelerator.clip_grad_norm_ for FSDP2 DTensor compatibility") and ensure print_rank_0 is imported at top of the module (the same utility used elsewhere), leaving accelerator.clip_grad_norm_ = _clip_grad_norm unchanged.
269-270: 💤 Low valueReturn value should be on the same device as gradients.
When there are no gradients, the function returns a CPU tensor. For consistency with the non-empty case (which returns
total_normon device), consider returning on the same device.if len(grads) == 0: - return torch.tensor(0.0) + device = parameters[0].device if parameters else "cpu" + return torch.tensor(0.0, device=device)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@examples/speculative_decoding/fsdp2_buffer_patch.py` around lines 269 - 270, The early-return creates a CPU tensor when grads is empty; change it to return a zero tensor on the same device as the gradients by selecting the first available grad device (e.g. device = next((g.device for g in grads if g is not None), torch.device('cpu'))) and then return torch.tensor(0.0, device=device) so the empty-case matches the device of the non-empty case that returns total_norm.examples/speculative_decoding/main.py (1)
297-303: ⚡ Quick winConsider gating debug output or removing before merge.
This debug print executes on every rank and will produce verbose output on large clusters. If this is temporary debugging code, consider removing it or guarding with a debug flag.
- rank = int(os.environ.get("RANK", 0)) - dtypes = {} - for name, p in trainer.model.named_parameters(): - dt_key = str(p.dtype) if not hasattr(p, "_local_tensor") else str(p._local_tensor.dtype) - dtypes.setdefault(dt_key, []).append(name) - for dt, names in dtypes.items(): - print(f"[dtype_check rank={rank}] {dt}: {len(names)} params (e.g. {names[0]})") + if os.environ.get("DEBUG_DTYPES"): + rank = int(os.environ.get("RANK", 0)) + dtypes = {} + for name, p in trainer.model.named_parameters(): + dt_key = str(p.dtype) if not hasattr(p, "_local_tensor") else str(p._local_tensor.dtype) + dtypes.setdefault(dt_key, []).append(name) + for dt, names in dtypes.items(): + print(f"[dtype_check rank={rank}] {dt}: {len(names)} params (e.g. {names[0]})")As per coding guidelines: "use
print_rank_0orwarn_rank_0to avoid noisy logs."🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@examples/speculative_decoding/main.py` around lines 297 - 303, The debug loop printing per-rank dtype info (uses rank, dtypes, iterating trainer.model.named_parameters()) should be gated or replaced to avoid noisy logs: either remove the prints or wrap them so only rank 0 logs (use existing print_rank_0 or warn_rank_0 utility) and/or guard with a debug flag (e.g., if DEBUG:). Update the block that builds dtypes and the final print to call print_rank_0 (or warn_rank_0) with the formatted message so only the main process emits the output, or conditionally execute the entire loop behind a debug configuration toggle.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tools/launcher/core.py`:
- Around line 280-287: The code assumes slurm_config.additional_parameters is a
mutable dict and mutates it directly, which can cause shared-state bugs; before
assigning to executor.additional_parameters (and before mutating it to set
"requeue"), validate and normalize slurm_config.additional_parameters to a plain
dict (e.g., treat None, mappings, or other types safely), create a shallow copy
for executor.additional_parameters, and then mutate that copy; also ensure
executor.retries is updated via executor.retries = max(executor.retries, 3) as
shown. Reference: slurm_config.additional_parameters,
executor.additional_parameters, and executor.retries.
---
Outside diff comments:
In `@examples/speculative_decoding/eagle_utils.py`:
- Line 53: The module's public API list __all__ is missing DFlashExportCallback;
update the __all__ declaration (which currently lists "EagleOfflineDataCollator"
and "OfflineSupervisedDataset") to also include "DFlashExportCallback" so the
symbol is exported for consumers like main.py that import it.
---
Nitpick comments:
In `@examples/speculative_decoding/fsdp2_buffer_patch.py`:
- Around line 238-240: The two print calls in fsdp2_buffer_patch (the success
message and the except message around fsdp2_load_full_state_dict) should be
replaced with rank-safe logging: import and call print_rank_0 from
modelopt.torch.utils (or guard with a rank check) so messages only appear on
rank 0; update the success print and the exception print to use print_rank_0 and
include the exception variable in the error message (e) while keeping the same
text context.
- Around line 1-3: This module is missing an explicit public API; add a
module-level __all__ declaration at the top (after the SPDX headers) listing the
public names exported by this file (e.g. __all__ = ["Name1", "function_name",
"CLASS_NAME"]), ensuring each symbol included matches the actual top-level
functions/classes/variables defined later in the file; place the __all__
immediately below the license lines to satisfy the coding guideline.
- Around line 322-326: The patch_accelerator function currently uses print to
log the patch; replace that call with print_rank_0 to follow logging guidelines.
Update the function to call print_rank_0("[fsdp2_buffer_patch] Patched
accelerator.clip_grad_norm_ for FSDP2 DTensor compatibility") and ensure
print_rank_0 is imported at top of the module (the same utility used elsewhere),
leaving accelerator.clip_grad_norm_ = _clip_grad_norm unchanged.
- Around line 269-270: The early-return creates a CPU tensor when grads is
empty; change it to return a zero tensor on the same device as the gradients by
selecting the first available grad device (e.g. device = next((g.device for g in
grads if g is not None), torch.device('cpu'))) and then return torch.tensor(0.0,
device=device) so the empty-case matches the device of the non-empty case that
returns total_norm.
In `@examples/speculative_decoding/main.py`:
- Around line 297-303: The debug loop printing per-rank dtype info (uses rank,
dtypes, iterating trainer.model.named_parameters()) should be gated or replaced
to avoid noisy logs: either remove the prints or wrap them so only rank 0 logs
(use existing print_rank_0 or warn_rank_0 utility) and/or guard with a debug
flag (e.g., if DEBUG:). Update the block that builds dtypes and the final print
to call print_rank_0 (or warn_rank_0) with the formatted message so only the
main process emits the output, or conditionally execute the entire loop behind a
debug configuration toggle.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: ec55dcce-a920-44ca-8e39-ee3167ca3eeb
📒 Files selected for processing (4)
examples/speculative_decoding/eagle_utils.pyexamples/speculative_decoding/fsdp2_buffer_patch.pyexamples/speculative_decoding/main.pytools/launcher/core.py
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1621 +/- ##
==========================================
+ Coverage 67.73% 76.21% +8.47%
==========================================
Files 511 511
Lines 56169 57193 +1024
==========================================
+ Hits 38044 43587 +5543
+ Misses 18125 13606 -4519
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
d2d0558 to
5496efc
Compare
There was a problem hiding this comment.
Warning
CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.
Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@examples/speculative_decoding/fsdp2_buffer_patch.py`:
- Around line 269-270: The early-return creates a CPU tensor when no grads
exist, causing device mismatch with the normal path which returns total_norm on
the GPU (the variable device is set around line 274); update the early-return to
produce a tensor on the same device as total_norm by constructing the zero
tensor on the same device (e.g., using the device variable or
total_norm/new_tensor style) so the returned tensor's device matches the normal
path (refer to grads, total_norm, and device to locate the code).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: b39e1775-7aa8-482b-84ba-391c5f4eaef7
📒 Files selected for processing (4)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.pyexamples/speculative_decoding/eagle_utils.pyexamples/speculative_decoding/fsdp2_buffer_patch.pyexamples/speculative_decoding/main.py
🚧 Files skipped from review as they are similar to previous changes (2)
- examples/speculative_decoding/eagle_utils.py
- examples/speculative_decoding/main.py
There was a problem hiding this comment.
🧹 Nitpick comments (1)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py (1)
246-246: ⚡ Quick winAdd defensive length check before truncating loss_mask.
If
output_hidden_statesis longer thanloss_mask, Python slice semantics return the full (too-short)loss_mask, creating a mismatch with the savedinput_idslength. The downstreamOfflineSupervisedDatasetloader does not validate shape alignment, risking silent training errors. Withenable_prefix_caching=False(line 194), lengths should match exactly—consider adding a warning to catch violations:🛡️ Suggested defensive check
+ expected_len = output_hidden_states.shape[0] + if loss_mask.shape[0] != expected_len: + import warnings + warnings.warn( + f"Conversation {conv_id}: loss_mask length {loss_mask.shape[0]} != " + f"hidden_states length {expected_len}; may indicate tokenization mismatch." + ) + output_file = output_dir / f"{conv_id}.pt" with open(output_file, "wb") as f: torch.save( { "input_ids": token_ids.cpu(), "hidden_states": output_hidden_states, "aux_hidden_states": aux_hidden_states, - "loss_mask": loss_mask[: output_hidden_states.shape[0]].cpu(), + "loss_mask": loss_mask[:expected_len].cpu(), "conversation_id": conv_id, }, f, )🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py` at line 246, Before truncating loss_mask for storage, ensure its length matches output_hidden_states.shape[0]; check if output_hidden_states.shape[0] > loss_mask.shape[0] and if so log a warning (or raise an error) calling out the mismatch between loss_mask and output_hidden_states lengths (mention enable_prefix_caching if relevant), otherwise perform the slice as before: "loss_mask = loss_mask[: output_hidden_states.shape[0]].cpu()". Reference the variables loss_mask and output_hidden_states and the OfflineSupervisedDataset consumer when adding this defensive check.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In
`@examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py`:
- Line 246: Before truncating loss_mask for storage, ensure its length matches
output_hidden_states.shape[0]; check if output_hidden_states.shape[0] >
loss_mask.shape[0] and if so log a warning (or raise an error) calling out the
mismatch between loss_mask and output_hidden_states lengths (mention
enable_prefix_caching if relevant), otherwise perform the slice as before:
"loss_mask = loss_mask[: output_hidden_states.shape[0]].cpu()". Reference the
variables loss_mask and output_hidden_states and the OfflineSupervisedDataset
consumer when adding this defensive check.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 898bf778-31d1-4a22-a5f1-2672bcf3b208
📒 Files selected for processing (1)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py
|
There was a problem hiding this comment.
Warning
CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.
Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In
`@examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py`:
- Around line 47-73: The _resolve_aux_layers_standalone function currently
parses comma-separated aux layer IDs but doesn't validate they lie in [0,
num_hidden_layers); update it to check each parsed id (from aux_layers.split)
against 0 <= id < num_hidden_layers and raise a ValueError with a clear message
(mirroring resolve_aux_layers) if any id is out of range or negative; keep the
existing behavior for the "dflash" preset and ensure the error references the
original aux_layers input and num_hidden_layers for clarity.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: b02075cf-1305-4259-b54f-425362cba18c
📒 Files selected for processing (2)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.pytools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/specdec_bench.yaml
h-guo18
left a comment
There was a problem hiding this comment.
Can we add some tests that cover the new patches and callbacks?
|
Thanks for the review — addressed all comments. Summary of this round:
Both paths re-validated end-to-end: offline dump→train→export green, and the YaRN export config recovers 32k AL (1.17 → 2.62). |
|
Added in
The |
Models without a <|mask|> token (e.g., MiniMax-M2.7) would fail with ValueError during DFlash training. Instead of requiring the user to manually set dflash_mask_token_id, add the token to the tokenizer and resize model embeddings automatically. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Ye Yu <yeyu@nvidia.com>
When slurm_config.requeue is True, set additional_parameters["requeue"] = True so nemo-run emits #SBATCH --requeue in the sbatch script. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Ye Yu <yeyu@nvidia.com>
1. main.py: When FSDP2 cpu_ram_efficient_loading is active, only rank 0 loads real weights on CPU; other ranks use meta device. FSDP2 distributes from rank 0. Also adds dp_replicate_size auto-computation so dp_replicate * dp_shard * cp == world_size. 2. core.py: Set retries=3 when requeue is requested. The nemo-run sbatch wrapper only calls scontrol requeue when TORCHX_MAX_RETRIES > SLURM_RESTART_COUNT — retries=0 (the default) disabled requeue. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Ye Yu <yeyu@nvidia.com>
…heckpoint resume The Pydantic recipe refactor dropped fsdp2_buffer_patch.apply() and patch_accelerator() calls and added a buffer-to-CUDA block that moved DFlash buffers before FSDP wrapping. With cpu_ram_efficient_loading, non-rank-0 processes have meta-device params, causing _infer_parameter_dtype() to return fp32 instead of bf16 on resume. Also detects FSDP distributed checkpoints (no HF model files) and loads the base model instead of trying from_pretrained on them. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Ye Yu <yeyu@nvidia.com>
… patch _infer_parameter_dtype() reads the model's current param dtype to cast the broadcasted tensor. With cpu_ram_efficient_loading, non-rank-0 processes have fp32 meta-device params for DFlash, so _infer_parameter_dtype returns fp32 and _finish() casts the correctly- broadcasted bf16 tensor back to fp32. Use bcast_dtype (from rank 0) instead. Also prints dtype_check on all ranks to verify consistency. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Ye Yu <yeyu@nvidia.com>
The Pydantic-recipe refactor (7038dec) dropped DFlashExportCallback, which had exported the DFlash draft submodule after every checkpoint save. Without it, FSDP2 sharded checkpoints (pytorch_model_fsdp_0/, no model.safetensors) get no exported-checkpoint-{step}/, so downstream vLLM deployment / acceptance-length eval has nothing to load. The verify-only comment 'export happens during training via DFlashExportCallback' was left behind but the callback itself was gone. Restore the callback (gathers only the ~328MB draft submodule across shards via get_model_state_dict, so it works under SHARDED_STATE_DICT without materializing the full base model) and wire it into main.py for DFlash recipes. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Ye Yu <yeyu@nvidia.com>
…g comment Per @hguo-nv: - rope_theta: the base-first choice is intentional for DFlash — the draft injects the target's KV into every layer, so its RoPE base must match the target's. The draft arch config carries no rope_theta, so draft-first fell back to the 1e6 default (a mismatch with the target's 5e6). Tightened the comment to state this. - Trimmed the verbose rope_scaling comment to the essentials. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Ye Yu <yeyu@nvidia.com>
…er extraction Addresses @hguo-nv's points 1-2 on the export callback: 1. The callback now self-skips when the checkpoint was saved as a full state dict (model.safetensors / pytorch_model.bin present) — i.e. it only does work under FSDP2 SHARDED_STATE_DICT, where checkpoints are distributed shards the post-run export can't read. The check is from the on-disk checkpoint format (identical across ranks, before the collective gather). main.py now appends it unconditionally for DFlash since it self-gates (DDP / single-device / FSDP2-full all self-skip), removing the coarse dp_shard_size/env heuristic. 2. Extraction now delegates to DFlashExporter._extract_state_dict (the same logic the normal export path uses) for the common prefixed-key layout, with a fallback for the already-stripped submodule-gather keys. Config generation already delegated to the exporter; this removes the remaining duplication. The gather (get_model_state_dict on the dflash_module submodule) — the load-bearing FSDP2 part that produced the release — is unchanged. The full split into separate reusable gather/export callbacks is deferred: offline DFlash uses DDP (full checkpoints, no callback needed) and eagle has its own export path, so there is no current consumer. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Ye Yu <yeyu@nvidia.com>
… dtype diag Implements @hguo-nv's agreed asks on the DFlash export callback: - Rename DFlashExportCallback -> DFlashFSDP2ShardedSDExportCallback to make its applicable range (FSDP2 SHARDED_STATE_DICT only) explicit in the name. - Gate it in main.py by reading the live accelerator FSDP state dict type (trainer.accelerator.state.fsdp_plugin.state_dict_type) instead of the callback's on-disk checkpoint-filename heuristic; the callback is added only for SHARDED_STATE_DICT (full-state-dict runs use the launcher's post-run export). Removed the in-callback early-exit accordingly. - Move the per-rank dtype diagnostic out of main.py into fsdp2_buffer_patch.log_param_dtypes(), gated behind DFLASH_LOG_PARAM_DTYPES=1 (no-op by default), since it only exists to verify the FSDP2 dtype sync. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Ye Yu <yeyu@nvidia.com>
|
/claude review |
Acceptance length (released MiniMax-M2.7 DFlash draft, step 20400)Measured with the
Draft length 7 is the recommended serving config (DFlash |
Addresses @hguo-nv: the modify() inheritance was a setdefault, so a rope_theta provided in dflash_architecture_config would NOT be overridden — the draft would train with that value while the exporter writes the base's rope_theta, causing a train/inference mismatch. DFlash injects the target's KV into every draft layer, so the draft's RoPE base must match the target's. Enforce rope_theta/rope_type/rope_interleaved from the base model (overwrite any user value, with a warning); keep the other architecture attrs as setdefault. This makes training and export agree by construction. rope_scaling stays excluded (added only at export via dflash_export_rope_scaling). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Ye Yu <yeyu@nvidia.com>
… check Per @hguo-nv: tidy the inline checkpoint-format probe in main.py into a named helper. Kept it distinct from the export-callback's save-mode gate on purpose — this inspects the *resume* checkpoint's on-disk format (a property of existing bytes, checked pre-trainer), whereas the gate reads the *current run's* fsdp_state_dict_type (post-trainer); the two answer different questions and can differ across runs, so they shouldn't share one FSDP2+ShardedSD condition. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Ye Yu <yeyu@nvidia.com>
There was a problem hiding this comment.
Claude Code Review
Thorough pass on the DFlash / FSDP2 / mask-token / per-checkpoint-export PR.
The motivation and architecture are solid — gated patches, well-documented
applicability, and the export-callback / config-field convergence with eagle
is the right shape. Below are the issues I found that I think the author
should look at before merging.
Findings by severity
- CRITICAL: 1
- IMPORTANT: 6
- SUGGESTION: 3
Highlights
CRITICAL — _clip_grad_norm deadlock when not all ranks see grads
fsdp2_buffer_patch.py:312-344 returns early on ranks with no grads but
unconditionally issues dist.all_reduce(sharded_norm_p) on the others.
First time a rank hits an empty-grad step (sharded MoE expert with no
tokens, etc.), the cluster hangs. Fix: every rank must reach the same
collective, or none.
IMPORTANT — silent fallback paths in the export callback (eagle_utils.py:269-291)
- The
TypeError-fallback branch reverts to a full-model gather of the
229B base, completely defeating the submodule-only design. Should fail
loudly (or at minimum warn) on unsupported PyTorch versions. - The fallback key heuristic differs from
DFlashExporter._extract_state_dict
and bypasses_check_valid_sd, so a malformed export ships silently and
only fails at vLLM load time.
IMPORTANT — _DTYPE_TO_CODE.get(..., 0) silently coerces unknown dtypes
to fp32 (fsdp2_buffer_patch.py:178). Any dtype outside the four-entry
map (int8 / fp64 / bool / fp8_e5m2 / …) becomes fp32 on rank 0 and the
broadcast either casts data on the wire or NCCL refuses on size mismatch.
Raise on unknown dtype instead of silently coercing.
IMPORTANT — standalone aux-layer resolver hard-codes num_draft=5
(compute_hidden_states_vllm.py:62). Today the recipes match, but a
deeper draft + --aux-layers dflash silently produces mis-aligned dumps
and downstream AR regressions. Make the depth a CLI arg.
IMPORTANT — rope_theta or-fallback (hf_spec_export.py:382-383)
swallows 0 / 0.0. Use is not None for parity with the eagle path.
IMPORTANT — /dev/shm staging (compute_hidden_states_vllm.py:224)
without cleanup leaks RAM-backed tmpfs across crashes; on some HPC
containers /dev/shm is unmapped or undersized. Make it overridable via
env var and clean up on exit.
IMPORTANT — HF-format checkpoint detection (main.py:189-194) does
not appear to re-apply mtsp.convert on resume of a previously-saved
modelopt HF dir; needs verification that enable_huggingface_checkpointing
covers the DFlash plugin.
SUGGESTIONs — duplicate doc row, dead _orig capture,
exported-checkpoint-final mismatch in specdec_bench.yaml (callback
writes per-step dirs, not -final).
Risk assessment
The core algorithmic and export logic look right; the FSDP2 buffer/dtype
broadcast patch is reasonable for the narrow scope it covers, and the
callback gating refactor (only under SHARDED_STATE_DICT) is the right
shape. The crit risk is concentrated in two places: (1) the clip_grad_norm
collective imbalance, (2) the silent-fallback paths in the export callback
and dtype-encoder. Either could turn into a hard-to-debug production
failure long after merge.
Recommend addressing the CRITICAL and the four "silent fallback" IMPORTANT
items before merging; the rest are good follow-ups.
| try: | ||
| raw_sd = get_model_state_dict( | ||
| model, submodules={model.dflash_module}, options=options | ||
| ) | ||
| except TypeError: | ||
| # Older PyTorch without submodules parameter — gather full model | ||
| raw_sd = get_model_state_dict(model, options=options) | ||
| except ImportError: | ||
| # Non-distributed / single-GPU fallback | ||
| raw_sd = model.state_dict() | ||
|
|
||
| # Reuse the exporter's extraction (strips the dflash_module prefix, drops rotary | ||
| # buffers) for the common full-model key layout. Some PyTorch versions return the | ||
| # submodule gather with keys already stripped of the prefix — handle that directly. | ||
| exporter = model.get_exporter() | ||
| drafter_sd = exporter._extract_state_dict(raw_sd) | ||
| if not drafter_sd: | ||
| drafter_sd = { | ||
| k: v | ||
| for k, v in raw_sd.items() |
There was a problem hiding this comment.
[IMPORTANT Performance/ModeState] The silent full-model fallback when submodules= is not supported is dangerous for the very target this callback exists to support.
The whole point of this callback (per the docstring and PR description) is to gather only the ~328 MB DFlash submodule under FSDP2 SHARDED_STATE_DICT, without materializing the 229B base. But this fallback path:
except TypeError:
# Older PyTorch without submodules parameter — gather full model
raw_sd = get_model_state_dict(model, options=options)silently calls get_model_state_dict(model, options=options) (no submodule filter, full_state_dict=True, cpu_offload=True), which gathers the entire 229B base on every save. That is many minutes per checkpoint of all-gather plus ~230 GB of CPU memory pressure per node — exactly what the submodule gather was added to avoid. On older PyTorch it wouldn't fail loudly; the user would just see saves get mysteriously expensive (or OOM) without understanding why.
Suggestion: either (a) raise a clear error pointing the user at a minimum supported PyTorch version, or (b) at minimum print_rank_0 a loud warning naming the model size and that this path will gather the full base model. The current silent fallback masks a tail-latency / OOM failure mode in the only scenario that matters.
| if not drafter_sd: | ||
| drafter_sd = { | ||
| k: v | ||
| for k, v in raw_sd.items() | ||
| if "rotary_emb" not in k | ||
| and not any(p in k for p in ("model.", "lm_head.", "embed_tokens.")) | ||
| } | ||
| del raw_sd | ||
| # Coerce to CPU for save_file (the distributed gather uses cpu_offload, but the | ||
| # single-GPU fallback may return CUDA tensors). | ||
| drafter_sd = {k: (v.cpu() if v.device.type != "cpu" else v) for k, v in drafter_sd.items()} |
There was a problem hiding this comment.
[IMPORTANT Algorithm/Export] The fallback key heuristic doesn't match what DFlashExporter._extract_state_dict actually returns — it will produce an empty export silently.
DFlashExporter._extract_state_dict (modelopt/torch/export/plugins/hf_spec_export.py:322-332) keeps only entries that contain "dflash_module." and strips that prefix. If get_model_state_dict(..., submodules={model.dflash_module}, full_state_dict=True) returns keys already pre-stripped (without the dflash_module. prefix), then exporter._extract_state_dict(raw_sd) returns {} — that's why this fallback exists.
But the heuristic here:
drafter_sd = {
k: v
for k, v in raw_sd.items()
if "rotary_emb" not in k
and not any(p in k for p in ("model.", "lm_head.", "embed_tokens."))
}filters out anything containing "model.", which kills layer/decoder weights named like layers.0.self_attn.q_proj.weight (no model. substring) — fine — but it also kills any key that happens to legitimately contain "model." if a future DFlash module nests differently. More importantly, the substring tests "lm_head." and "embed_tokens." will NEVER match in the pre-stripped DFlash submodule case (the DFlash draft has neither). The exporter's _check_valid_sd is also bypassed in this branch — so a malformed state dict (e.g. weights from the wrong module after a refactor) would write a broken model.safetensors and only fail at vLLM load time.
Recommendation: instead of the substring heuristic, detect the pre-stripped case explicitly and either (a) re-prefix the keys with "dflash_module." and rerun exporter._extract_state_dict (so the same validation path runs), or (b) at minimum, assert the resulting key set looks sane (e.g. contains fc.weight, norm.weight, layers.0.self_attn.q_proj.weight) before writing. Right now a key-layout regression would pass through silently.
| grads = [p.grad for p in parameters if p.grad is not None] | ||
| if len(grads) == 0: | ||
| # Match the device of the normal return path (GPU when training on CUDA) so | ||
| # callers don't hit a device mismatch on the empty-grad case. | ||
| dev = "cuda" if torch.cuda.is_available() else "cpu" | ||
| return torch.tensor(0.0, device=dev) | ||
|
|
||
| # Shard DTensors hold partial data — need all_reduce for global norm. | ||
| # Replicate DTensors and regular tensors already hold full data. | ||
| device = grads[0]._local_tensor.device if isinstance(grads[0], DTensor) else grads[0].device | ||
| sharded_norm_p = torch.tensor(0.0, device=device) | ||
| local_norm_p = torch.tensor(0.0, device=device) | ||
|
|
||
| n_sharded = 0 | ||
| n_replicate = 0 | ||
| n_regular = 0 | ||
|
|
||
| for g in grads: | ||
| if isinstance(g, DTensor): | ||
| is_sharded = any(isinstance(p, Shard) for p in g.placements) | ||
| t = g._local_tensor.detach().to(torch.float32) | ||
| n = torch.linalg.vector_norm(t, norm_type) | ||
| if is_sharded: | ||
| sharded_norm_p += n.pow(norm_type) | ||
| n_sharded += 1 | ||
| else: | ||
| local_norm_p += n.pow(norm_type) | ||
| n_replicate += 1 | ||
| else: | ||
| n = torch.linalg.vector_norm(g.detach().to(torch.float32), norm_type) | ||
| local_norm_p += n.pow(norm_type) | ||
| n_regular += 1 | ||
|
|
||
| dist.all_reduce(sharded_norm_p, op=dist.ReduceOp.SUM) |
There was a problem hiding this comment.
[CRITICAL Algorithm] Unconditional dist.all_reduce(sharded_norm_p) deadlocks when not all ranks see grads.
grads = [p.grad for p in parameters if p.grad is not None] may legitimately be empty on some ranks (e.g. an FSDP2-sharded MoE where a rank's expert gets no tokens this micro-batch, or a pipeline-parallel-style setup, or simply early steps where some ranks' params haven't yet accumulated grads). The function returns early with torch.tensor(0.0) on those ranks (line 312-316), but the other ranks fall through and call dist.all_reduce(sharded_norm_p) (line 344). Since the empty-grad rank does NOT participate in the collective, the all_reduce hangs forever. The same hazard exists if every grad on a rank is a Replicate DTensor / regular tensor (so n_sharded == 0) while another rank has shards — the reduce is still issued, but on a tensor that wasn't summed against any sharded contribution; that is benign for the math but still requires every rank to reach the call, which the early return breaks.
Two issues to fix:
- The early-return must not skip the collective. Either every rank participates in the all_reduce (initialize
sharded_norm_pto zero before the early-return and unconditionally reduce), or none do. - Document/assert that
clip_grad_norm_is only ever called on aparametersset that is identical across ranks (which is the only case where "all ranks have at least one grad" is guaranteed). If that's the design contract, please at least add an assertion + comment so a future caller doesn't introduce uneven grads and silently deadlock.
The current shape — early-return on one rank, collective on the others — is a deadlock waiting for the first edge case.
| if accelerator.is_main_process: | ||
| dtype_codes = torch.tensor( | ||
| [_DTYPE_TO_CODE.get(full_sd[name].dtype, 0) for name in meta_sharded_sd], | ||
| dtype=torch.int32, | ||
| device=accelerator.device, | ||
| ) | ||
| else: | ||
| dtype_codes = torch.empty( | ||
| n_total, | ||
| dtype=torch.int32, | ||
| device=accelerator.device, | ||
| ) | ||
| dist.broadcast(dtype_codes, src=0, group=dist.group.WORLD) | ||
| broadcast_dtypes = [_CODE_TO_DTYPE[c.item()] for c in dtype_codes] |
There was a problem hiding this comment.
[IMPORTANT Algorithm] Silent fallback _DTYPE_TO_CODE.get(..., 0) + _CODE_TO_DTYPE[c.item()] corrupts unsupported-dtype entries instead of failing loudly.
On rank 0 you encode each full_sd[name].dtype to a code via .get(dtype, 0) — meaning any dtype that isn't in the four-entry map (fp32 / bf16 / fp16 / fp8_e4m3fn) silently becomes 0 (= float32). Other ranks then build a fp32 receive buffer and dist.broadcast casts the rank-0 tensor to fp32 on the wire (or worse, NCCL refuses if the byte size doesn't match). This silently corrupts:
int8/uint8checkpoints (e.g. some quantized loads)float8_e5m2float64/int64(rare for weights but valid for some buffers like rotaryinv_freqcached as fp64 by some HF model code)- bool (some attention masks)
Because the comment explicitly says this path supports fp8 weights, a bf16 vs fp8 mis-encoding would silently broadcast garbage across the cluster and you'd only notice from a downstream NaN/AR collapse — exactly the kind of ghost-bug this patch is meant to prevent. Rank-0 should raise ValueError(f"Unsupported broadcast dtype {dtype} for {name}") rather than coerce to a default.
Bonus: the int32 broadcast carries n_total codes, but the _CODE_TO_DTYPE decode loop treats c.item() as authoritative — so if a future torch version re-orders the fp8 enum to a different code, the if hasattr(...) else -1 initializer here makes that change a silent encoding flip rather than an error. Pin codes by name, not by attribute existence.
| import accelerate.utils.fsdp_utils as fsdp_utils | ||
| from torch.distributed.tensor import DTensor | ||
|
|
||
| _orig = fsdp_utils.fsdp2_load_full_state_dict |
There was a problem hiding this comment.
[SUGGESTION] _orig is captured but never invoked — dead variable.
The original fsdp2_load_full_state_dict is captured here for an apparent rollback / wrapper pattern, but _patched never delegates to it. Either drop the capture (it confused me on first read into thinking there was a fallback path) or document why it exists (e.g. for an unapply() you plan to add). Trivial cleanup.
| | `dflash.dflash_architecture_config.num_hidden_layers` | 5 | Draft decoder layers | | ||
| | `dflash.dflash_architecture_config.mask_token_id` | auto | Token ID for masked positions | | ||
| | `dflash.dflash_mask_token_id` | auto | Token ID for masked positions (see note below) | |
There was a problem hiding this comment.
[SUGGESTION] Duplicate dflash_architecture_config.num_hidden_layers row.
| | `dflash.dflash_architecture_config.num_hidden_layers` | 5 | Draft decoder layers | | |
| | `dflash.dflash_architecture_config.mask_token_id` | auto | Token ID for masked positions | | |
| | `dflash.dflash_mask_token_id` | auto | Token ID for masked positions (see note below) | | |
| | `dflash.dflash_architecture_config.num_hidden_layers` | 5 | Draft decoder layers | | |
| | `dflash.dflash_mask_token_id` | auto | Token ID for masked positions (see note below) | |
The current diff inserts the new dflash_mask_token_id row but keeps the old dflash_architecture_config.num_hidden_layers line and ALSO repeats it on line 168 — the rendered table will have num_hidden_layers twice in a row. Drop the duplicate (the line below this one).
| if spec == "dflash": | ||
| num_draft = 5 | ||
| if num_draft == 1: | ||
| return [num_hidden_layers // 2] | ||
| start = min(1, num_hidden_layers - 1) | ||
| end = max(start, num_hidden_layers - 3) | ||
| span = end - start | ||
| return sorted({round(start + (i * span) / (num_draft - 1)) for i in range(num_draft)}) |
There was a problem hiding this comment.
[IMPORTANT Algorithm] The standalone resolver hard-codes num_draft=5 while the recipe value is configurable — a mismatch silently mis-aligns dumped aux layers.
build_target_layer_ids(num_target_layers, num_draft_layers) (modeling_dflash.py:58-69) is parameterized by num_draft_layers. The MiniMax recipes here pin dflash.dflash_architecture_config.num_hidden_layers=5, so today the constant matches — but if anyone runs offline DFlash with a different draft depth (4, 6, 8 etc.) and forgets to switch from --aux-layers dflash to an explicit comma-list, the dump silently captures a different set of target layers than the trained draft will actually consume at training time, since target_layer_ids is also derived from num_draft_layers via the same function in hf_dflash.py:169. The training run will load mis-aligned aux features and learn garbage; the failure mode is "AR mysteriously regresses" rather than a loud crash.
Two fixes either of which would close the gap:
- Take the draft depth as a CLI flag (
--dflash-num-draft 5) and pass it through. - At least raise a loud error / log if the recipe-default
5does not match the model's actualnum_hidden_layersdivisibility expectations, and document the constant-coupling explicitly in the docstring.
The TODO already acknowledges this is fragile; the comment-coupling to a hard-coded 5 is the actual breakage vector. Suggest making the value a required CLI arg and removing the default.
| # Stage the connector's intermediate safetensors on local tmpfs, not the (lustre) | ||
| # output dir: the producer writes one file per request and the client reads it back | ||
| # immediately, so a fast local path avoids cross-node FS latency. Per-DP-rank dir so | ||
| # parallel shards don't collide. | ||
| storage_path = Path("/dev/shm") / f"vllm_hidden_states_dp{args.dp_rank}" |
There was a problem hiding this comment.
[IMPORTANT Performance] Hard-coding /dev/shm and skipping cleanup is a footgun on shared/multi-tenant nodes.
Two issues with this hardcoded path:
-
No cleanup on failure / exit. A run that crashes after writing K safetensors leaves them in
/dev/shm(which is RAM-backed on Linux) and they persist until the node is rebooted or another job evicts them. With per-DP-rank dirs and ~MiniMax-M2.7-shape hidden states (vocab × hidden × seq_len), this can accumulate to tens or hundreds of GB of stranded RAM. Suggest wrapping withtempfile.TemporaryDirectory()(also pickable per DP rank) or at minimum registeratexit.registerto removestorage_path. -
/dev/shmmay not exist or may be too small. Some HPC scheduler containers either don't bind-mount/dev/shm(it's a tmpfs the orchestration must wire up) or cap it tightly (Slurm default is sometimes 64 MB). For MiniMax-M2.7 (229B target) and 4096-token sequences, per-conversation tensor sizes are non-trivial;/dev/shmoverflow yields a confusingOSError: [Errno 28] No space left on deviceat write time, mid-job. Worth at least making the staging path overridable with an env var (VLLM_HIDDEN_STATES_TMPDIR) and falling back throughtempfile.gettempdir().
Neither blocks the merge but both will bite production. The cleanup omission is the more serious of the two.
| # Inherit the target's rope_theta: DFlash injects the target's KV into every | ||
| # draft layer, so the draft's RoPE base must match the target's. (The draft | ||
| # arch config carries no rope_theta of its own.) | ||
| "rope_theta": getattr(base_config, "rope_theta", None) | ||
| or getattr(draft_config, "rope_theta", 1000000.0), |
There was a problem hiding this comment.
[IMPORTANT Compatibility] The or short-circuit silently swallows rope_theta=0 (and any other falsy numeric value).
"rope_theta": getattr(base_config, "rope_theta", None)
or getattr(draft_config, "rope_theta", 1000000.0),If base_config.rope_theta were ever 0 or 0.0 (legal for some models that disable RoPE entirely, though uncommon for production targets), or falls through to the draft default 1000000.0. The intent (per the comment) is "use the target's value if set, otherwise fall back" — which means the test should be is not None, not truthy:
| # Inherit the target's rope_theta: DFlash injects the target's KV into every | |
| # draft layer, so the draft's RoPE base must match the target's. (The draft | |
| # arch config carries no rope_theta of its own.) | |
| "rope_theta": getattr(base_config, "rope_theta", None) | |
| or getattr(draft_config, "rope_theta", 1000000.0), | |
| "rope_theta": ( | |
| getattr(base_config, "rope_theta", None) | |
| if getattr(base_config, "rope_theta", None) is not None | |
| else getattr(draft_config, "rope_theta", 1000000.0) | |
| ), |
Realistically rope_theta=0 is rare so this is more theoretical than acute, but the existing eagle path uses an explicit is None check (line 205-206) for exactly this reason — worth being consistent.
| - --tp_size 4 | ||
| - --ep_size 4 |
There was a problem hiding this comment.
[SUGGESTION] exported-checkpoint-final doesn't match the callback's actual output naming.
DFlashFSDP2ShardedSDExportCallback.on_save writes to exported-checkpoint-{step} (eagle_utils.py:258), where step = state.global_step — never the literal string final. So this YAML's --draft_model_dir /scratchspace/dflash_minimax_m2.7/exported-checkpoint-final will fail to find model.safetensors unless the user manually renames or symlinks the last per-save export directory. Same on line 67. Either the callback should additionally write/symlink an exported-checkpoint-final/ after on_train_end, or this YAML should use a placeholder pattern (e.g. exported-checkpoint-${STEP} with the user expected to fill it in) and call out in a comment that the user must pick the desired step. As-written this is a copy-paste-ready bench config that won't work out of the box.
Addresses the Claude review's CRITICAL + silent-fallback findings: - CRITICAL: _clip_grad_norm no longer early-returns on an empty-grad rank before the collective — under FSDP2 sharding / MoE+LoRA co-train a rank can legitimately have no grads (e.g. an expert with no tokens), and returning early while other ranks all_reduce deadlocks the job. Every rank now reaches the (guarded) all_reduce; empty-grad ranks contribute a zero norm and clip nothing. - fsdp2_buffer_patch dtype-sync: raise on an unmapped dtype instead of silently coercing to fp32 (which would cast on the wire or trip an NCCL size mismatch). - DFlash export callback: warn loudly on the two fallback paths — the full-model gather when get_model_state_dict lacks submodules=, and the denylist heuristic when the prefix-based extraction finds nothing — so a slow/OOM gather or a malformed export isn't silent. - hf_spec_export rope_theta: use 'is not None' instead of 'or' so a base rope_theta of 0 isn't swallowed (parity with the eagle path). - Cleanups: drop the dead _orig capture; remove the duplicated num_hidden_layers doc row. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Ye Yu <yeyu@nvidia.com>
…g in offline dump Addresses the remaining offline-dump-path findings (relevant to the upcoming M3 offline work): - compute_hidden_states_vllm.py: the standalone 'dflash' aux-layer resolver no longer hard-codes num_draft=5; added --num-draft-layers (default 5) so the dumped aux layers match the recipe's dflash_architecture_config.num_hidden_layers. A mismatch would silently mis-align the dump with what the draft consumes at training time. - compute_hidden_states_vllm.py: the connector staging dir is now overridable via DFLASH_HS_STAGING_DIR (default /dev/shm) for containers where /dev/shm is unmapped or undersized, and is removed on exit (atexit) so a crash doesn't strand RAM-backed files. - specdec_bench.yaml: clarified that exported-checkpoint-final comes from the launcher's post-run export, while the per-save callback writes exported-checkpoint-<step>. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Ye Yu <yeyu@nvidia.com>
Claude review — resolutions (commits
|
…lambda
mypy couldn't infer the cleanup lambda's type ('Cannot infer type of lambda'). atexit
forwards *args/**kwargs to the callback, so register shutil.rmtree directly with its args.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
What
Brings up DFlash block-diffusion speculative decoding for large MoE targets (MiniMax-M2.7, 229B) trained under accelerate FSDP2, and fixes the regressions that broke checkpoint resume and per-checkpoint draft export.
Commits
build_slurm_executor+ FSDP2 cpu_ram_efficient_loading for 229B on multi-node.fsdp2_buffer_patch.py): handle non-DTensor buffers infsdp2_load_full_state_dict, broadcast dtype codes from rank 0, and an FSDP2-safeclip_grad_norm_. Required because MiniMax-M2.7 pins transformers 4.57.x (no nativeParallelismConfig).DFlashExportCallback(this PR's headline): the Pydantic-recipe refactor (7038dec) dropped the callback that exported the draft submodule after each checkpoint save, leaving a stale "export happens during training via DFlashExportCallback" comment with no callback. FSDP2 SHARDED_STATE_DICT checkpoints carry nomodel.safetensors, so without it there is nothing for vLLM / acceptance-length eval to load. The callback gathers only the ~328 MB draft submodule across shards viaget_model_state_dict(..., submodules={dflash_module}, full_state_dict=True, cpu_offload=True)— works under SHARDED_STATE_DICT without materializing the 229B base — and writesexported-checkpoint-{step}/.Testing
🤖 Generated with Claude Code
Summary by CodeRabbit
New Features
Chores