Skip to content

DFlash speculative decoding for MiniMax-M2.7 (FSDP2): auto mask-token, FSDP2 resume fixes, per-checkpoint draft export#1621

Merged
yeyu-nvidia merged 32 commits into
mainfrom
yeyu/dflash-auto-mask-token
Jun 15, 2026
Merged

DFlash speculative decoding for MiniMax-M2.7 (FSDP2): auto mask-token, FSDP2 resume fixes, per-checkpoint draft export#1621
yeyu-nvidia merged 32 commits into
mainfrom
yeyu/dflash-auto-mask-token

Conversation

@yeyu-nvidia

@yeyu-nvidia yeyu-nvidia commented Jun 3, 2026

Copy link
Copy Markdown
Contributor

What

Brings up DFlash block-diffusion speculative decoding for large MoE targets (MiniMax-M2.7, 229B) trained under accelerate FSDP2, and fixes the regressions that broke checkpoint resume and per-checkpoint draft export.

Commits

  • auto-add mask token for DFlash when the tokenizer lacks one (resize embeddings, restore dtype).
  • requeue support in build_slurm_executor + FSDP2 cpu_ram_efficient_loading for 229B on multi-node.
  • FSDP2 buffer patch (fsdp2_buffer_patch.py): handle non-DTensor buffers in fsdp2_load_full_state_dict, broadcast dtype codes from rank 0, and an FSDP2-safe clip_grad_norm_. Required because MiniMax-M2.7 pins transformers 4.57.x (no native ParallelismConfig).
  • dtype fix: use the broadcast dtype (rank 0) rather than the local meta-device param dtype, so non-leader ranks don't cast bf16 back to fp32 on resume.
  • restore DFlashExportCallback (this PR's headline): the Pydantic-recipe refactor (7038dec) dropped the callback that exported the draft submodule after each checkpoint save, leaving a stale "export happens during training via DFlashExportCallback" comment with no callback. FSDP2 SHARDED_STATE_DICT checkpoints carry no model.safetensors, so without it there is nothing for vLLM / acceptance-length eval to load. The callback gathers only the ~328 MB draft submodule across shards via get_model_state_dict(..., submodules={dflash_module}, full_state_dict=True, cpu_offload=True) — works under SHARDED_STATE_DICT without materializing the 229B base — and writes exported-checkpoint-{step}/.

Testing

  • Resume from FSDP2 sharded checkpoints verified end-to-end (loss/AR continuity).
  • Draft export validated against vLLM: exported drafts load and produce acceptance-length metrics on MT-Bench across the full checkpoint sweep.

🤖 Generated with Claude Code

Summary by CodeRabbit

  • New Features

    • Export draft-submodule weights to dedicated exported checkpoints during training.
    • FSDP2 buffer compatibility and DTensor-aware gradient clipping for safer distributed loading/training.
    • Detect HF-format checkpoints for smarter resume/load behavior.
    • Auto-add and handle a mask special token for draft workflows.
    • vLLM: disable prefix caching to preserve full prompt hidden states.
    • Add CLI option for answer-only loss and save aligned loss masks.
    • Add a SPEED-Bench config for DFLASH/vLLM benchmarking.
  • Chores

    • Robust package version fallback to avoid import failures.

@yeyu-nvidia yeyu-nvidia requested a review from a team as a code owner June 3, 2026 18:42
@yeyu-nvidia yeyu-nvidia requested a review from h-guo18 June 3, 2026 18:42
@coderabbitai

coderabbitai Bot commented Jun 3, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds a DFlash draft-weight export callback, an FSDP2 buffer/DTensor monkey-patch (with DTensor-safe grad clipping), integrates both into the speculative-decoding training script (checkpoint-format detection and mask-token init), disables vLLM prefix caching for hidden-state dumps, and forwards Slurm requeue settings to the launcher executor.

Changes

Speculative Decoding Training Enhancements

Layer / File(s) Summary
DFlash export callback
examples/speculative_decoding/eagle_utils.py
DFlashExportCallback gathers model state (distributed-aware with fallbacks), filters dflash_module.* (excludes rotary_emb), and writes model.safetensors and config.json on master rank; skips empty exports and logs failures.
FSDP2 buffer & DTensor patching
examples/speculative_decoding/fsdp2_buffer_patch.py
Adds monkey-patch to accelerate.utils.fsdp_utils.fsdp2_load_full_state_dict to synchronize dtypes, broadcast non-DTensor buffers from rank 0, reconstruct DTensors via distribute_tensor(), and provides DTensor-aware _clip_grad_norm plus patch_accelerator.
Main script integration and checkpoint/token handling
examples/speculative_decoding/main.py
Imports/conditionally applies fsdp2_buffer_patch, detects HF-formatted checkpoints vs. sharded checkpoints and falls back to base-model loading, ensures dflash_mask_token_id exists (derives or adds `<
vLLM hidden-state config and loss-mask capture
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py
Adds --answer-only-loss arg, allows chat-template overrides/cleanup, tokenizes with aligned loss_mask via tokenize_with_loss_mask, disables enable_prefix_caching for full-prompt hidden states, and includes loss_mask in saved .pt aligned to hidden-state length.`

Launcher Slurm Requeue Configuration

Layer / File(s) Summary
Slurm requeue and retries
tools/launcher/core.py
build_slurm_executor forwards additional_parameters to run.SlurmExecutor, sets requeue=True and ensures retries>=3 when slurm_config.requeue is enabled.

Package initialization

Layer / File(s) Summary
modelopt version fallback
modelopt/__init__.py
Wraps importlib.metadata.version in try/except PackageNotFoundError, falling back to __version__ = "0.0.0+unknown" when distribution metadata is absent.

Benchmark spec

Layer / File(s) Summary
MiniMax DFlash SPEED-Bench spec
tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/specdec_bench.yaml
Adds YAML configuration for running MiniMax-M2.7 DFLASH benchmarks with vLLM, two tasks, and SLURM/container settings pointing to the exported draft checkpoint.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Suggested reviewers

  • ChenhanYu
  • shengliangxu
  • kevalmorabia97
  • h-guo18
🚥 Pre-merge checks | ✅ 5 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 53.85% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and comprehensively summarizes the three main changes: auto mask-token support, FSDP2 resume fixes, and per-checkpoint draft export, all in context of DFlash speculative decoding for MiniMax-M2.7.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Security Anti-Patterns ✅ Passed No security anti-patterns found: no unsafe torch.load/numpy.load, trust_remote_code is configurable (not hardcoded), no eval/exec of untrusted input, and no new unsafe dependencies.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch yeyu/dflash-auto-mask-token

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Warning

CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.

Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.

👉 Steps to fix this

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
examples/speculative_decoding/eagle_utils.py (1)

53-53: 🛠️ Refactor suggestion | 🟠 Major | ⚡ Quick win

Update __all__ to include DFlashExportCallback.

The coding guidelines require defining the public API with __all__. Since DFlashExportCallback is imported by main.py (line 39), it should be exported.

-__all__ = ["EagleOfflineDataCollator", "OfflineSupervisedDataset"]
+__all__ = ["DFlashExportCallback", "EagleOfflineDataCollator", "OfflineSupervisedDataset"]

As per coding guidelines: "Define the public API with __all__ at the top of each Python module."

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/speculative_decoding/eagle_utils.py` at line 53, The module's public
API list __all__ is missing DFlashExportCallback; update the __all__ declaration
(which currently lists "EagleOfflineDataCollator" and
"OfflineSupervisedDataset") to also include "DFlashExportCallback" so the symbol
is exported for consumers like main.py that import it.
🧹 Nitpick comments (5)
examples/speculative_decoding/fsdp2_buffer_patch.py (4)

238-240: ⚡ Quick win

Use print_rank_0 to avoid noisy logs in multi-rank environments.

These print statements execute on every rank, which can produce excessive output on large clusters. Consider using print_rank_0 from modelopt.torch.utils or guarding with a rank check.

+from modelopt.torch.utils import print_rank_0
+
 # In apply() function:
-        print("[fsdp2_buffer_patch] Patched fsdp2_load_full_state_dict for buffer compatibility")
+        print_rank_0("[fsdp2_buffer_patch] Patched fsdp2_load_full_state_dict for buffer compatibility")
     except Exception as e:
-        print(f"[fsdp2_buffer_patch] Patch skipped: {e}")
+        print_rank_0(f"[fsdp2_buffer_patch] Patch skipped: {e}")

As per coding guidelines: "use print_rank_0 or warn_rank_0 to avoid noisy logs."

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/speculative_decoding/fsdp2_buffer_patch.py` around lines 238 - 240,
The two print calls in fsdp2_buffer_patch (the success message and the except
message around fsdp2_load_full_state_dict) should be replaced with rank-safe
logging: import and call print_rank_0 from modelopt.torch.utils (or guard with a
rank check) so messages only appear on rank 0; update the success print and the
exception print to use print_rank_0 and include the exception variable in the
error message (e) while keeping the same text context.

1-3: 💤 Low value

Add __all__ to define the public API.

Per coding guidelines, each Python module should define __all__ to make the public API explicit.

+__all__ = ["apply", "patch_accelerator"]
+
 """Monkey-patch for accelerate's fsdp2_load_full_state_dict buffer handling.

As per coding guidelines: "Define the public API with __all__ at the top of each Python module."

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/speculative_decoding/fsdp2_buffer_patch.py` around lines 1 - 3, This
module is missing an explicit public API; add a module-level __all__ declaration
at the top (after the SPDX headers) listing the public names exported by this
file (e.g. __all__ = ["Name1", "function_name", "CLASS_NAME"]), ensuring each
symbol included matches the actual top-level functions/classes/variables defined
later in the file; place the __all__ immediately below the license lines to
satisfy the coding guideline.

322-326: ⚡ Quick win

Use print_rank_0 here as well.

 def patch_accelerator(accelerator):
     """Replace accelerator's clip_grad_norm_ with FSDP2-safe version."""
     accelerator.clip_grad_norm_ = _clip_grad_norm
-    print("[fsdp2_buffer_patch] Patched accelerator.clip_grad_norm_ "
-          "for FSDP2 DTensor compatibility")
+    from modelopt.torch.utils import print_rank_0
+    print_rank_0("[fsdp2_buffer_patch] Patched accelerator.clip_grad_norm_ "
+                 "for FSDP2 DTensor compatibility")

As per coding guidelines: "use print_rank_0 or warn_rank_0 to avoid noisy logs."

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/speculative_decoding/fsdp2_buffer_patch.py` around lines 322 - 326,
The patch_accelerator function currently uses print to log the patch; replace
that call with print_rank_0 to follow logging guidelines. Update the function to
call print_rank_0("[fsdp2_buffer_patch] Patched accelerator.clip_grad_norm_ for
FSDP2 DTensor compatibility") and ensure print_rank_0 is imported at top of the
module (the same utility used elsewhere), leaving accelerator.clip_grad_norm_ =
_clip_grad_norm unchanged.

269-270: 💤 Low value

Return value should be on the same device as gradients.

When there are no gradients, the function returns a CPU tensor. For consistency with the non-empty case (which returns total_norm on device), consider returning on the same device.

     if len(grads) == 0:
-        return torch.tensor(0.0)
+        device = parameters[0].device if parameters else "cpu"
+        return torch.tensor(0.0, device=device)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/speculative_decoding/fsdp2_buffer_patch.py` around lines 269 - 270,
The early-return creates a CPU tensor when grads is empty; change it to return a
zero tensor on the same device as the gradients by selecting the first available
grad device (e.g. device = next((g.device for g in grads if g is not None),
torch.device('cpu'))) and then return torch.tensor(0.0, device=device) so the
empty-case matches the device of the non-empty case that returns total_norm.
examples/speculative_decoding/main.py (1)

297-303: ⚡ Quick win

Consider gating debug output or removing before merge.

This debug print executes on every rank and will produce verbose output on large clusters. If this is temporary debugging code, consider removing it or guarding with a debug flag.

-    rank = int(os.environ.get("RANK", 0))
-    dtypes = {}
-    for name, p in trainer.model.named_parameters():
-        dt_key = str(p.dtype) if not hasattr(p, "_local_tensor") else str(p._local_tensor.dtype)
-        dtypes.setdefault(dt_key, []).append(name)
-    for dt, names in dtypes.items():
-        print(f"[dtype_check rank={rank}] {dt}: {len(names)} params (e.g. {names[0]})")
+    if os.environ.get("DEBUG_DTYPES"):
+        rank = int(os.environ.get("RANK", 0))
+        dtypes = {}
+        for name, p in trainer.model.named_parameters():
+            dt_key = str(p.dtype) if not hasattr(p, "_local_tensor") else str(p._local_tensor.dtype)
+            dtypes.setdefault(dt_key, []).append(name)
+        for dt, names in dtypes.items():
+            print(f"[dtype_check rank={rank}] {dt}: {len(names)} params (e.g. {names[0]})")

As per coding guidelines: "use print_rank_0 or warn_rank_0 to avoid noisy logs."

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/speculative_decoding/main.py` around lines 297 - 303, The debug loop
printing per-rank dtype info (uses rank, dtypes, iterating
trainer.model.named_parameters()) should be gated or replaced to avoid noisy
logs: either remove the prints or wrap them so only rank 0 logs (use existing
print_rank_0 or warn_rank_0 utility) and/or guard with a debug flag (e.g., if
DEBUG:). Update the block that builds dtypes and the final print to call
print_rank_0 (or warn_rank_0) with the formatted message so only the main
process emits the output, or conditionally execute the entire loop behind a
debug configuration toggle.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@tools/launcher/core.py`:
- Around line 280-287: The code assumes slurm_config.additional_parameters is a
mutable dict and mutates it directly, which can cause shared-state bugs; before
assigning to executor.additional_parameters (and before mutating it to set
"requeue"), validate and normalize slurm_config.additional_parameters to a plain
dict (e.g., treat None, mappings, or other types safely), create a shallow copy
for executor.additional_parameters, and then mutate that copy; also ensure
executor.retries is updated via executor.retries = max(executor.retries, 3) as
shown. Reference: slurm_config.additional_parameters,
executor.additional_parameters, and executor.retries.

---

Outside diff comments:
In `@examples/speculative_decoding/eagle_utils.py`:
- Line 53: The module's public API list __all__ is missing DFlashExportCallback;
update the __all__ declaration (which currently lists "EagleOfflineDataCollator"
and "OfflineSupervisedDataset") to also include "DFlashExportCallback" so the
symbol is exported for consumers like main.py that import it.

---

Nitpick comments:
In `@examples/speculative_decoding/fsdp2_buffer_patch.py`:
- Around line 238-240: The two print calls in fsdp2_buffer_patch (the success
message and the except message around fsdp2_load_full_state_dict) should be
replaced with rank-safe logging: import and call print_rank_0 from
modelopt.torch.utils (or guard with a rank check) so messages only appear on
rank 0; update the success print and the exception print to use print_rank_0 and
include the exception variable in the error message (e) while keeping the same
text context.
- Around line 1-3: This module is missing an explicit public API; add a
module-level __all__ declaration at the top (after the SPDX headers) listing the
public names exported by this file (e.g. __all__ = ["Name1", "function_name",
"CLASS_NAME"]), ensuring each symbol included matches the actual top-level
functions/classes/variables defined later in the file; place the __all__
immediately below the license lines to satisfy the coding guideline.
- Around line 322-326: The patch_accelerator function currently uses print to
log the patch; replace that call with print_rank_0 to follow logging guidelines.
Update the function to call print_rank_0("[fsdp2_buffer_patch] Patched
accelerator.clip_grad_norm_ for FSDP2 DTensor compatibility") and ensure
print_rank_0 is imported at top of the module (the same utility used elsewhere),
leaving accelerator.clip_grad_norm_ = _clip_grad_norm unchanged.
- Around line 269-270: The early-return creates a CPU tensor when grads is
empty; change it to return a zero tensor on the same device as the gradients by
selecting the first available grad device (e.g. device = next((g.device for g in
grads if g is not None), torch.device('cpu'))) and then return torch.tensor(0.0,
device=device) so the empty-case matches the device of the non-empty case that
returns total_norm.

In `@examples/speculative_decoding/main.py`:
- Around line 297-303: The debug loop printing per-rank dtype info (uses rank,
dtypes, iterating trainer.model.named_parameters()) should be gated or replaced
to avoid noisy logs: either remove the prints or wrap them so only rank 0 logs
(use existing print_rank_0 or warn_rank_0 utility) and/or guard with a debug
flag (e.g., if DEBUG:). Update the block that builds dtypes and the final print
to call print_rank_0 (or warn_rank_0) with the formatted message so only the
main process emits the output, or conditionally execute the entire loop behind a
debug configuration toggle.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: ec55dcce-a920-44ca-8e39-ee3167ca3eeb

📥 Commits

Reviewing files that changed from the base of the PR and between 88fd7ff and d2d0558.

📒 Files selected for processing (4)
  • examples/speculative_decoding/eagle_utils.py
  • examples/speculative_decoding/fsdp2_buffer_patch.py
  • examples/speculative_decoding/main.py
  • tools/launcher/core.py

Comment thread tools/launcher/core.py Outdated
@codecov

codecov Bot commented Jun 3, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 66.66667% with 5 lines in your changes missing coverage. Please review.
✅ Project coverage is 76.21%. Comparing base (46eddab) to head (69ad872).
⚠️ Report is 28 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/torch/speculative/plugins/hf_dflash.py 50.00% 5 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1621      +/-   ##
==========================================
+ Coverage   67.73%   76.21%   +8.47%     
==========================================
  Files         511      511              
  Lines       56169    57193    +1024     
==========================================
+ Hits        38044    43587    +5543     
+ Misses      18125    13606    -4519     
Flag Coverage Δ
examples 41.83% <6.66%> (+0.52%) ⬆️
gpu 57.69% <46.66%> (+25.74%) ⬆️
regression 14.67% <60.00%> (+0.03%) ⬆️
unit 54.39% <66.66%> (+0.05%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@yeyu-nvidia yeyu-nvidia force-pushed the yeyu/dflash-auto-mask-token branch from d2d0558 to 5496efc Compare June 9, 2026 17:26

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Warning

CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.

Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.

👉 Steps to fix this

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@examples/speculative_decoding/fsdp2_buffer_patch.py`:
- Around line 269-270: The early-return creates a CPU tensor when no grads
exist, causing device mismatch with the normal path which returns total_norm on
the GPU (the variable device is set around line 274); update the early-return to
produce a tensor on the same device as total_norm by constructing the zero
tensor on the same device (e.g., using the device variable or
total_norm/new_tensor style) so the returned tensor's device matches the normal
path (refer to grads, total_norm, and device to locate the code).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: b39e1775-7aa8-482b-84ba-391c5f4eaef7

📥 Commits

Reviewing files that changed from the base of the PR and between d2d0558 and 5496efc.

📒 Files selected for processing (4)
  • examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py
  • examples/speculative_decoding/eagle_utils.py
  • examples/speculative_decoding/fsdp2_buffer_patch.py
  • examples/speculative_decoding/main.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • examples/speculative_decoding/eagle_utils.py
  • examples/speculative_decoding/main.py

Comment thread examples/speculative_decoding/fsdp2_buffer_patch.py Outdated

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py (1)

246-246: ⚡ Quick win

Add defensive length check before truncating loss_mask.

If output_hidden_states is longer than loss_mask, Python slice semantics return the full (too-short) loss_mask, creating a mismatch with the saved input_ids length. The downstream OfflineSupervisedDataset loader does not validate shape alignment, risking silent training errors. With enable_prefix_caching=False (line 194), lengths should match exactly—consider adding a warning to catch violations:

🛡️ Suggested defensive check
+        expected_len = output_hidden_states.shape[0]
+        if loss_mask.shape[0] != expected_len:
+            import warnings
+            warnings.warn(
+                f"Conversation {conv_id}: loss_mask length {loss_mask.shape[0]} != "
+                f"hidden_states length {expected_len}; may indicate tokenization mismatch."
+            )
+
         output_file = output_dir / f"{conv_id}.pt"
         with open(output_file, "wb") as f:
             torch.save(
                 {
                     "input_ids": token_ids.cpu(),
                     "hidden_states": output_hidden_states,
                     "aux_hidden_states": aux_hidden_states,
-                    "loss_mask": loss_mask[: output_hidden_states.shape[0]].cpu(),
+                    "loss_mask": loss_mask[:expected_len].cpu(),
                     "conversation_id": conv_id,
                 },
                 f,
             )
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In
`@examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py`
at line 246, Before truncating loss_mask for storage, ensure its length matches
output_hidden_states.shape[0]; check if output_hidden_states.shape[0] >
loss_mask.shape[0] and if so log a warning (or raise an error) calling out the
mismatch between loss_mask and output_hidden_states lengths (mention
enable_prefix_caching if relevant), otherwise perform the slice as before:
"loss_mask = loss_mask[: output_hidden_states.shape[0]].cpu()". Reference the
variables loss_mask and output_hidden_states and the OfflineSupervisedDataset
consumer when adding this defensive check.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In
`@examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py`:
- Line 246: Before truncating loss_mask for storage, ensure its length matches
output_hidden_states.shape[0]; check if output_hidden_states.shape[0] >
loss_mask.shape[0] and if so log a warning (or raise an error) calling out the
mismatch between loss_mask and output_hidden_states lengths (mention
enable_prefix_caching if relevant), otherwise perform the slice as before:
"loss_mask = loss_mask[: output_hidden_states.shape[0]].cpu()". Reference the
variables loss_mask and output_hidden_states and the OfflineSupervisedDataset
consumer when adding this defensive check.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 898bf778-31d1-4a22-a5f1-2672bcf3b208

📥 Commits

Reviewing files that changed from the base of the PR and between 5496efc and d5ea663.

📒 Files selected for processing (1)
  • examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py

@github-actions

github-actions Bot commented Jun 9, 2026

Copy link
Copy Markdown
Contributor
PR Preview Action v1.8.1
Preview removed because the pull request was closed.
2026-06-15 23:48 UTC

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Warning

CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.

Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.

👉 Steps to fix this

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In
`@examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py`:
- Around line 47-73: The _resolve_aux_layers_standalone function currently
parses comma-separated aux layer IDs but doesn't validate they lie in [0,
num_hidden_layers); update it to check each parsed id (from aux_layers.split)
against 0 <= id < num_hidden_layers and raise a ValueError with a clear message
(mirroring resolve_aux_layers) if any id is out of range or negative; keep the
existing behavior for the "dflash" preset and ensure the error references the
original aux_layers input and num_hidden_layers for clarity.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: b02075cf-1305-4259-b54f-425362cba18c

📥 Commits

Reviewing files that changed from the base of the PR and between e6c552f and f7844eb.

📒 Files selected for processing (2)
  • examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py
  • tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/specdec_bench.yaml

Comment thread examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py Outdated
Comment thread modelopt/__init__.py Outdated
Comment thread tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_offline_dflash.yaml Outdated
Comment thread examples/speculative_decoding/main.py Outdated
Comment thread examples/speculative_decoding/main.py Outdated

@h-guo18 h-guo18 left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add some tests that cover the new patches and callbacks?

@yeyu-nvidia yeyu-nvidia requested a review from a team as a code owner June 10, 2026 17:00
@yeyu-nvidia yeyu-nvidia requested a review from ChenhanYu June 10, 2026 17:00
Comment thread modelopt/torch/export/plugins/hf_spec_export.py Outdated
Comment thread examples/speculative_decoding/fsdp2_buffer_patch.py
@yeyu-nvidia

Copy link
Copy Markdown
Contributor Author

Thanks for the review — addressed all comments. Summary of this round:

  • c69df2e — CodeRabbit nitpicks (additional_parameters dict copy, clip_grad empty-grad device, aux-layer bounds, loss_mask length guard) + reverted the modelopt/__init__ version guard per @kevalmorabia97.
  • 2aeda4f — converged DFlash export RoPE to a config field dflash_export_rope_scaling (eagle-style) per @hguo-nv; gated DFlashExportCallback on FSDP2; switched the offline recipe to DDP (FakeBaseModel doesn't need FSDP).
  • 1f1d9ed — embedding-resize fix: reuse an existing reserved embedding row for the mask token instead of resizing (the draft ships no embeddings, so a resized row is neither trained nor exported); MiniMax recipes pin dflash_mask_token_id=200054. The released checkpoint ships as-is.
  • 4327683 — documented the applicability/scope of fsdp2_buffer_patch (FSDP2-via-accelerate-config on transformers 4.57.x, currently MiniMax-only).

Both paths re-validated end-to-end: offline dump→train→export green, and the YaRN export config recovers 32k AL (1.17 → 2.62).

Comment thread examples/speculative_decoding/doc/dflash.md Outdated
Comment thread examples/speculative_decoding/main.py Outdated
@yeyu-nvidia

Copy link
Copy Markdown
Contributor Author

Added in e25fd71:

  • DFlash export rope-scaling — unit tests in tests/unit/torch/export/test_hf_spec_rope_export.py (alongside the eagle ones): YaRN injected from the dflash_export_rope_scaling config field, empty dict disables it, and rope_theta inherits the base/target. (7 passed locally.)
  • Mask-token resolution — extracted the main.py inline logic into a pure helper resolve_dflash_mask_token_id() in modeling_dflash.py and unit-tested all four branches (configured id / tokenizer mask id / reuse existing reserved row / must-resize) in test_hf_dflash.py::TestResolveMaskTokenId.

The DFlashExportCallback gather and the fsdp2_buffer_patch state-dict path are inherently FSDP2/distributed, so they stay covered by the dflash regression suite (tests/regression/torch/speculative/test_dflash*.py, which already runs train→export→AR-validate). Happy to add a multi-GPU test for the callback if you think the regression coverage is not enough.

yeyu-nvidia and others added 6 commits June 10, 2026 12:18
Models without a <|mask|> token (e.g., MiniMax-M2.7) would fail with
ValueError during DFlash training. Instead of requiring the user to
manually set dflash_mask_token_id, add the token to the tokenizer
and resize model embeddings automatically.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
When slurm_config.requeue is True, set additional_parameters["requeue"]
= True so nemo-run emits #SBATCH --requeue in the sbatch script.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
1. main.py: When FSDP2 cpu_ram_efficient_loading is active, only rank 0
   loads real weights on CPU; other ranks use meta device. FSDP2
   distributes from rank 0. Also adds dp_replicate_size auto-computation
   so dp_replicate * dp_shard * cp == world_size.

2. core.py: Set retries=3 when requeue is requested. The nemo-run sbatch
   wrapper only calls scontrol requeue when TORCHX_MAX_RETRIES >
   SLURM_RESTART_COUNT — retries=0 (the default) disabled requeue.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
…heckpoint resume

The Pydantic recipe refactor dropped fsdp2_buffer_patch.apply() and
patch_accelerator() calls and added a buffer-to-CUDA block that moved
DFlash buffers before FSDP wrapping. With cpu_ram_efficient_loading,
non-rank-0 processes have meta-device params, causing
_infer_parameter_dtype() to return fp32 instead of bf16 on resume.

Also detects FSDP distributed checkpoints (no HF model files) and
loads the base model instead of trying from_pretrained on them.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
… patch

_infer_parameter_dtype() reads the model's current param dtype to cast
the broadcasted tensor. With cpu_ram_efficient_loading, non-rank-0
processes have fp32 meta-device params for DFlash, so
_infer_parameter_dtype returns fp32 and _finish() casts the correctly-
broadcasted bf16 tensor back to fp32. Use bcast_dtype (from rank 0)
instead.

Also prints dtype_check on all ranks to verify consistency.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
The Pydantic-recipe refactor (7038dec) dropped DFlashExportCallback,
which had exported the DFlash draft submodule after every checkpoint save.
Without it, FSDP2 sharded checkpoints (pytorch_model_fsdp_0/, no
model.safetensors) get no exported-checkpoint-{step}/, so downstream
vLLM deployment / acceptance-length eval has nothing to load. The
verify-only comment 'export happens during training via DFlashExportCallback'
was left behind but the callback itself was gone.

Restore the callback (gathers only the ~328MB draft submodule across
shards via get_model_state_dict, so it works under SHARDED_STATE_DICT
without materializing the full base model) and wire it into main.py for
DFlash recipes.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
yeyu-nvidia and others added 2 commits June 12, 2026 10:50
…g comment

Per @hguo-nv:
- rope_theta: the base-first choice is intentional for DFlash — the draft injects the
  target's KV into every layer, so its RoPE base must match the target's. The draft arch
  config carries no rope_theta, so draft-first fell back to the 1e6 default (a mismatch
  with the target's 5e6). Tightened the comment to state this.
- Trimmed the verbose rope_scaling comment to the essentials.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
…er extraction

Addresses @hguo-nv's points 1-2 on the export callback:

1. The callback now self-skips when the checkpoint was saved as a full state dict
   (model.safetensors / pytorch_model.bin present) — i.e. it only does work under FSDP2
   SHARDED_STATE_DICT, where checkpoints are distributed shards the post-run export can't
   read. The check is from the on-disk checkpoint format (identical across ranks, before
   the collective gather). main.py now appends it unconditionally for DFlash since it
   self-gates (DDP / single-device / FSDP2-full all self-skip), removing the coarse
   dp_shard_size/env heuristic.

2. Extraction now delegates to DFlashExporter._extract_state_dict (the same logic the
   normal export path uses) for the common prefixed-key layout, with a fallback for the
   already-stripped submodule-gather keys. Config generation already delegated to the
   exporter; this removes the remaining duplication.

The gather (get_model_state_dict on the dflash_module submodule) — the load-bearing
FSDP2 part that produced the release — is unchanged. The full split into separate
reusable gather/export callbacks is deferred: offline DFlash uses DDP (full checkpoints,
no callback needed) and eagle has its own export path, so there is no current consumer.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Comment thread examples/speculative_decoding/main.py Outdated
… dtype diag

Implements @hguo-nv's agreed asks on the DFlash export callback:

- Rename DFlashExportCallback -> DFlashFSDP2ShardedSDExportCallback to make its applicable
  range (FSDP2 SHARDED_STATE_DICT only) explicit in the name.
- Gate it in main.py by reading the live accelerator FSDP state dict type
  (trainer.accelerator.state.fsdp_plugin.state_dict_type) instead of the callback's
  on-disk checkpoint-filename heuristic; the callback is added only for SHARDED_STATE_DICT
  (full-state-dict runs use the launcher's post-run export). Removed the in-callback
  early-exit accordingly.
- Move the per-rank dtype diagnostic out of main.py into
  fsdp2_buffer_patch.log_param_dtypes(), gated behind DFLASH_LOG_PARAM_DTYPES=1 (no-op by
  default), since it only exists to verify the FSDP2 dtype sync.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
@ChenhanYu

Copy link
Copy Markdown
Collaborator

/claude review

@yeyu-nvidia

Copy link
Copy Markdown
Contributor Author

Acceptance length (released MiniMax-M2.7 DFlash draft, step 20400)

Measured with the specdec_bench harness on vLLM DFlash, greedy (temperature 0), 4096-token generations, target MiniMaxAI/MiniMax-M2.7. The released draft ships the YaRN rope_scaling (so it drafts at the target's full context); numbers are for that shipped config. Records uploaded to s3://team-specdec-workgroup/results/minimax_m2.7_dflash_vllm/.

Benchmark AL (draft 3) AL (draft 7)
MT-Bench (80 prompts) 3.05 3.08
SPEED-Bench qualitative (880 prompts) 2.90 3.06
SPEED-Bench throughput-32k (long context) 2.61

Draft length 7 is the recommended serving config (DFlash block_size − 1). The 32k long-context AL (2.61) confirms the YaRN export holds at the target's full window (vs ~1.17 without it).

Comment thread examples/speculative_decoding/main.py Outdated
Addresses @hguo-nv: the modify() inheritance was a setdefault, so a rope_theta provided
in dflash_architecture_config would NOT be overridden — the draft would train with that
value while the exporter writes the base's rope_theta, causing a train/inference mismatch.

DFlash injects the target's KV into every draft layer, so the draft's RoPE base must match
the target's. Enforce rope_theta/rope_type/rope_interleaved from the base model (overwrite
any user value, with a warning); keep the other architecture attrs as setdefault. This
makes training and export agree by construction. rope_scaling stays excluded (added only
at export via dflash_export_rope_scaling).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
@yeyu-nvidia yeyu-nvidia requested a review from a team as a code owner June 15, 2026 21:10
@yeyu-nvidia yeyu-nvidia requested a review from chadvoegele June 15, 2026 21:10
… check

Per @hguo-nv: tidy the inline checkpoint-format probe in main.py into a named helper.
Kept it distinct from the export-callback's save-mode gate on purpose — this inspects the
*resume* checkpoint's on-disk format (a property of existing bytes, checked pre-trainer),
whereas the gate reads the *current run's* fsdp_state_dict_type (post-trainer); the two
answer different questions and can differ across runs, so they shouldn't share one
FSDP2+ShardedSD condition.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>

@claude claude Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

Thorough pass on the DFlash / FSDP2 / mask-token / per-checkpoint-export PR.
The motivation and architecture are solid — gated patches, well-documented
applicability, and the export-callback / config-field convergence with eagle
is the right shape. Below are the issues I found that I think the author
should look at before merging.

Findings by severity

  • CRITICAL: 1
  • IMPORTANT: 6
  • SUGGESTION: 3

Highlights

CRITICAL — _clip_grad_norm deadlock when not all ranks see grads
fsdp2_buffer_patch.py:312-344 returns early on ranks with no grads but
unconditionally issues dist.all_reduce(sharded_norm_p) on the others.
First time a rank hits an empty-grad step (sharded MoE expert with no
tokens, etc.), the cluster hangs. Fix: every rank must reach the same
collective, or none.

IMPORTANT — silent fallback paths in the export callback (eagle_utils.py:269-291)

  1. The TypeError-fallback branch reverts to a full-model gather of the
    229B base, completely defeating the submodule-only design. Should fail
    loudly (or at minimum warn) on unsupported PyTorch versions.
  2. The fallback key heuristic differs from DFlashExporter._extract_state_dict
    and bypasses _check_valid_sd, so a malformed export ships silently and
    only fails at vLLM load time.

IMPORTANT — _DTYPE_TO_CODE.get(..., 0) silently coerces unknown dtypes
to fp32
(fsdp2_buffer_patch.py:178). Any dtype outside the four-entry
map (int8 / fp64 / bool / fp8_e5m2 / …) becomes fp32 on rank 0 and the
broadcast either casts data on the wire or NCCL refuses on size mismatch.
Raise on unknown dtype instead of silently coercing.

IMPORTANT — standalone aux-layer resolver hard-codes num_draft=5
(compute_hidden_states_vllm.py:62). Today the recipes match, but a
deeper draft + --aux-layers dflash silently produces mis-aligned dumps
and downstream AR regressions. Make the depth a CLI arg.

IMPORTANT — rope_theta or-fallback (hf_spec_export.py:382-383)
swallows 0 / 0.0. Use is not None for parity with the eagle path.

IMPORTANT — /dev/shm staging (compute_hidden_states_vllm.py:224)
without cleanup leaks RAM-backed tmpfs across crashes; on some HPC
containers /dev/shm is unmapped or undersized. Make it overridable via
env var and clean up on exit.

IMPORTANT — HF-format checkpoint detection (main.py:189-194) does
not appear to re-apply mtsp.convert on resume of a previously-saved
modelopt HF dir; needs verification that enable_huggingface_checkpointing
covers the DFlash plugin.

SUGGESTIONs — duplicate doc row, dead _orig capture,
exported-checkpoint-final mismatch in specdec_bench.yaml (callback
writes per-step dirs, not -final).

Risk assessment

The core algorithmic and export logic look right; the FSDP2 buffer/dtype
broadcast patch is reasonable for the narrow scope it covers, and the
callback gating refactor (only under SHARDED_STATE_DICT) is the right
shape. The crit risk is concentrated in two places: (1) the clip_grad_norm
collective imbalance, (2) the silent-fallback paths in the export callback
and dtype-encoder. Either could turn into a hard-to-debug production
failure long after merge.

Recommend addressing the CRITICAL and the four "silent fallback" IMPORTANT
items before merging; the rest are good follow-ups.

Comment on lines +269 to +288
try:
raw_sd = get_model_state_dict(
model, submodules={model.dflash_module}, options=options
)
except TypeError:
# Older PyTorch without submodules parameter — gather full model
raw_sd = get_model_state_dict(model, options=options)
except ImportError:
# Non-distributed / single-GPU fallback
raw_sd = model.state_dict()

# Reuse the exporter's extraction (strips the dflash_module prefix, drops rotary
# buffers) for the common full-model key layout. Some PyTorch versions return the
# submodule gather with keys already stripped of the prefix — handle that directly.
exporter = model.get_exporter()
drafter_sd = exporter._extract_state_dict(raw_sd)
if not drafter_sd:
drafter_sd = {
k: v
for k, v in raw_sd.items()

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[IMPORTANT Performance/ModeState] The silent full-model fallback when submodules= is not supported is dangerous for the very target this callback exists to support.

The whole point of this callback (per the docstring and PR description) is to gather only the ~328 MB DFlash submodule under FSDP2 SHARDED_STATE_DICT, without materializing the 229B base. But this fallback path:

except TypeError:
    # Older PyTorch without submodules parameter — gather full model
    raw_sd = get_model_state_dict(model, options=options)

silently calls get_model_state_dict(model, options=options) (no submodule filter, full_state_dict=True, cpu_offload=True), which gathers the entire 229B base on every save. That is many minutes per checkpoint of all-gather plus ~230 GB of CPU memory pressure per node — exactly what the submodule gather was added to avoid. On older PyTorch it wouldn't fail loudly; the user would just see saves get mysteriously expensive (or OOM) without understanding why.

Suggestion: either (a) raise a clear error pointing the user at a minimum supported PyTorch version, or (b) at minimum print_rank_0 a loud warning naming the model size and that this path will gather the full base model. The current silent fallback masks a tail-latency / OOM failure mode in the only scenario that matters.

Comment on lines +285 to +295
if not drafter_sd:
drafter_sd = {
k: v
for k, v in raw_sd.items()
if "rotary_emb" not in k
and not any(p in k for p in ("model.", "lm_head.", "embed_tokens."))
}
del raw_sd
# Coerce to CPU for save_file (the distributed gather uses cpu_offload, but the
# single-GPU fallback may return CUDA tensors).
drafter_sd = {k: (v.cpu() if v.device.type != "cpu" else v) for k, v in drafter_sd.items()}

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[IMPORTANT Algorithm/Export] The fallback key heuristic doesn't match what DFlashExporter._extract_state_dict actually returns — it will produce an empty export silently.

DFlashExporter._extract_state_dict (modelopt/torch/export/plugins/hf_spec_export.py:322-332) keeps only entries that contain "dflash_module." and strips that prefix. If get_model_state_dict(..., submodules={model.dflash_module}, full_state_dict=True) returns keys already pre-stripped (without the dflash_module. prefix), then exporter._extract_state_dict(raw_sd) returns {} — that's why this fallback exists.

But the heuristic here:

drafter_sd = {
    k: v
    for k, v in raw_sd.items()
    if "rotary_emb" not in k
    and not any(p in k for p in ("model.", "lm_head.", "embed_tokens."))
}

filters out anything containing "model.", which kills layer/decoder weights named like layers.0.self_attn.q_proj.weight (no model. substring) — fine — but it also kills any key that happens to legitimately contain "model." if a future DFlash module nests differently. More importantly, the substring tests "lm_head." and "embed_tokens." will NEVER match in the pre-stripped DFlash submodule case (the DFlash draft has neither). The exporter's _check_valid_sd is also bypassed in this branch — so a malformed state dict (e.g. weights from the wrong module after a refactor) would write a broken model.safetensors and only fail at vLLM load time.

Recommendation: instead of the substring heuristic, detect the pre-stripped case explicitly and either (a) re-prefix the keys with "dflash_module." and rerun exporter._extract_state_dict (so the same validation path runs), or (b) at minimum, assert the resulting key set looks sane (e.g. contains fc.weight, norm.weight, layers.0.self_attn.q_proj.weight) before writing. Right now a key-layout regression would pass through silently.

Comment on lines +311 to +344
grads = [p.grad for p in parameters if p.grad is not None]
if len(grads) == 0:
# Match the device of the normal return path (GPU when training on CUDA) so
# callers don't hit a device mismatch on the empty-grad case.
dev = "cuda" if torch.cuda.is_available() else "cpu"
return torch.tensor(0.0, device=dev)

# Shard DTensors hold partial data — need all_reduce for global norm.
# Replicate DTensors and regular tensors already hold full data.
device = grads[0]._local_tensor.device if isinstance(grads[0], DTensor) else grads[0].device
sharded_norm_p = torch.tensor(0.0, device=device)
local_norm_p = torch.tensor(0.0, device=device)

n_sharded = 0
n_replicate = 0
n_regular = 0

for g in grads:
if isinstance(g, DTensor):
is_sharded = any(isinstance(p, Shard) for p in g.placements)
t = g._local_tensor.detach().to(torch.float32)
n = torch.linalg.vector_norm(t, norm_type)
if is_sharded:
sharded_norm_p += n.pow(norm_type)
n_sharded += 1
else:
local_norm_p += n.pow(norm_type)
n_replicate += 1
else:
n = torch.linalg.vector_norm(g.detach().to(torch.float32), norm_type)
local_norm_p += n.pow(norm_type)
n_regular += 1

dist.all_reduce(sharded_norm_p, op=dist.ReduceOp.SUM)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[CRITICAL Algorithm] Unconditional dist.all_reduce(sharded_norm_p) deadlocks when not all ranks see grads.

grads = [p.grad for p in parameters if p.grad is not None] may legitimately be empty on some ranks (e.g. an FSDP2-sharded MoE where a rank's expert gets no tokens this micro-batch, or a pipeline-parallel-style setup, or simply early steps where some ranks' params haven't yet accumulated grads). The function returns early with torch.tensor(0.0) on those ranks (line 312-316), but the other ranks fall through and call dist.all_reduce(sharded_norm_p) (line 344). Since the empty-grad rank does NOT participate in the collective, the all_reduce hangs forever. The same hazard exists if every grad on a rank is a Replicate DTensor / regular tensor (so n_sharded == 0) while another rank has shards — the reduce is still issued, but on a tensor that wasn't summed against any sharded contribution; that is benign for the math but still requires every rank to reach the call, which the early return breaks.

Two issues to fix:

  1. The early-return must not skip the collective. Either every rank participates in the all_reduce (initialize sharded_norm_p to zero before the early-return and unconditionally reduce), or none do.
  2. Document/assert that clip_grad_norm_ is only ever called on a parameters set that is identical across ranks (which is the only case where "all ranks have at least one grad" is guaranteed). If that's the design contract, please at least add an assertion + comment so a future caller doesn't introduce uneven grads and silently deadlock.

The current shape — early-return on one rank, collective on the others — is a deadlock waiting for the first edge case.

Comment on lines +176 to +189
if accelerator.is_main_process:
dtype_codes = torch.tensor(
[_DTYPE_TO_CODE.get(full_sd[name].dtype, 0) for name in meta_sharded_sd],
dtype=torch.int32,
device=accelerator.device,
)
else:
dtype_codes = torch.empty(
n_total,
dtype=torch.int32,
device=accelerator.device,
)
dist.broadcast(dtype_codes, src=0, group=dist.group.WORLD)
broadcast_dtypes = [_CODE_TO_DTYPE[c.item()] for c in dtype_codes]

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[IMPORTANT Algorithm] Silent fallback _DTYPE_TO_CODE.get(..., 0) + _CODE_TO_DTYPE[c.item()] corrupts unsupported-dtype entries instead of failing loudly.

On rank 0 you encode each full_sd[name].dtype to a code via .get(dtype, 0) — meaning any dtype that isn't in the four-entry map (fp32 / bf16 / fp16 / fp8_e4m3fn) silently becomes 0 (= float32). Other ranks then build a fp32 receive buffer and dist.broadcast casts the rank-0 tensor to fp32 on the wire (or worse, NCCL refuses if the byte size doesn't match). This silently corrupts:

  • int8 / uint8 checkpoints (e.g. some quantized loads)
  • float8_e5m2
  • float64 / int64 (rare for weights but valid for some buffers like rotary inv_freq cached as fp64 by some HF model code)
  • bool (some attention masks)

Because the comment explicitly says this path supports fp8 weights, a bf16 vs fp8 mis-encoding would silently broadcast garbage across the cluster and you'd only notice from a downstream NaN/AR collapse — exactly the kind of ghost-bug this patch is meant to prevent. Rank-0 should raise ValueError(f"Unsupported broadcast dtype {dtype} for {name}") rather than coerce to a default.

Bonus: the int32 broadcast carries n_total codes, but the _CODE_TO_DTYPE decode loop treats c.item() as authoritative — so if a future torch version re-orders the fp8 enum to a different code, the if hasattr(...) else -1 initializer here makes that change a silent encoding flip rather than an error. Pin codes by name, not by attribute existence.

import accelerate.utils.fsdp_utils as fsdp_utils
from torch.distributed.tensor import DTensor

_orig = fsdp_utils.fsdp2_load_full_state_dict

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION] _orig is captured but never invoked — dead variable.

The original fsdp2_load_full_state_dict is captured here for an apparent rollback / wrapper pattern, but _patched never delegates to it. Either drop the capture (it confused me on first read into thinking there was a fallback path) or document why it exists (e.g. for an unapply() you plan to add). Trivial cleanup.

Comment on lines +166 to +167
| `dflash.dflash_architecture_config.num_hidden_layers` | 5 | Draft decoder layers |
| `dflash.dflash_architecture_config.mask_token_id` | auto | Token ID for masked positions |
| `dflash.dflash_mask_token_id` | auto | Token ID for masked positions (see note below) |

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION] Duplicate dflash_architecture_config.num_hidden_layers row.

Suggested change
| `dflash.dflash_architecture_config.num_hidden_layers` | 5 | Draft decoder layers |
| `dflash.dflash_architecture_config.mask_token_id` | auto | Token ID for masked positions |
| `dflash.dflash_mask_token_id` | auto | Token ID for masked positions (see note below) |
| `dflash.dflash_architecture_config.num_hidden_layers` | 5 | Draft decoder layers |
| `dflash.dflash_mask_token_id` | auto | Token ID for masked positions (see note below) |

The current diff inserts the new dflash_mask_token_id row but keeps the old dflash_architecture_config.num_hidden_layers line and ALSO repeats it on line 168 — the rendered table will have num_hidden_layers twice in a row. Drop the duplicate (the line below this one).

Comment on lines +61 to +68
if spec == "dflash":
num_draft = 5
if num_draft == 1:
return [num_hidden_layers // 2]
start = min(1, num_hidden_layers - 1)
end = max(start, num_hidden_layers - 3)
span = end - start
return sorted({round(start + (i * span) / (num_draft - 1)) for i in range(num_draft)})

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[IMPORTANT Algorithm] The standalone resolver hard-codes num_draft=5 while the recipe value is configurable — a mismatch silently mis-aligns dumped aux layers.

build_target_layer_ids(num_target_layers, num_draft_layers) (modeling_dflash.py:58-69) is parameterized by num_draft_layers. The MiniMax recipes here pin dflash.dflash_architecture_config.num_hidden_layers=5, so today the constant matches — but if anyone runs offline DFlash with a different draft depth (4, 6, 8 etc.) and forgets to switch from --aux-layers dflash to an explicit comma-list, the dump silently captures a different set of target layers than the trained draft will actually consume at training time, since target_layer_ids is also derived from num_draft_layers via the same function in hf_dflash.py:169. The training run will load mis-aligned aux features and learn garbage; the failure mode is "AR mysteriously regresses" rather than a loud crash.

Two fixes either of which would close the gap:

  1. Take the draft depth as a CLI flag (--dflash-num-draft 5) and pass it through.
  2. At least raise a loud error / log if the recipe-default 5 does not match the model's actual num_hidden_layers divisibility expectations, and document the constant-coupling explicitly in the docstring.

The TODO already acknowledges this is fragile; the comment-coupling to a hard-coded 5 is the actual breakage vector. Suggest making the value a required CLI arg and removing the default.

Comment on lines +220 to +224
# Stage the connector's intermediate safetensors on local tmpfs, not the (lustre)
# output dir: the producer writes one file per request and the client reads it back
# immediately, so a fast local path avoids cross-node FS latency. Per-DP-rank dir so
# parallel shards don't collide.
storage_path = Path("/dev/shm") / f"vllm_hidden_states_dp{args.dp_rank}"

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[IMPORTANT Performance] Hard-coding /dev/shm and skipping cleanup is a footgun on shared/multi-tenant nodes.

Two issues with this hardcoded path:

  1. No cleanup on failure / exit. A run that crashes after writing K safetensors leaves them in /dev/shm (which is RAM-backed on Linux) and they persist until the node is rebooted or another job evicts them. With per-DP-rank dirs and ~MiniMax-M2.7-shape hidden states (vocab × hidden × seq_len), this can accumulate to tens or hundreds of GB of stranded RAM. Suggest wrapping with tempfile.TemporaryDirectory() (also pickable per DP rank) or at minimum register atexit.register to remove storage_path.

  2. /dev/shm may not exist or may be too small. Some HPC scheduler containers either don't bind-mount /dev/shm (it's a tmpfs the orchestration must wire up) or cap it tightly (Slurm default is sometimes 64 MB). For MiniMax-M2.7 (229B target) and 4096-token sequences, per-conversation tensor sizes are non-trivial; /dev/shm overflow yields a confusing OSError: [Errno 28] No space left on device at write time, mid-job. Worth at least making the staging path overridable with an env var (VLLM_HIDDEN_STATES_TMPDIR) and falling back through tempfile.gettempdir().

Neither blocks the merge but both will bite production. The cleanup omission is the more serious of the two.

Comment on lines +379 to +383
# Inherit the target's rope_theta: DFlash injects the target's KV into every
# draft layer, so the draft's RoPE base must match the target's. (The draft
# arch config carries no rope_theta of its own.)
"rope_theta": getattr(base_config, "rope_theta", None)
or getattr(draft_config, "rope_theta", 1000000.0),

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[IMPORTANT Compatibility] The or short-circuit silently swallows rope_theta=0 (and any other falsy numeric value).

"rope_theta": getattr(base_config, "rope_theta", None)
or getattr(draft_config, "rope_theta", 1000000.0),

If base_config.rope_theta were ever 0 or 0.0 (legal for some models that disable RoPE entirely, though uncommon for production targets), or falls through to the draft default 1000000.0. The intent (per the comment) is "use the target's value if set, otherwise fall back" — which means the test should be is not None, not truthy:

Suggested change
# Inherit the target's rope_theta: DFlash injects the target's KV into every
# draft layer, so the draft's RoPE base must match the target's. (The draft
# arch config carries no rope_theta of its own.)
"rope_theta": getattr(base_config, "rope_theta", None)
or getattr(draft_config, "rope_theta", 1000000.0),
"rope_theta": (
getattr(base_config, "rope_theta", None)
if getattr(base_config, "rope_theta", None) is not None
else getattr(draft_config, "rope_theta", 1000000.0)
),

Realistically rope_theta=0 is rare so this is more theoretical than acute, but the existing eagle path uses an explicit is None check (line 205-206) for exactly this reason — worth being consistent.

Comment on lines +38 to +39
- --tp_size 4
- --ep_size 4

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION] exported-checkpoint-final doesn't match the callback's actual output naming.

DFlashFSDP2ShardedSDExportCallback.on_save writes to exported-checkpoint-{step} (eagle_utils.py:258), where step = state.global_step — never the literal string final. So this YAML's --draft_model_dir /scratchspace/dflash_minimax_m2.7/exported-checkpoint-final will fail to find model.safetensors unless the user manually renames or symlinks the last per-save export directory. Same on line 67. Either the callback should additionally write/symlink an exported-checkpoint-final/ after on_train_end, or this YAML should use a placeholder pattern (e.g. exported-checkpoint-${STEP} with the user expected to fill it in) and call out in a comment that the user must pick the desired step. As-written this is a copy-paste-ready bench config that won't work out of the box.

yeyu-nvidia and others added 2 commits June 15, 2026 14:29
Addresses the Claude review's CRITICAL + silent-fallback findings:

- CRITICAL: _clip_grad_norm no longer early-returns on an empty-grad rank before the
  collective — under FSDP2 sharding / MoE+LoRA co-train a rank can legitimately have no
  grads (e.g. an expert with no tokens), and returning early while other ranks all_reduce
  deadlocks the job. Every rank now reaches the (guarded) all_reduce; empty-grad ranks
  contribute a zero norm and clip nothing.
- fsdp2_buffer_patch dtype-sync: raise on an unmapped dtype instead of silently coercing
  to fp32 (which would cast on the wire or trip an NCCL size mismatch).
- DFlash export callback: warn loudly on the two fallback paths — the full-model gather
  when get_model_state_dict lacks submodules=, and the denylist heuristic when the
  prefix-based extraction finds nothing — so a slow/OOM gather or a malformed export
  isn't silent.
- hf_spec_export rope_theta: use 'is not None' instead of 'or' so a base rope_theta of 0
  isn't swallowed (parity with the eagle path).
- Cleanups: drop the dead _orig capture; remove the duplicated num_hidden_layers doc row.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
…g in offline dump

Addresses the remaining offline-dump-path findings (relevant to the upcoming M3 offline work):

- compute_hidden_states_vllm.py: the standalone 'dflash' aux-layer resolver no longer
  hard-codes num_draft=5; added --num-draft-layers (default 5) so the dumped aux layers
  match the recipe's dflash_architecture_config.num_hidden_layers. A mismatch would
  silently mis-align the dump with what the draft consumes at training time.
- compute_hidden_states_vllm.py: the connector staging dir is now overridable via
  DFLASH_HS_STAGING_DIR (default /dev/shm) for containers where /dev/shm is unmapped or
  undersized, and is removed on exit (atexit) so a crash doesn't strand RAM-backed files.
- specdec_bench.yaml: clarified that exported-checkpoint-final comes from the launcher's
  post-run export, while the per-save callback writes exported-checkpoint-<step>.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
@yeyu-nvidia

Copy link
Copy Markdown
Contributor Author

Claude review — resolutions (commits 82814228 + a8c67e8744)

Thanks — went through all findings. Summary:

CRITICAL — _clip_grad_norm deadlock ✅ fixed. The empty-grad early-return before the collective is removed; every rank now reaches the (distributed-guarded) all_reduce, so an empty-grad rank (e.g. an FSDP2-sharded MoE expert with no tokens — exactly the LoRA-co-train case) contributes a zero norm and clips nothing instead of hanging the others.

IMPORTANT

  • Full-model fallback when submodules= unsupported ✅ now warns loudly (slow / may-OOM gather; upgrade PyTorch).
  • Fallback key heuristic could ship a malformed export silently ✅ now warns when the prefix-based extraction finds nothing and the denylist heuristic is used.
  • _DTYPE_TO_CODE.get(...,0) coerces unknown dtype → fp32 ✅ now raises on any unmapped dtype.
  • aux-layer resolver hard-coded num_draft=5 ✅ added --num-draft-layers (default 5); must match the recipe's draft depth.
  • rope_theta or swallows 0 ✅ switched to is not None (eagle parity).
  • /dev/shm staging no cleanup / not overridable ✅ overridable via DFLASH_HS_STAGING_DIR, removed on exit via atexit.

SUGGESTION — dead _orig capture removed ✅; duplicate num_hidden_layers doc row removed ✅; specdec_bench.yaml clarified that exported-checkpoint-final is the launcher's post-run export while the callback writes exported-checkpoint-<step> ✅.

On the summary's "resume doesn't re-apply mtsp.convert" (main.py HF-format path): this is handled — mto.enable_huggingface_checkpointing() (called at import) patches from_pretrained to restore the saved modelopt_state (incl. the DFlash module) when loading a modelopt HF checkpoint dir, so load_vlm_or_llm(checkpoint) re-materializes the draft automatically. Happy to add an explicit assertion if you'd prefer a hard guard.

…lambda

mypy couldn't infer the cleanup lambda's type ('Cannot infer type of lambda'). atexit
forwards *args/**kwargs to the callback, so register shutil.rmtree directly with its args.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
@yeyu-nvidia yeyu-nvidia merged commit e004d8d into main Jun 15, 2026
54 checks passed
@yeyu-nvidia yeyu-nvidia deleted the yeyu/dflash-auto-mask-token branch June 15, 2026 23:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants