diff --git a/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py b/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py
index dd496480cbb..77441f8f858 100644
--- a/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py
+++ b/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py
@@ -25,10 +25,19 @@
"""
import argparse
+import atexit
+import os
+import shutil
from pathlib import Path
import torch
-from common import add_aux_layers_args, resolve_aux_layers
+from common import (
+ add_answer_only_loss_args,
+ add_aux_layers_args,
+ load_chat_template,
+ tokenize_with_loss_mask,
+ verify_generation_tags,
+)
from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoConfig, AutoTokenizer
@@ -38,6 +47,48 @@
)
+def _resolve_aux_layers_standalone(
+ aux_layers: str, num_hidden_layers: int, num_draft: int = 5
+) -> list[int]:
+ """Resolve aux-layer ids without importing modelopt.
+
+ This dump runs in a stock vLLM container. ``common.resolve_aux_layers`` resolves the
+ 'dflash'/'eagle' presets by importing ``modelopt.torch.speculative.plugins`` — which
+ pulls in the full ``modelopt.torch`` init chain (omegaconf, etc.) that the vLLM
+ container does not have, so the import fails. Resolve the 'dflash' preset inline
+ (mirroring ``modeling_dflash.build_target_layer_ids`` for ``num_draft`` draft layers)
+ and accept an explicit comma-separated int list. ``num_draft`` MUST match the recipe's
+ ``dflash.dflash_architecture_config.num_hidden_layers`` (pass --num-draft-layers) or the
+ dumped aux layers silently mis-align with what the draft consumes at training time.
+ Keep in sync with modelopt.
+
+ TODO: drop this once ``common.resolve_aux_layers`` is decoupled from the heavy
+ ``modelopt.torch`` import chain so it can be reused directly in a vLLM container.
+ """
+ spec = aux_layers.strip().lower()
+ if spec == "dflash":
+ if num_draft == 1:
+ return [num_hidden_layers // 2]
+ start = min(1, num_hidden_layers - 1)
+ end = max(start, num_hidden_layers - 3)
+ span = end - start
+ return sorted({round(start + (i * span) / (num_draft - 1)) for i in range(num_draft)})
+ ids = sorted({int(t) for t in aux_layers.split(",") if t.strip()})
+ # Match the shared helper's contract: ids must be valid layer indices.
+ out_of_range = [i for i in ids if not 0 <= i < num_hidden_layers]
+ if out_of_range:
+ raise ValueError(
+ f"--aux-layers ids {out_of_range} out of range [0, {num_hidden_layers}) "
+ f"for a {num_hidden_layers}-layer model."
+ )
+ if not ids:
+ raise ValueError(
+ f"--aux-layers={aux_layers!r}: in the stock vLLM container (no modelopt) only the "
+ "'dflash' preset or an explicit comma-separated layer-id list are supported."
+ )
+ return ids
+
+
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="""Collect hidden states from conversations using vLLM's native extractor."""
@@ -63,6 +114,14 @@ def parse_args() -> argparse.Namespace:
"--debug-max-num-conversations", type=int, default=None, help="Limit conversations."
)
add_aux_layers_args(parser)
+ parser.add_argument(
+ "--num-draft-layers",
+ type=int,
+ default=5,
+ help="DFlash draft depth, for resolving the 'dflash' --aux-layers preset. MUST match "
+ "the recipe's dflash.dflash_architecture_config.num_hidden_layers (default: 5).",
+ )
+ add_answer_only_loss_args(parser)
return parser.parse_args()
@@ -112,7 +171,9 @@ def keep_conversation(entry):
num_hidden_layers = getattr(config, "num_hidden_layers", None)
if num_hidden_layers is None:
raise ValueError(f"model config has no 'num_hidden_layers' attribute: {config}")
- aux_layer_ids = resolve_aux_layers(args, num_hidden_layers)
+ aux_layer_ids = _resolve_aux_layers_standalone(
+ args.aux_layers, num_hidden_layers, num_draft=args.num_draft_layers
+ )
# The trailing entry is the final output hidden state; the rest are aux layers.
extract_layer_ids = [*aux_layer_ids, num_hidden_layers]
print(f"Extracting hidden states from layers {extract_layer_ids} (last = final output)")
@@ -121,34 +182,37 @@ def keep_conversation(entry):
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=args.trust_remote_code)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
+ override_template = load_chat_template(args.chat_template)
+ if override_template is not None:
+ tokenizer.chat_template = override_template
if tokenizer.chat_template is not None:
tokenizer.chat_template = tokenizer.chat_template.replace(REMOVE_THINK_CHAT_TEMPLATE, "")
+ if args.answer_only_loss:
+ verify_generation_tags(tokenizer.chat_template)
# Prepare prompts for vLLM
prompts = []
conversation_ids = []
+ loss_masks = []
num_skipped_too_long = 0
num_invalid = 0
for entry in dataset:
conversation_id = entry.get("conversation_id", entry.get("uuid"))
- conversations = entry["conversations"]
+ # Accept either the "conversations" or OpenAI-style "messages" key (the
+ # MiniMax synthetic data uses "messages").
+ conversations = entry.get("conversations") or entry.get("messages")
if not conversations or not isinstance(conversations, list):
num_invalid += 1
continue
- tokenized = tokenizer.apply_chat_template(
- conversations, return_tensors="pt", add_generation_prompt=False
+ # One apply_chat_template call yields aligned input_ids + loss_mask. With
+ # --answer-only-loss the mask comes from the template's {% generation %} tags;
+ # otherwise it is all-ones. Same tokens are sent to vLLM, so the dumped hidden
+ # states line up with this loss_mask 1:1 (prefix caching is disabled below).
+ input_ids, loss_mask = tokenize_with_loss_mask(
+ tokenizer, conversations, args.answer_only_loss
)
- # transformers 5.x: BatchEncoding may not inherit from dict; use .input_ids
- if hasattr(tokenized, "input_ids"):
- input_ids = tokenized.input_ids
- elif hasattr(tokenized, "__getitem__") and "input_ids" in tokenized:
- input_ids = tokenized["input_ids"]
- else:
- input_ids = tokenized
- if not hasattr(input_ids, "shape"):
- input_ids = torch.tensor(input_ids)
input_ids = input_ids.squeeze(0)
num_tokens = input_ids.shape[0]
if num_tokens <= 10 or num_tokens > args.max_seq_len:
@@ -157,6 +221,7 @@ def keep_conversation(entry):
prompts.append(TokensPrompt(prompt_token_ids=input_ids.tolist()))
conversation_ids.append(conversation_id)
+ loss_masks.append(loss_mask)
print(
f"Prepared {len(prompts)} prompts ({num_skipped_too_long} skipped too long, {num_invalid} invalid)"
@@ -168,8 +233,16 @@ def keep_conversation(entry):
# Initialize vLLM with the native hidden-state extractor.
tp = args.tp if args.tp is not None else torch.cuda.device_count()
- storage_path = output_dir / ".vllm_hidden_states"
+ # Stage the connector's intermediate safetensors on local tmpfs, not the (lustre)
+ # output dir: the producer writes one file per request and the client reads it back
+ # immediately, so a fast local path avoids cross-node FS latency. Per-DP-rank dir so
+ # parallel shards don't collide. Overridable via DFLASH_HS_STAGING_DIR for containers
+ # where /dev/shm is unmapped or undersized; cleaned up on exit so a crash doesn't strand
+ # RAM-backed files until the node reboots.
+ staging_root = os.environ.get("DFLASH_HS_STAGING_DIR", "/dev/shm")
+ storage_path = Path(staging_root) / f"vllm_hidden_states_dp{args.dp_rank}"
storage_path.mkdir(parents=True, exist_ok=True)
+ atexit.register(shutil.rmtree, storage_path, ignore_errors=True)
llm = LLM(
model=args.model,
@@ -177,6 +250,11 @@ def keep_conversation(entry):
max_model_len=args.max_seq_len,
trust_remote_code=args.trust_remote_code,
enable_chunked_prefill=False, # required by extract_hidden_states
+ # With prefix caching on, vLLM serves shared prefixes from cache in block-sized
+ # chunks and the hidden-state connector only emits the freshly-computed suffix, so
+ # the dumped hidden_states come out short by N*block_size vs the full input_ids /
+ # loss_mask. Disabling it forces a full prefill so every token's state is dumped.
+ enable_prefix_caching=False,
speculative_config={
"method": "extract_hidden_states",
"num_speculative_tokens": 1,
@@ -189,7 +267,10 @@ def keep_conversation(entry):
kv_role="kv_producer",
kv_connector_extra_config={
"shared_storage_path": str(storage_path),
- "use_synchronization_lock": False, # batch generation, no concurrent readers
+ # The client reads each request's safetensors right after generation; the
+ # lock makes the producer signal completion so the reader doesn't race the
+ # writer (without it the reader looks for a .lock the producer never wrote).
+ "use_synchronization_lock": True,
},
),
)
@@ -197,10 +278,11 @@ def keep_conversation(entry):
# max_tokens=1: we only need a single forward pass over the prompt tokens.
outputs = llm.generate(prompts, SamplingParams(max_tokens=1))
- # Save in the same format as compute_hidden_states_hf.py (sans loss_mask, which the
- # vLLM path does not compute).
+ # Save in the same format as compute_hidden_states_hf.py, including loss_mask.
num_success = 0
- for conv_id, output in tqdm(zip(conversation_ids, outputs), total=len(outputs), desc="Saving"):
+ for conv_id, loss_mask, output in tqdm(
+ zip(conversation_ids, loss_masks, outputs), total=len(outputs), desc="Saving"
+ ):
hidden_states_path = output.kv_transfer_params.get("hidden_states_path")
if hidden_states_path is None:
print(f"WARNING: no hidden_states_path for conversation {conv_id}; skipping")
@@ -220,6 +302,16 @@ def keep_conversation(entry):
else:
aux_hidden_states = torch.empty(0)
+ # loss_mask is sliced to the dumped length below; a shorter loss_mask would slice
+ # to itself and silently misalign with the hidden states, so guard explicitly.
+ n_hs = output_hidden_states.shape[0]
+ if loss_mask.shape[0] < n_hs:
+ print(
+ f"WARNING: {conv_id}: loss_mask ({loss_mask.shape[0]}) shorter than hidden "
+ f"states ({n_hs}); skipping to avoid misalignment"
+ )
+ continue
+
output_file = output_dir / f"{conv_id}.pt"
with open(output_file, "wb") as f:
torch.save(
@@ -227,6 +319,7 @@ def keep_conversation(entry):
"input_ids": token_ids.cpu(),
"hidden_states": output_hidden_states,
"aux_hidden_states": aux_hidden_states,
+ "loss_mask": loss_mask[: output_hidden_states.shape[0]].cpu(),
"conversation_id": conv_id,
},
f,
diff --git a/examples/speculative_decoding/doc/dflash.md b/examples/speculative_decoding/doc/dflash.md
index 44db5d39e72..0150e0884e1 100644
--- a/examples/speculative_decoding/doc/dflash.md
+++ b/examples/speculative_decoding/doc/dflash.md
@@ -163,10 +163,15 @@ See [`modelopt_recipes/general/speculative_decoding/dflash.yaml`](../../../model
| `dflash.dflash_num_anchors` | 512 | Random anchor positions per sample (see below) |
| `dflash.dflash_loss_decay_factor` | 4.0 | Exponential decay gamma (0 disables, see below) |
| `dflash.dflash_self_logit_distillation` | true | Use target model logits as soft labels (vs hard CE) |
+| `dflash.dflash_mask_token_id` | auto | Token ID for masked positions (see note below) |
| `dflash.dflash_architecture_config.num_hidden_layers` | 5 | Draft decoder layers |
-| `dflash.dflash_architecture_config.mask_token_id` | auto | Token ID for masked positions |
| `training.answer_only_loss` | false | Mask loss on non-assistant tokens |
+> **Note on `dflash_mask_token_id`:** the draft reuses the target's `embed_tokens`, so the
+> mask id must be a token that exists in the target embedding. Pin it to an existing
+> reserved token id — e.g. MiniMax-M2.7 uses `200054` (embedding is 200064 rows; tokens
+> 0..200053 are real, 200054+ reserved).
+
> **Note on `answer_only_loss` and chat templates:** When `answer_only_loss=true`, the
> tokenizer's chat template must include `{% generation %}` / `{% endgeneration %}` tags
> around assistant content. HuggingFace uses these tags to produce `assistant_masks` via
diff --git a/examples/speculative_decoding/eagle_utils.py b/examples/speculative_decoding/eagle_utils.py
index 721b981eaae..6d88632022b 100644
--- a/examples/speculative_decoding/eagle_utils.py
+++ b/examples/speculative_decoding/eagle_utils.py
@@ -230,6 +230,110 @@ def on_step_begin(self, args, state, control, **kwargs):
return control
+class DFlashFSDP2ShardedSDExportCallback(TrainerCallback):
+ """Export the DFlash draft module after each checkpoint save, for FSDP2 sharded runs.
+
+ Applicable range: this is needed only under FSDP2 ``SHARDED_STATE_DICT``, where the
+ checkpoint holds distributed shards (``pytorch_model_fsdp_0/``) and no consolidated
+ ``model.safetensors`` — so the post-training ``export_hf_checkpoint.py`` pass can't read
+ it. It gathers just the small draft submodule and writes the deployable export.
+
+ Gating is the caller's responsibility: ``main.py`` adds this callback only when the
+ accelerator's FSDP state dict type is ``SHARDED_STATE_DICT`` (full-state-dict runs —
+ DDP, single-device, FSDP2 FULL_STATE_DICT — use the launcher's post-run export instead).
+ """
+
+ def on_save(self, args, state, control, **kwargs):
+ """Export DFlash draft module weights + config after checkpoint save."""
+ import json
+ import os
+
+ from safetensors.torch import save_file
+
+ model = kwargs["model"]
+ if not hasattr(model, "dflash_module"):
+ return control
+
+ step = state.global_step
+ export_dir = os.path.join(args.output_dir, f"exported-checkpoint-{step}")
+
+ # All ranks participate in the state_dict gather (FSDP2 collective op). Only the
+ # dflash_module submodule is gathered (~328 MB for MiniMax-M2.7), not the 229B base.
+ try:
+ from torch.distributed.checkpoint.state_dict import (
+ StateDictOptions,
+ get_model_state_dict,
+ )
+
+ options = StateDictOptions(full_state_dict=True, cpu_offload=True)
+ try:
+ raw_sd = get_model_state_dict(
+ model, submodules={model.dflash_module}, options=options
+ )
+ except TypeError:
+ # Older PyTorch without the submodules= parameter: this gathers the FULL
+ # model (the entire base, e.g. ~229B for MiniMax-M2.7), defeating the
+ # submodule-only design and risking OOM. Warn loudly — upgrade PyTorch.
+ print_rank_0(
+ "WARNING: DFlash export: get_model_state_dict lacks submodules= on this "
+ "PyTorch — gathering the FULL base model (slow / may OOM). Upgrade PyTorch "
+ "for the submodule-only gather."
+ )
+ raw_sd = get_model_state_dict(model, options=options)
+ except ImportError:
+ # Non-distributed / single-GPU fallback
+ raw_sd = model.state_dict()
+
+ # Reuse the exporter's extraction (strips the dflash_module prefix, drops rotary
+ # buffers) for the common full-model key layout. Some PyTorch versions return the
+ # submodule gather with keys already stripped of the prefix — handle that directly.
+ exporter = model.get_exporter()
+ drafter_sd = exporter._extract_state_dict(raw_sd)
+ if not drafter_sd:
+ # Fallback for the already-stripped-key layout: a denylist heuristic rather than
+ # the prefix-based extractor. Warn, since a key-naming change upstream could let
+ # a malformed draft ship and only fail at vLLM load time.
+ print_rank_0(
+ "WARNING: DFlash export: prefix-based extraction found no dflash_module keys; "
+ "falling back to a denylist heuristic on already-stripped keys. Verify the "
+ "exported draft loads in vLLM."
+ )
+ drafter_sd = {
+ k: v
+ for k, v in raw_sd.items()
+ if "rotary_emb" not in k
+ and not any(p in k for p in ("model.", "lm_head.", "embed_tokens."))
+ }
+ del raw_sd
+ # Coerce to CPU for save_file (the distributed gather uses cpu_offload, but the
+ # single-GPU fallback may return CUDA tensors).
+ drafter_sd = {k: (v.cpu() if v.device.type != "cpu" else v) for k, v in drafter_sd.items()}
+
+ if not drafter_sd:
+ print_rank_0(f"Warning: No dflash_module weights found at step {step}, skipping export")
+ return control
+
+ # Only rank 0 writes files
+ if is_master():
+ try:
+ os.makedirs(export_dir, exist_ok=True)
+ save_file(drafter_sd, os.path.join(export_dir, "model.safetensors"))
+
+ config = exporter._export_config()
+ with open(os.path.join(export_dir, "config.json"), "w") as f:
+ json.dump(config, f, indent=2)
+
+ total_mb = sum(v.nbytes for v in drafter_sd.values()) / 1024 / 1024
+ print_rank_0(
+ f"Exported DFlash draft ({len(drafter_sd)} tensors, {total_mb:.0f}MB) "
+ f"to {export_dir}"
+ )
+ except Exception as e:
+ print_rank_0(f"Warning: DFlash export failed at step {step}: {e}")
+
+ return control
+
+
class EagleTrainingPlot(TrainerCallback):
"""Callback that plot training acc and AR during training."""
diff --git a/examples/speculative_decoding/fsdp2_buffer_patch.py b/examples/speculative_decoding/fsdp2_buffer_patch.py
new file mode 100644
index 00000000000..388959c6b78
--- /dev/null
+++ b/examples/speculative_decoding/fsdp2_buffer_patch.py
@@ -0,0 +1,413 @@
+# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+"""Monkey-patch for accelerate's fsdp2_load_full_state_dict buffer handling.
+
+Applicability (scope of this patch)
+-----------------------------------
+This is **not** needed for FSDP2 in general. It is required only for the narrow
+combination of: **FSDP2 configured via an accelerate YAML config** (not torch-native
+``ParallelismConfig``) **with** ``cpu_ram_efficient_loading=True``. Today that path is
+forced by **models that require transformers 4.57.x** (their ``trust_remote_code`` code
+predates transformers 5.x ``ParallelismConfig`` support) **and** are too large to load on
+every rank — currently only **MiniMax-M2.7** (229B MoE). Models that run on transformers
+5.x (Qwen, Llama, Nemotron, ...) use native ``ParallelismConfig`` (``dp_shard_size > 1``),
+which handles buffers/dtypes correctly and never enters ``fsdp2_load_full_state_dict`` —
+they need none of this. Gated off by default; activate with ``PATCH_FSDP2_BUFFERS_TF457=1``.
+
+Problem
+-------
+accelerate's ``fsdp2_load_full_state_dict`` (called during model preparation
+when ``cpu_ram_efficient_loading=True``) iterates ``model.state_dict()`` and
+unconditionally accesses ``.device_mesh`` on every entry, assuming they are all
+DTensors. After ``fully_shard()``, **parameters** become DTensors but
+**persistent buffers** (e.g., rotary-embedding ``inv_freq``) remain plain
+``torch.Tensor``. This crashes with::
+
+ AttributeError: 'Tensor' object has no attribute 'device_mesh'
+
+Additionally, ``cpu_ram_efficient_loading`` causes a dtype divergence: rank 0
+loads the model on CPU (inheriting the checkpoint's dtype, e.g. bfloat16) while
+other ranks use ``meta`` device (defaulting to float32 for newly-added modules
+like the DFlash head). After ``fully_shard()``, the DTensor dtypes differ
+across ranks for these modules. Since ``dist.broadcast()`` requires matching
+dtypes and element sizes on all ranks, broadcasting a bfloat16 tensor (2
+bytes/elem) to a float32 receive buffer (4 bytes/elem) causes an NCCL deadlock.
+
+Why we need FSDP2 via accelerate config (not ParallelismConfig)
+---------------------------------------------------------------
+MiniMax-M2.7's ``trust_remote_code`` model code requires **transformers 4.57.x**.
+Transformers' native FSDP2 support via ``ParallelismConfig`` requires
+**transformers 5.x**. So we fall back to configuring FSDP2 through an
+``accelerate`` YAML config file (``accelerate_fsdp2.yaml``), which works with
+transformers 4.57.x. We set ``dp_shard_size=1`` to prevent ``main.py`` from
+creating a ``ParallelismConfig``, letting the accelerate config handle sharding.
+
+Why we need cpu_ram_efficient_loading
+-------------------------------------
+MiniMax-M2.7 is a 229B MoE model (~230 GB in FP8). Each GB200 node has 4 GPUs
+and ~800 GB system RAM. Without ``cpu_ram_efficient_loading``, all 4 ranks per
+node load the model to CPU simultaneously (4 × 230 GB ≈ 920 GB), exceeding
+system RAM and triggering OOM kills. With ``cpu_ram_efficient_loading``, only
+rank 0 loads the model; other ranks initialize on ``meta`` device. The weights
+are then broadcast via ``fsdp2_load_full_state_dict`` — which is where the bug
+hits.
+
+What this patch does
+--------------------
+1. Before accessing ``.device_mesh``, checks ``isinstance(entry, DTensor)``.
+ For non-DTensor entries (persistent buffers), broadcasts the raw tensor from
+ rank 0 without calling ``distribute_tensor()``.
+
+2. All ranks iterate ``model.state_dict()`` (post-shard) in the same order so
+ broadcast calls match 1-to-1. Rank 0 looks up the full parameter **by key
+ name** from the pre-shard state dict — never by positional ``zip``, because
+ ``model.to("meta")`` + ``fully_shard()`` can reorder keys.
+
+3. **Dtype synchronization**: rank 0 broadcasts a dtype code for each entry
+ before the main loop. All ranks then use the same dtype for their broadcast
+ tensors. This fixes the dtype divergence caused by rank 0 loading in
+ bfloat16 while other ranks default to float32 for newly-added modules.
+
+Accelerate config constraints (for reference)
+----------------------------------------------
+``accelerate_fsdp2.yaml`` also requires:
+
+- ``fsdp_use_orig_params: true`` — without this, FSDP flattens all params into
+ FlatParameter, losing per-parameter ``requires_grad`` flags. The DFlash head
+ can't train because its gradients mix with frozen base model zeros.
+- ``fsdp_transformer_layer_cls_to_wrap: MiniMaxM2DecoderLayer,DFlashModule`` —
+ DFlash head params at the model root must be in the wrap policy so they become
+ DTensors. Without this, ``fsdp2_load_full_state_dict`` also crashes.
+- ``fsdp_sync_module_states: true`` — accelerate's launch validator requires it
+ when ``cpu_ram_efficient_loading`` is enabled, even though FSDP2 ignores it at
+ runtime (sets it to None with a warning).
+
+Does NOT affect models on transformers 5.x
+-------------------------------------------
+This entire workaround exists ONLY because MiniMax-M2.7 requires
+transformers 4.57.x. Models that support transformers 5.x (Qwen, Llama,
+Nemotron, etc.) use ``ParallelismConfig`` natively by setting
+``dp_shard_size > 1`` in the training args. That code path handles buffers
+correctly and does not go through ``fsdp2_load_full_state_dict`` at all.
+No accelerate config file, no ``PATCH_FSDP2_BUFFERS_TF457``, no
+``OVERRIDE_TRANSFORMERS`` needed.
+
+When to remove
+--------------
+This patch can be removed when EITHER of these happens:
+
+1. MiniMax updates ``trust_remote_code`` for transformers 5.x, allowing native
+ ``ParallelismConfig`` (which handles this correctly).
+2. Upstream accelerate fixes ``fsdp2_load_full_state_dict`` to skip non-DTensor
+ entries. Track: https://github.com/huggingface/accelerate
+
+Activation
+----------
+Set ``PATCH_FSDP2_BUFFERS_TF457=1`` in the environment to activate. Off by default.
+Only needed in MiniMax-M2.7 pipeline YAMLs.
+"""
+
+import torch
+
+# Dtype encoding for the broadcast dtype-sync step.
+_DTYPE_TO_CODE = {
+ torch.float32: 0,
+ torch.bfloat16: 1,
+ torch.float16: 2,
+ torch.float8_e4m3fn: 3 if hasattr(torch, "float8_e4m3fn") else -1,
+}
+_CODE_TO_DTYPE = {v: k for k, v in _DTYPE_TO_CODE.items() if v >= 0}
+
+
+def apply():
+ """Patch fsdp2_load_full_state_dict if the buffer bug is present."""
+ try:
+ import accelerate.utils.fsdp_utils as fsdp_utils
+ from torch.distributed.tensor import DTensor
+
+ def _dtype_code(dt):
+ """Map a dtype to its broadcast sync code, raising on anything unmapped.
+
+ Silently coercing an unknown dtype to fp32 would cast data on the wire (or
+ make NCCL refuse on an element-size mismatch), so fail loudly instead.
+ """
+ code = _DTYPE_TO_CODE.get(dt)
+ if code is None or code < 0:
+ raise ValueError(
+ f"fsdp2_buffer_patch: unsupported dtype {dt} in the broadcast "
+ f"dtype-sync; add it to _DTYPE_TO_CODE."
+ )
+ return code
+
+ def _patched(accelerator, model, full_sd, cpu_offload=False):
+ import time
+
+ import torch.distributed as dist
+ from torch.distributed.tensor import distribute_tensor
+
+ meta_sharded_sd = model.state_dict()
+ sharded_sd = {}
+ n_total = len(meta_sharded_sd)
+ n_dtensor = sum(1 for v in meta_sharded_sd.values() if isinstance(v, DTensor))
+ n_buffer = n_total - n_dtensor
+
+ if accelerator.is_main_process:
+ print(
+ f"[fsdp2_buffer_patch] State dict: {n_total} entries "
+ f"({n_dtensor} DTensor, {n_buffer} buffer), full_sd: {len(full_sd)}"
+ )
+ else:
+ print(
+ f"[fsdp2_buffer_patch] State dict: {n_total} entries "
+ f"({n_dtensor} DTensor, {n_buffer} buffer)"
+ )
+ t0 = time.time()
+
+ # --- Step 0: broadcast dtype codes from rank 0 ---
+ # cpu_ram_efficient_loading causes rank 0 to load in bfloat16 while
+ # other ranks default to float32 for newly-added modules (DFlash).
+ # After fully_shard(), DTensor dtypes diverge across ranks.
+ # Broadcast rank 0's dtypes so all ranks use the same dtype for
+ # each broadcast tensor.
+ if accelerator.is_main_process:
+ dtype_codes = torch.tensor(
+ [_dtype_code(full_sd[name].dtype) for name in meta_sharded_sd],
+ dtype=torch.int32,
+ device=accelerator.device,
+ )
+ else:
+ dtype_codes = torch.empty(
+ n_total,
+ dtype=torch.int32,
+ device=accelerator.device,
+ )
+ dist.broadcast(dtype_codes, src=0, group=dist.group.WORLD)
+ broadcast_dtypes = [_CODE_TO_DTYPE[c.item()] for c in dtype_codes]
+ del dtype_codes
+
+ # Infer dtype/contiguity for cast — copied from upstream
+ def _infer_parameter_dtype(mdl, param_name, empty_param):
+ try:
+ old_param = mdl.get_parameter_or_buffer(param_name)
+ except AttributeError:
+ base, local = param_name.rsplit(".", 1)
+ old_param = getattr(mdl.get_submodule(base), local)
+ is_f8 = hasattr(torch, "float8_e4m3fn") and empty_param.dtype == torch.float8_e4m3fn
+ casting_dtype = (
+ old_param.dtype if (empty_param.dtype.is_floating_point and not is_f8) else None
+ )
+ return old_param is not None and old_param.is_contiguous(), casting_dtype
+
+ def _finish(st, contig, dtype, offload):
+ if dtype is not None:
+ st = st.to(dtype=dtype)
+ if contig:
+ st = st.contiguous()
+ if offload:
+ st = st.to("cpu")
+ return st
+
+ # --- Step 1: broadcast all entries ---
+ # All ranks iterate meta_sharded_sd in the same order to ensure
+ # identical broadcast sequences. Rank 0 looks up the full parameter
+ # by name — never positional zip (model.to("meta") + fully_shard()
+ # can reorder keys). broadcast_dtypes[idx] is used for the tensor
+ # dtype on ALL ranks to prevent NCCL deadlocks from dtype divergence.
+ for idx, (param_name, sharded_param) in enumerate(meta_sharded_sd.items()):
+ is_dtensor = isinstance(sharded_param, DTensor)
+ bcast_dtype = broadcast_dtypes[idx]
+
+ if not is_dtensor:
+ # Persistent buffer — broadcast raw, no distribute_tensor
+ if accelerator.is_main_process:
+ t = (
+ full_sd[param_name]
+ .detach()
+ .to(device=accelerator.device, dtype=bcast_dtype)
+ )
+ else:
+ t = torch.empty(
+ sharded_param.size(),
+ device=accelerator.device,
+ dtype=bcast_dtype,
+ )
+ dist.broadcast(t, src=0, group=dist.group.WORLD)
+ sharded_sd[param_name] = t
+ continue
+
+ device_mesh = sharded_param.device_mesh
+ if accelerator.is_main_process:
+ ft = (
+ full_sd[param_name]
+ .detach()
+ .to(device=device_mesh.device_type, dtype=bcast_dtype)
+ )
+ if isinstance(ft, DTensor):
+ ft = ft.to_local()
+ else:
+ ft = torch.empty(
+ sharded_param.size(),
+ device=device_mesh.device_type,
+ dtype=bcast_dtype,
+ )
+ dist.broadcast(ft, src=0, group=dist.group.WORLD)
+ st = distribute_tensor(ft, device_mesh, sharded_param.placements)
+ contig, _ = _infer_parameter_dtype(model, param_name, ft)
+ # Use bcast_dtype (from rank 0) instead of the model's local
+ # param dtype — with cpu_ram_efficient_loading, non-rank-0
+ # processes have fp32 meta-device params for DFlash, and
+ # _infer_parameter_dtype would incorrectly cast bf16 back to fp32.
+ is_f8 = hasattr(torch, "float8_e4m3fn") and bcast_dtype == torch.float8_e4m3fn
+ final_dtype = None if is_f8 else bcast_dtype
+ sharded_sd[param_name] = _finish(st, contig, final_dtype, cpu_offload)
+
+ elapsed = time.time() - t0
+ print(
+ f"[fsdp2_buffer_patch] Broadcast done in {elapsed:.1f}s, "
+ f"loading {len(sharded_sd)} entries into model..."
+ )
+ model.load_state_dict(sharded_sd, assign=True)
+ print(
+ f"[fsdp2_buffer_patch] State dict loaded successfully "
+ f"({time.time() - t0:.1f}s total)"
+ )
+ return model
+
+ fsdp_utils.fsdp2_load_full_state_dict = _patched
+ print("[fsdp2_buffer_patch] Patched fsdp2_load_full_state_dict for buffer compatibility")
+ except Exception as e:
+ print(f"[fsdp2_buffer_patch] Patch skipped: {e}")
+
+
+_clip_grad_norm_call_count = 0
+
+
+def _clip_grad_norm(parameters, max_norm, norm_type=2):
+ """Clip gradient norms for FSDP2 DTensor parameters.
+
+ Bypasses DTensor dispatch (which deadlocks with partially-frozen models
+ on the accelerate FSDP2 path) by extracting local tensor shards and
+ doing an explicit all_reduce for the global norm.
+
+ Handles Shard (need all_reduce) and Replicate/regular (already global)
+ placements. Safe for DFlash-only and LoRA co-training.
+ """
+ global _clip_grad_norm_call_count
+ import torch.distributed as dist
+ from torch.distributed.tensor import DTensor
+ from torch.distributed.tensor.placement_types import Shard
+
+ if isinstance(parameters, torch.Tensor):
+ parameters = [parameters]
+
+ parameters = list(parameters) # materialize generator
+ max_norm = float(max_norm)
+ norm_type = float(norm_type)
+
+ grads = [p.grad for p in parameters if p.grad is not None]
+ # Every rank MUST reach the all_reduce below: under FSDP2 sharding (and especially
+ # MoE + LoRA co-training) a rank can legitimately have no grads on a step — e.g. an
+ # expert that received no tokens, so the shard it owns gets no gradient. Early-
+ # returning here while other ranks call all_reduce would deadlock the job. So we
+ # never short-circuit before the collective; an empty-grad rank simply contributes a
+ # zero local norm and clips nothing.
+ if grads:
+ device = grads[0]._local_tensor.device if isinstance(grads[0], DTensor) else grads[0].device
+ else:
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+ # Shard DTensors hold partial data — need all_reduce for global norm.
+ # Replicate DTensors and regular tensors already hold full data.
+ sharded_norm_p = torch.tensor(0.0, device=device)
+ local_norm_p = torch.tensor(0.0, device=device)
+
+ n_sharded = 0
+ n_replicate = 0
+ n_regular = 0
+
+ for g in grads:
+ if isinstance(g, DTensor):
+ is_sharded = any(isinstance(p, Shard) for p in g.placements)
+ t = g._local_tensor.detach().to(torch.float32)
+ n = torch.linalg.vector_norm(t, norm_type)
+ if is_sharded:
+ sharded_norm_p += n.pow(norm_type)
+ n_sharded += 1
+ else:
+ local_norm_p += n.pow(norm_type)
+ n_replicate += 1
+ else:
+ n = torch.linalg.vector_norm(g.detach().to(torch.float32), norm_type)
+ local_norm_p += n.pow(norm_type)
+ n_regular += 1
+
+ # Symmetric across ranks: reached on every rank regardless of whether this rank had
+ # grads (see note above). Guard for the non-distributed case where local == global.
+ if dist.is_available() and dist.is_initialized():
+ dist.all_reduce(sharded_norm_p, op=dist.ReduceOp.SUM)
+ total_norm = (sharded_norm_p + local_norm_p).pow(1.0 / norm_type)
+
+ clip_coef = torch.clamp(max_norm / (total_norm + 1e-6), max=1.0)
+
+ # Debug: log computation breakdown on first 5 calls (no collectives — safe).
+ _clip_grad_norm_call_count += 1
+ _rank0 = not (dist.is_available() and dist.is_initialized()) or dist.get_rank() == 0
+ if _clip_grad_norm_call_count <= 5 and _rank0:
+ print(
+ f"[clip_grad_norm debug] call={_clip_grad_norm_call_count} "
+ f"total_norm={total_norm.item():.6f} "
+ f"sharded_norm_p={sharded_norm_p.item():.6f} local_norm_p={local_norm_p.item():.6f} "
+ f"grads={len(grads)} (sharded={n_sharded} replicate={n_replicate} regular={n_regular}) "
+ f"max_norm={max_norm} clip_coef={clip_coef.item():.6f}"
+ )
+ for g in grads:
+ if isinstance(g, DTensor):
+ g._local_tensor.mul_(clip_coef)
+ else:
+ g.mul_(clip_coef)
+
+ return total_norm
+
+
+def patch_accelerator(accelerator):
+ """Replace accelerator's clip_grad_norm_ with FSDP2-safe version."""
+ accelerator.clip_grad_norm_ = _clip_grad_norm
+ print(
+ "[fsdp2_buffer_patch] Patched accelerator.clip_grad_norm_ for FSDP2 DTensor compatibility"
+ )
+
+
+def log_param_dtypes(model):
+ """Debug aid: log per-rank parameter dtype counts (gated by DFLASH_LOG_PARAM_DTYPES=1).
+
+ Used to verify the FSDP2 dtype synchronization above — after ``fully_shard()`` params
+ are DTensors whose dtype lives on ``_local_tensor``. Off by default; this is purely
+ diagnostic and has no effect on training.
+ """
+ import os
+
+ if os.environ.get("DFLASH_LOG_PARAM_DTYPES") != "1":
+ return
+ rank = int(os.environ.get("RANK", 0))
+ dtypes = {}
+ for name, p in model.named_parameters():
+ dt_key = str(p.dtype) if not hasattr(p, "_local_tensor") else str(p._local_tensor.dtype)
+ dtypes.setdefault(dt_key, []).append(name)
+ for dt, names in dtypes.items():
+ print(f"[dtype_check rank={rank}] {dt}: {len(names)} params (e.g. {names[0]})")
diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py
index f62b099121d..bf247cc3db4 100644
--- a/examples/speculative_decoding/main.py
+++ b/examples/speculative_decoding/main.py
@@ -33,9 +33,11 @@
import dataclasses
import os
+import fsdp2_buffer_patch
import torch
import transformers
from eagle_utils import (
+ DFlashFSDP2ShardedSDExportCallback,
EagleTrainerWithAccLog,
EagleTrainingPlot,
LoRAWarmupCallback,
@@ -64,6 +66,9 @@
torch.manual_seed(0)
mto.enable_huggingface_checkpointing()
+if os.environ.get("PATCH_FSDP2_BUFFERS_TF457") == "1":
+ fsdp2_buffer_patch.apply()
+
# HF-compatible TrainingArguments with our speculative-decoding extensions, auto-derived
# from :class:`SpecTrainingArgs` so its field set can't drift from the Pydantic recipe schema.
@@ -145,6 +150,23 @@ def init_distributed_env(training_args: transformers.TrainingArguments) -> None:
)
+def _is_hf_format_checkpoint(checkpoint: str | None) -> bool:
+ """True if the checkpoint dir holds consolidated HF weights (from_pretrained-loadable).
+
+ FSDP2 SHARDED_STATE_DICT checkpoints contain only distributed shards
+ (``pytorch_model_fsdp_*/``), no ``model.safetensors`` — those return False, signalling
+ the caller to load the base model and resume via the Trainer instead. This inspects the
+ on-disk format of the *resume* checkpoint, which is a property of the existing bytes and
+ is independent of the current run's save mode (the two can differ across runs), so it's
+ intentionally separate from the save-time FSDP state-dict-type gate used for the export
+ callback.
+ """
+ if not checkpoint:
+ return False
+ hf_files = ("model.safetensors", "pytorch_model.bin", "model.safetensors.index.json")
+ return any(os.path.isfile(os.path.join(checkpoint, f)) for f in hf_files)
+
+
def train():
config_path, dry_run, overrides = _parse_cli()
recipe = load_recipe(config_path, overrides=overrides)
@@ -181,7 +203,13 @@ def train():
use_offline_training = recipe.data.mode != "online"
- if checkpoint:
+ # Resume path depends on the existing checkpoint's on-disk format: consolidated HF
+ # weights load via from_pretrained; FSDP sharded checkpoints load the base model and
+ # resume through the Trainer.
+ checkpoint_is_hf = _is_hf_format_checkpoint(checkpoint)
+
+ if checkpoint_is_hf:
+ assert checkpoint is not None # guaranteed by checkpoint_is_hf
with patch_transformers5_params_loading():
model = load_vlm_or_llm(
checkpoint, dtype="auto", trust_remote_code=recipe.model.trust_remote_code
@@ -190,6 +218,11 @@ def train():
checkpoint, trust_remote_code=recipe.model.trust_remote_code
)
else:
+ if checkpoint:
+ print_rank_0(
+ f"Checkpoint {checkpoint} is not in HF format (FSDP distributed checkpoint). "
+ f"Loading base model and resuming via Trainer."
+ )
model_name_or_path = recipe.model.model_name_or_path
if model_name_or_path is None:
raise ValueError(
@@ -245,20 +278,6 @@ def train():
)
return
- # Move any remaining CPU buffers to CUDA so DDP (NCCL-only) can broadcast
- # them. We iterate named_buffers and reassign via the owning module to
- # keep the module tree consistent. Parameters are left on CPU — the HF
- # Trainer will move them during init.
- if torch.cuda.is_available():
- _target_dev = torch.device("cuda", 0)
- for name, buf in list(model.named_buffers()):
- if buf.device.type == "cpu":
- parts = name.split(".")
- mod = model
- for p in parts[:-1]:
- mod = getattr(mod, p)
- setattr(mod, parts[-1], buf.to(_target_dev))
-
print_rank_0("Loading dataset...")
is_dflash = isinstance(recipe, ModelOptDFlashRecipe)
data_module = make_speculative_data_module(
@@ -289,6 +308,26 @@ def train():
**data_module,
)
+ if os.environ.get("PATCH_FSDP2_BUFFERS_TF457") == "1":
+ fsdp2_buffer_patch.patch_accelerator(trainer.accelerator)
+
+ # DFlash: export the draft submodule after each checkpoint save — but only under FSDP2
+ # SHARDED_STATE_DICT, where checkpoints are distributed shards the post-training
+ # export_hf_checkpoint.py pass can't read. Gate by reading the live FSDP state dict
+ # type off the accelerator; full-state-dict runs (DDP, single-device, FSDP2
+ # FULL_STATE_DICT) use the launcher's post-run export instead.
+ if isinstance(recipe, ModelOptDFlashRecipe):
+ fsdp_plugin = getattr(trainer.accelerator.state, "fsdp_plugin", None)
+ sd_type = str(getattr(fsdp_plugin, "state_dict_type", "") or "")
+ if "SHARDED_STATE_DICT" in sd_type:
+ trainer.add_callback(DFlashFSDP2ShardedSDExportCallback())
+ print_rank_0("DFlash: FSDP2 SHARDED_STATE_DICT — enabling per-save draft export.")
+ else:
+ print_rank_0(
+ f"DFlash: checkpoints use {sd_type or 'a full state dict'}; relying on the "
+ "launcher's post-run export (no per-save export callback added)."
+ )
+
# Manually enable this to return loss in eval
trainer.can_return_loss = True
# Make sure label_smoother is None
@@ -296,6 +335,9 @@ def train():
"label_smoother is not supported in speculative decoding!"
)
+ # Diagnostic (no-op unless DFLASH_LOG_PARAM_DTYPES=1): verifies FSDP2 dtype sync.
+ fsdp2_buffer_patch.log_param_dtypes(trainer.model)
+
print_rank_0("Start training...")
trainer.train(resume_from_checkpoint=checkpoint)
trainer.save_state()
diff --git a/modelopt/torch/export/plugins/hf_spec_export.py b/modelopt/torch/export/plugins/hf_spec_export.py
index 54d6e493c25..95b2de864f4 100644
--- a/modelopt/torch/export/plugins/hf_spec_export.py
+++ b/modelopt/torch/export/plugins/hf_spec_export.py
@@ -376,11 +376,15 @@ def _export_config(self):
"initializer_range": getattr(base_config, "initializer_range", 0.02),
"attention_bias": getattr(draft_config, "attention_bias", False),
"attention_dropout": getattr(draft_config, "attention_dropout", 0.0),
- "rope_theta": getattr(
- draft_config, "rope_theta", getattr(base_config, "rope_theta", 1000000.0)
+ # Inherit the target's rope_theta: DFlash injects the target's KV into every
+ # draft layer, so the draft's RoPE base must match the target's. (The draft
+ # arch config carries no rope_theta of its own.)
+ "rope_theta": (
+ getattr(base_config, "rope_theta", None)
+ if getattr(base_config, "rope_theta", None) is not None
+ else getattr(draft_config, "rope_theta", 1000000.0)
),
- # DFlash draft uses standard Qwen3 RoPE, not M-RoPE from multimodal models.
- # z-lab uses null; vLLM handles null rope_scaling correctly.
+ # YaRN long-context scaling is injected below (see the rope_scaling block).
"rope_scaling": None,
"tie_word_embeddings": False,
"torch_dtype": str(getattr(base_config, "torch_dtype", torch.bfloat16)).replace(
@@ -395,6 +399,12 @@ def _export_config(self):
else:
config["layer_types"] = ["full_attention"] * draft_config.num_hidden_layers
+ # Inject the export-time YaRN rope_scaling from the dflash_export_rope_scaling
+ # config field (empty dict disables). Mirrors eagle's eagle_export_rope_scaling.
+ export_rope_scaling = getattr(self.model, "dflash_export_rope_scaling", None)
+ if export_rope_scaling:
+ config["rope_scaling"] = export_rope_scaling
+
return config
def export(self, export_dir: Path | str, dtype: torch.dtype | None = None):
diff --git a/modelopt/torch/speculative/config.py b/modelopt/torch/speculative/config.py
index 7649b2d0357..8da7b6ec93e 100644
--- a/modelopt/torch/speculative/config.py
+++ b/modelopt/torch/speculative/config.py
@@ -132,6 +132,20 @@ class DFlashConfig(ModeloptBaseConfig):
description="Whether to use torch.compile on DFlash forward/loss methods.",
)
+ dflash_export_rope_scaling: dict = ModeloptField(
+ default={},
+ description=(
+ "The rope_scaling config to inject into the exported HuggingFace draft config. "
+ "The DFlash draft trains on a short window but must draft for the target at long "
+ "context, so — mirroring published Eagle3 drafts such as nvidia/Kimi-K2.6-Eagle3 — "
+ "a YaRN rope_scaling is injected at export to extend the training window to the "
+ "target's full context. Example: "
+ '{"type": "yarn", "factor": 48.0, "original_max_position_embeddings": 4096, '
+ '"beta_fast": 1.0, "beta_slow": 1.0, "mscale": 1.0, "mscale_all_dim": 1.0}. '
+ "Set to empty dict {} (default) to disable rope scaling injection at export."
+ ),
+ )
+
class MedusaConfig(ModeloptBaseConfig):
"""Medusa config."""
diff --git a/modelopt/torch/speculative/dflash/dflash_model.py b/modelopt/torch/speculative/dflash/dflash_model.py
index a99e93c816a..702d9812482 100644
--- a/modelopt/torch/speculative/dflash/dflash_model.py
+++ b/modelopt/torch/speculative/dflash/dflash_model.py
@@ -35,3 +35,4 @@ def modify(self, config):
self.dflash_num_anchors = config.dflash_num_anchors
self.dflash_report_acc = config.dflash_report_acc
self.dflash_use_torch_compile = config.dflash_use_torch_compile
+ self.dflash_export_rope_scaling = config.dflash_export_rope_scaling
diff --git a/modelopt/torch/speculative/plugins/hf_dflash.py b/modelopt/torch/speculative/plugins/hf_dflash.py
index 1760cb2072d..9e54ee531ce 100644
--- a/modelopt/torch/speculative/plugins/hf_dflash.py
+++ b/modelopt/torch/speculative/plugins/hf_dflash.py
@@ -132,26 +132,45 @@ def modify(self, config):
self.dflash_config.hidden_size = base_config.hidden_size
self.dflash_config.vocab_size = base_config.vocab_size
- # Inherit architecture settings from base model when not specified by user.
- # Static defaults (hidden_act, attention_bias, etc.) are in dflash/default_config.py.
- # NOTE: rope_scaling is intentionally excluded. DFlash draft uses Qwen3
- # RotaryEmbedding which only supports standard RoPE. Inheriting M-RoPE
- # config from multimodal models (e.g. Qwen3.5) would be incorrect.
- _base_model_attrs = [
+ # Inherit architecture settings from base model when not specified by user
+ # (setdefault). Static defaults (hidden_act, attention_bias, etc.) are in
+ # dflash/default_config.py.
+ _setdefault_attrs = [
"max_position_embeddings",
"intermediate_size",
"num_attention_heads",
"num_key_value_heads",
- "rope_theta",
- "rope_type",
- "rope_interleaved",
"rms_norm_eps",
]
- for attr in _base_model_attrs:
+ for attr in _setdefault_attrs:
if not hasattr(self.dflash_config, attr) or getattr(self.dflash_config, attr) is None:
if hasattr(base_config, attr):
setattr(self.dflash_config, attr, getattr(base_config, attr))
+ # RoPE base settings are ENFORCED to match the base model (not setdefault): the
+ # DFlash draft injects the target's KV into every layer, so its RoPE base must
+ # match the target's for the injected positions to align — and the exporter writes
+ # the base model's rope_theta. Letting dflash_architecture_config override these
+ # would make training (draft rope) and inference (base rope) disagree, so we
+ # overwrite any user value and warn. (rope_scaling is intentionally NOT inherited:
+ # DFlash uses standard Qwen3 RotaryEmbedding; the long-context YaRN scaling is
+ # added only at export via dflash_export_rope_scaling.)
+ for attr in ("rope_theta", "rope_type", "rope_interleaved"):
+ if not hasattr(base_config, attr):
+ continue
+ base_val = getattr(base_config, attr)
+ user_val = getattr(self.dflash_config, attr, None)
+ if user_val is not None and user_val != base_val:
+ logger.warning(
+ "DFlash: ignoring dflash_architecture_config.%s=%r and enforcing the "
+ "base model's value %r — the draft injects the target's KV, so its RoPE "
+ "base must match the target's.",
+ attr,
+ user_val,
+ base_val,
+ )
+ setattr(self.dflash_config, attr, base_val)
+
self.dflash_config.head_dim = getattr(
self.dflash_config,
"head_dim",
diff --git a/tests/unit/torch/export/test_hf_spec_rope_export.py b/tests/unit/torch/export/test_hf_spec_rope_export.py
index 171fe6263c6..720082bc617 100644
--- a/tests/unit/torch/export/test_hf_spec_rope_export.py
+++ b/tests/unit/torch/export/test_hf_spec_rope_export.py
@@ -13,11 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-"""Unit tests for EAGLE export rope scaling logic in hf_spec_export.py."""
+"""Unit tests for EAGLE/DFlash export rope scaling logic in hf_spec_export.py."""
+from types import SimpleNamespace
from unittest.mock import MagicMock
-from modelopt.torch.export.plugins.hf_spec_export import EagleExporter
+import torch
+
+from modelopt.torch.export.plugins.hf_spec_export import DFlashExporter, EagleExporter
DEFAULT_ROPE_SCALING = {
"rope_type": "yarn",
@@ -76,3 +79,63 @@ def test_rope_theta_fallback_from_rope_scaling():
"""rope_theta is populated from rope_scaling when not available as top-level attr."""
config = _make_exporter(rope_type="default", rope_theta=500000)._export_config()
assert config["rope_theta"] == 500000
+
+
+# ---------------------------------------------------------------------------
+# DFlash export rope scaling (config-field convergence, mirrors the eagle style)
+# ---------------------------------------------------------------------------
+
+DFLASH_YARN = {
+ "type": "yarn",
+ "factor": 48.0,
+ "original_max_position_embeddings": 4096,
+ "beta_fast": 1.0,
+ "beta_slow": 1.0,
+ "mscale": 1.0,
+ "mscale_all_dim": 1.0,
+}
+
+
+def _make_dflash_exporter(dflash_export_rope_scaling=None, base_rope_theta=5000000.0):
+ base_config = SimpleNamespace(
+ hidden_size=128,
+ num_attention_heads=4,
+ num_key_value_heads=2,
+ intermediate_size=256,
+ vocab_size=1000,
+ max_position_embeddings=196608,
+ initializer_range=0.02,
+ num_hidden_layers=8,
+ rope_theta=base_rope_theta,
+ torch_dtype=torch.bfloat16,
+ )
+ draft_config = SimpleNamespace(num_hidden_layers=2)
+ model = SimpleNamespace(
+ config=base_config,
+ dflash_config=draft_config,
+ dflash_block_size=8,
+ mask_token_id=999,
+ target_layer_ids=[1, 3, 5, 7],
+ dflash_export_rope_scaling=dflash_export_rope_scaling,
+ )
+ exporter = DFlashExporter.__new__(DFlashExporter)
+ exporter.model = model
+ return exporter
+
+
+def test_dflash_yarn_rope_injected_from_config_field():
+ """YaRN rope_scaling from dflash_export_rope_scaling is injected verbatim."""
+ config = _make_dflash_exporter(dflash_export_rope_scaling=DFLASH_YARN)._export_config()
+ assert config["rope_scaling"] == DFLASH_YARN
+
+
+def test_dflash_rope_not_injected_when_field_empty():
+ """Empty dict (default) disables rope scaling injection."""
+ config = _make_dflash_exporter(dflash_export_rope_scaling={})._export_config()
+ assert config["rope_scaling"] is None
+
+
+def test_dflash_rope_theta_inherits_base():
+ """rope_theta is inherited from the target/base config (draft drafts for the base)."""
+ config = _make_dflash_exporter(base_rope_theta=5000000.0)._export_config()
+ assert config["rope_theta"] == 5000000.0
diff --git a/tools/launcher/common/specdec/dflash_online_training.sh b/tools/launcher/common/specdec/dflash_online_training.sh
index 60c2b9901e1..efc38be3b17 100644
--- a/tools/launcher/common/specdec/dflash_online_training.sh
+++ b/tools/launcher/common/specdec/dflash_online_training.sh
@@ -42,6 +42,13 @@ pip install -r modules/Model-Optimizer/examples/speculative_decoding/requirement
pip install huggingface-hub>=1.2.1
export PATH=$PATH:/workspace/.local/bin
+# Some trust_remote_code MoE models pin an older transformers (e.g. MiniMax-M2.7
+# needs 4.57.x; its modeling code is incompatible with the 5.x the recipe defaults
+# pull in). Override here when the model needs it.
+if [ -n "${OVERRIDE_TRANSFORMERS:-}" ]; then
+ pip install "transformers==${OVERRIDE_TRANSFORMERS}"
+fi
+
###################################################################################################
trap 'error_handler $0 $LINENO' ERR
@@ -99,9 +106,18 @@ fi
export TOKENIZERS_PARALLELISM=False
+# ACCELERATE_CONFIG selects an explicit accelerate config file (e.g. an FSDP2 YAML).
+# Required for trust_remote_code MoE models that need FSDP2 via accelerate config
+# rather than transformers-native ParallelismConfig (MiniMax-M2.7 on transformers 4.57.x).
+ACCEL_CONFIG_ARGS=""
+if [ -n "${ACCELERATE_CONFIG:-}" ] && [ -f "${ACCELERATE_CONFIG}" ]; then
+ ACCEL_CONFIG_ARGS="--config_file ${ACCELERATE_CONFIG}"
+ echo "Using accelerate config: ${ACCELERATE_CONFIG}"
+fi
+
set -x
start_time=$(date +%s)
-accelerate launch --mixed_precision bf16 $MULTI_NODE_ARGS $MAIN_PY "$@"
+accelerate launch $ACCEL_CONFIG_ARGS --mixed_precision "${MIXED_PRECISION:-bf16}" $MULTI_NODE_ARGS $MAIN_PY "$@"
echo "Training time: $(( $(date +%s) - start_time )) seconds"
set +x
diff --git a/tools/launcher/core.py b/tools/launcher/core.py
index 7f12c368efe..17d6cb44138 100644
--- a/tools/launcher/core.py
+++ b/tools/launcher/core.py
@@ -316,8 +316,17 @@ def build_slurm_executor(
retries=0,
packager=packager,
srun_args=slurm_config.srun_args,
+ # Copy into a fresh dict so the requeue mutation below doesn't leak back into
+ # the shared slurm_config.additional_parameters.
+ additional_parameters=dict(getattr(slurm_config, "additional_parameters", None) or {}),
**optional_kwargs,
)
+ if getattr(slurm_config, "requeue", False):
+ executor.additional_parameters["requeue"] = True
+ # The nemo-run sbatch wrapper only calls `scontrol requeue` when
+ # TORCHX_MAX_RETRIES > SLURM_RESTART_COUNT. retries=0 (the default)
+ # disables this, so bump it when requeue is requested.
+ executor.retries = max(executor.retries, 3)
return executor
diff --git a/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/accelerate_fsdp2_hybrid.yaml b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/accelerate_fsdp2_hybrid.yaml
new file mode 100644
index 00000000000..3310d1e5c7e
--- /dev/null
+++ b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/accelerate_fsdp2_hybrid.yaml
@@ -0,0 +1,11 @@
+compute_environment: LOCAL_MACHINE
+distributed_type: FSDP
+fsdp_config:
+ fsdp_version: 2
+ fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
+ fsdp_transformer_layer_cls_to_wrap: MiniMaxM2DecoderLayer,DFlashModule
+ fsdp_sharding_strategy: HYBRID_SHARD
+ fsdp_use_orig_params: true
+ fsdp_state_dict_type: SHARDED_STATE_DICT
+ fsdp_sync_module_states: true
+ fsdp_cpu_ram_efficient_loading: true
diff --git a/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/chat_template_train.jinja b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/chat_template_train.jinja
new file mode 100644
index 00000000000..178dc002069
--- /dev/null
+++ b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/chat_template_train.jinja
@@ -0,0 +1,162 @@
+{# MiniMax-M2.7 chat template with {% generation %} tags for answer_only_loss training.
+ Adapted from https://huggingface.co/MiniMaxAI/MiniMax-M2.7/blob/main/chat_template.jinja
+ with {% generation %} / {% endgeneration %} wrapping assistant content.
+#}
+{%- set toolcall_begin_token = '' -%}
+{%- set toolcall_end_token = '' -%}
+{#- Tool Rendering Functions ============================================== -#}
+{%- macro render_tool_namespace(namespace_name, tool_list) -%}
+{%- for tool in tool_list -%}
+{{ tool.function | tojson(ensure_ascii=False) }}
+{% endfor -%}
+{%- endmacro -%}
+{%- macro visible_text(content) -%}
+ {%- if content is string -%}
+ {{ content }}
+ {%- elif content is iterable and content is not mapping -%}
+ {%- for item in content -%}
+ {%- if item is mapping and item.type == 'text' -%}
+ {{- item.text }}
+ {%- elif item is string -%}
+ {{- item }}
+ {%- endif -%}
+ {%- endfor -%}
+ {%- else -%}
+ {{- content }}
+ {%- endif -%}
+{%- endmacro -%}
+{#- System Message Construction ============================================ -#}
+{%- macro build_system_message(system_message) -%}
+ {%- if system_message and system_message.content -%}
+ {{- visible_text(system_message.content) }}
+ {%- else -%}
+ {%- if model_identity is not defined -%}
+ {%- set model_identity = "You are a helpful assistant. Your name is MiniMax-M2.7 and is built by MiniMax." -%}
+ {%- endif -%}
+ {{- model_identity }}
+ {%- endif -%}
+
+ {#- Handle current_date -#}
+ {%- if system_message and system_message.current_date -%}
+ {{- '\n' ~ 'Current date: ' + system_message.current_date }}
+ {%- endif -%}
+ {#- Handle current_location -#}
+ {%- if system_message and system_message.current_location -%}
+ {{- '\n' ~ 'Current location: ' + system_message.current_location }}
+ {%- endif -%}
+{%- endmacro -%}
+{#- Main Template Logic ================================================= -#}
+{#- Extract system message (only first message if it's system) -#}
+{%- set system_message = none -%}
+{%- set conversation_messages = messages -%}
+{%- if messages and messages[0].role == "system" -%}
+ {%- set system_message = messages[0] -%}
+ {%- set conversation_messages = messages[1:] -%}
+{%- endif -%}
+{#- Get the last user message turn, for interleaved thinking -#}
+{%- set ns = namespace(last_user_index=-1) %}
+{% for m in conversation_messages %}
+ {%- if m.role == 'user' %}
+ {% set ns.last_user_index = loop.index0 -%}
+ {%- endif %}
+{%- endfor %}
+{#- Render system message -#}
+{{- ']~!b[' ~ ']~b]system' ~ '\n' }}
+{{- build_system_message(system_message) }}
+{#- Render tools if available -#}
+{%- if tools -%}
+ {{- '\n\n' ~ '# Tools' ~ '\n' ~ 'You may call one or more tools to assist with the user query.\nHere are the tools available in JSONSchema format:' ~ '\n' }}
+ {{- '\n' ~ '' ~ '\n' }}
+ {{- render_tool_namespace("functions", tools) }}
+ {{- '' ~ '\n\n' }}
+{{- 'When making tool calls, use XML format to invoke tools and pass parameters:' ~ '\n' }}
+{{- '\n' ~ toolcall_begin_token }}
+
+param-value-1
+param-value-2
+...
+
+{{- '\n' ~ toolcall_end_token }}
+{%- endif -%}
+{{- '[e~[\n' }}
+
+{#- Render messages -#}
+{%- set last_tool_call = namespace(name=none) -%}
+{%- for message in conversation_messages -%}
+ {%- if message.role == 'assistant' -%}
+ {{- ']~b]ai' ~ '\n' }}
+ {%- generation -%}
+ {%- set reasoning_content = '' %}
+ {%- set content = visible_text(message.content) %}
+ {%- if message.reasoning_content is string %}
+ {%- set reasoning_content = message.reasoning_content %}
+ {%- else %}
+ {%- if '' in content %}
+ {%- set reasoning_content = content.split('')[0].strip('\n').split('')[-1].strip('\n') %}
+ {%- set content = content.split('')[-1].strip('\n') %}
+ {%- endif %}
+ {%- endif %}
+ {%- if reasoning_content and loop.index0 > ns.last_user_index -%}
+ {{- '' ~ '\n' ~ reasoning_content ~ '\n' ~ '' ~ '\n\n' }}
+ {%- endif -%}
+ {%- if content -%}
+ {{- content }}
+ {%- endif -%}
+ {%- if message.tool_calls -%}
+ {{- '\n' ~ toolcall_begin_token ~ '\n' }}
+
+ {%- for tool_call in message.tool_calls -%}
+ {%- if tool_call.function %}
+ {%- set tool_call = tool_call.function %}
+ {%- endif %}
+ {{- '' }}
+ {% set _args = tool_call.arguments %}
+ {%- for k, v in _args.items() %}
+ {{- '' }}
+ {{- v | tojson(ensure_ascii=False) if v is not string else v }}
+ {{- '' }}
+ {% endfor %}
+ {{- '' ~ '\n' }}
+ {%- endfor -%}
+
+ {{- toolcall_end_token}}
+ {%- set last_tool_call.name = message.tool_calls[-1].name -%}
+ {%- else -%}
+ {%- set last_tool_call.name = none -%}
+ {%- endif -%}
+ {%- endgeneration -%}
+ {{- '[e~[' ~ '\n' }}
+
+ {%- elif message.role == 'tool' -%}
+ {%- if last_tool_call.name is none -%}
+ {{- raise_exception("Message has tool role, but there was no previous assistant message with a tool call!") }}
+ {%- endif -%}
+ {%- if loop.first or (conversation_messages[loop.index0 - 1].role != 'tool') -%}
+ {{- ']~b]tool' }}
+ {%- endif -%}
+ {%- if message.content is string -%}
+ {{- '\n' }}
+ {{- message.content }}
+ {{- '' }}
+ {%- else -%}
+ {%- for tr in message.content -%}
+ {{- '\n' }}
+ {{- tr.output if tr.output is defined else (tr.text if tr.type == 'text' and tr.text is defined else tr) }}
+ {{- '\n' }}
+ {%- endfor -%}
+ {%- endif -%}
+ {%- if loop.last or (conversation_messages[loop.index0 + 1].role != 'tool') -%}
+ {{- '[e~[\n' -}}
+ {%- endif -%}
+
+ {%- elif message.role == 'user' -%}
+ {{- ']~b]user' ~ '\n' }}
+ {{- visible_text(message.content) }}
+ {{- '[e~[' ~ '\n' }}
+ {%- endif -%}
+{%- endfor -%}
+
+{#- Generation prompt -#}
+{%- if add_generation_prompt -%}
+{{- ']~b]ai' ~ '\n' ~ '' ~ '\n' }}
+{%- endif -%}
diff --git a/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_offline_dflash.yaml b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_offline_dflash.yaml
new file mode 100644
index 00000000000..dc3dba4a65a
--- /dev/null
+++ b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_offline_dflash.yaml
@@ -0,0 +1,97 @@
+# DFlash offline speculative decoding training for MiniMax-M2.7 (229B MoE).
+#
+# 2-step pipeline (compare with hf_online_dflash.yaml, which streams the base model
+# forward at training time instead):
+# task_0: Dump base-model hidden states once via vLLM extract_hidden_states.
+# task_1: Train the DFlash draft on the dump (FakeBaseModel — loads only lm_head +
+# embed_tokens, not the full 229B base).
+#
+# We use the vLLM dump (compute_hidden_states_vllm.py) rather than the HF dump because
+# the 229B MoE is impractical to forward on a single GPU; vLLM shards it with TP. The
+# dump disables prefix caching so every token's hidden state is emitted and the dumped
+# sequence lengths line up with input_ids/loss_mask.
+#
+# Reference: "DFlash: Block Diffusion for Flash Speculative Decoding" (arXiv:2602.06036)
+#
+# Usage:
+# uv run slurm.py --yaml modules/Model-Optimizer/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_offline_dflash.yaml --yes
+
+job_name: MiniMax-M2.7-DFlash_offline
+pipeline:
+ global_vars:
+ hf_model: /hf-local/MiniMaxAI/MiniMax-M2.7
+
+ # Step 1: Dump base-model hidden states via vLLM extract_hidden_states (TP=4).
+ task_0:
+ script: common/eagle3/dump_offline_data_vllm.sh
+ args:
+ - --input-data /hf-local/modelopt/MiniMax-M2.7-synthetic-data-clean-v2
+ - --output-dir /scratchspace/dflash_minimax_m2.7_hidden_states
+ # Must match the draft model's num_hidden_layers (recipe default: 5).
+ - --aux-layers dflash
+ - --answer-only-loss
+ - --chat-template examples/MiniMax/MiniMax-M2.7-DFlash/chat_template_train.jinja
+ - --max-seq-len 4096
+ - --tp 4
+ environment:
+ - HF_MODEL_CKPT: <>
+ - TRUST_REMOTE_CODE: "1"
+ slurm_config:
+ _factory_: "slurm_factory"
+ nodes: 1
+ ntasks_per_node: 1
+ gpus_per_node: 8
+ container: vllm/vllm-openai:nightly
+
+ # Step 2: Train DFlash offline on the dumped hidden states. FakeBaseModel avoids
+ # loading the full 229B — only lm_head + embed_tokens are read from the checkpoint.
+ task_1:
+ script: common/specdec/dflash_online_training.sh
+ args:
+ - --config modules/Model-Optimizer/modelopt_recipes/general/speculative_decoding/dflash.yaml
+ - model.model_name_or_path=<>
+ - model.trust_remote_code=true
+ - model.use_fake_base_for_offline=true
+ - data.mode=offline
+ - data.offline_data_path=/scratchspace/dflash_minimax_m2.7_hidden_states
+ - data.chat_template=examples/MiniMax/MiniMax-M2.7-DFlash/chat_template_train.jinja
+ - training.output_dir=/scratchspace/dflash_minimax_m2.7_offline
+ - training.num_train_epochs=10
+ - training.per_device_train_batch_size=2
+ - training.learning_rate=1.2e-3
+ - training.warmup_steps=100
+ - training.training_seq_len=4096
+ - training.logging_steps=100
+ - training.save_steps=400
+ - training.disable_tqdm=true
+ - training.dp_shard_size=1
+ - training.answer_only_loss=true
+ - training.ddp_timeout=3600
+ - training.bf16=false
+ - dflash.dflash_self_logit_distillation=true
+ - dflash.dflash_block_size=8
+ - dflash.dflash_num_anchors=512
+ - dflash.dflash_loss_decay_factor=4.0
+ - dflash.dflash_architecture_config.num_hidden_layers=5
+ # Mask token id (an existing reserved row in MiniMax-M2.7's embedding).
+ - dflash.dflash_mask_token_id=200054
+ # YaRN rope_scaling injected at export for long context (factor = 196608/4096 = 48).
+ - dflash.dflash_export_rope_scaling.type=yarn
+ - dflash.dflash_export_rope_scaling.factor=48.0
+ - dflash.dflash_export_rope_scaling.original_max_position_embeddings=4096
+ - dflash.dflash_export_rope_scaling.beta_fast=1.0
+ - dflash.dflash_export_rope_scaling.beta_slow=1.0
+ - dflash.dflash_export_rope_scaling.mscale=1.0
+ - dflash.dflash_export_rope_scaling.mscale_all_dim=1.0
+ environment:
+ - NUM_NODES: "8"
+ # Offline training uses a lightweight FakeBaseModel, so plain DDP suffices (no
+ # ACCELERATE_CONFIG / FSDP2 patches). OVERRIDE_TRANSFORMERS pins 4.57.1 for the
+ # MiniMax config load.
+ - OVERRIDE_TRANSFORMERS: "4.57.1"
+ - MIXED_PRECISION: "no"
+ slurm_config:
+ _factory_: "slurm_factory"
+ nodes: 8
+ ntasks_per_node: 1
+ gpus_per_node: 8
diff --git a/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_online_dflash.yaml b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_online_dflash.yaml
new file mode 100644
index 00000000000..92c450ee550
--- /dev/null
+++ b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_online_dflash.yaml
@@ -0,0 +1,108 @@
+# DFlash online speculative decoding training for MiniMax-M2.7 (229B MoE).
+#
+# 3-step pipeline:
+# task_0: Online DFlash training (8 nodes x 8 GPU, FSDP2 HYBRID_SHARD)
+# task_1: vLLM smoke test with the exported DFlash draft
+# task_2: HF AR (acceptance-length) evaluation on MT-Bench (1 GPU)
+#
+# MiniMax-M2.7 specifics:
+# - trust_remote_code model whose code requires transformers 4.57.x, so FSDP2 is
+# configured via an accelerate config (accelerate_fsdp2_hybrid.yaml) rather than
+# transformers-native ParallelismConfig. Hence OVERRIDE_TRANSFORMERS + ACCELERATE_CONFIG
+# + dp_shard_size=1 (keeps main.py from building a ParallelismConfig) + PATCH_FSDP2_BUFFERS_TF457.
+# - 229B in FP8 needs cpu_ram_efficient_loading (set in the accelerate config) and
+# MIXED_PRECISION=no (the recipe / checkpoint already carry the right dtypes).
+# - The DFlash draft is Qwen3-architecture (5 layers); the target/draft architectures
+# are independent by design.
+#
+# Reference: "DFlash: Block Diffusion for Flash Speculative Decoding" (arXiv:2602.06036)
+#
+# Usage:
+# uv run launch.py --yaml examples/MiniMax/MiniMax-M2.7-DFlash/hf_online_dflash.yaml --yes
+# uv run slurm.py --yaml modules/Model-Optimizer/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_online_dflash.yaml --yes
+
+job_name: MiniMax-M2.7-DFlash_online
+pipeline:
+ global_vars:
+ hf_model: /hf-local/MiniMaxAI/MiniMax-M2.7
+
+ # Step 1: Online DFlash training.
+ task_0:
+ script: common/specdec/dflash_online_training.sh
+ args:
+ - --config modules/Model-Optimizer/modelopt_recipes/general/speculative_decoding/dflash.yaml
+ - model.model_name_or_path=<>
+ - model.trust_remote_code=true
+ - data.data_path=/hf-local/modelopt/MiniMax-M2.7-synthetic-data-clean-v2
+ - data.chat_template=examples/MiniMax/MiniMax-M2.7-DFlash/chat_template_train.jinja
+ - training.output_dir=/scratchspace/dflash_minimax_m2.7
+ - training.num_train_epochs=10
+ - training.per_device_train_batch_size=2
+ - training.learning_rate=1.2e-3
+ - training.warmup_steps=100
+ - training.training_seq_len=4096
+ - training.logging_steps=100
+ - training.save_steps=400
+ - training.disable_tqdm=true
+ # dp_shard_size=1 stops main.py from creating a ParallelismConfig; the
+ # accelerate config below owns FSDP2 sharding.
+ - training.dp_shard_size=1
+ - training.answer_only_loss=true
+ - training.ddp_timeout=3600
+ - training.bf16=false
+ - dflash.dflash_self_logit_distillation=true
+ - dflash.dflash_block_size=8
+ - dflash.dflash_num_anchors=512
+ - dflash.dflash_loss_decay_factor=4.0
+ - dflash.dflash_architecture_config.num_hidden_layers=5
+ # Mask token id. The draft reuses the target's embed_tokens, so this must be an id
+ # that exists in the target embedding; 200054 is a reserved row in MiniMax-M2.7.
+ - dflash.dflash_mask_token_id=200054
+ # YaRN rope_scaling injected at export for long context (factor = 196608/4096 = 48).
+ - dflash.dflash_export_rope_scaling.type=yarn
+ - dflash.dflash_export_rope_scaling.factor=48.0
+ - dflash.dflash_export_rope_scaling.original_max_position_embeddings=4096
+ - dflash.dflash_export_rope_scaling.beta_fast=1.0
+ - dflash.dflash_export_rope_scaling.beta_slow=1.0
+ - dflash.dflash_export_rope_scaling.mscale=1.0
+ - dflash.dflash_export_rope_scaling.mscale_all_dim=1.0
+ environment:
+ - NUM_NODES: "8"
+ - OVERRIDE_TRANSFORMERS: "4.57.1"
+ - ACCELERATE_CONFIG: examples/MiniMax/MiniMax-M2.7-DFlash/accelerate_fsdp2_hybrid.yaml
+ - PATCH_FSDP2_BUFFERS_TF457: "1"
+ - MIXED_PRECISION: "no"
+ slurm_config:
+ _factory_: "slurm_factory"
+ nodes: 8
+ ntasks_per_node: 1
+ gpus_per_node: 8
+
+ # Step 2: vLLM smoke test with the exported DFlash draft.
+ task_1:
+ script: common/specdec/vllm_smoke_test.sh
+ environment:
+ - HF_MODEL_CKPT: <>
+ - DRAFT_CKPT_DIR: /scratchspace/dflash_minimax_m2.7
+ - SPEC_METHOD: "dflash"
+ - NUM_SPEC_TOKENS: "7"
+ - TP_SIZE: "4"
+ slurm_config:
+ _factory_: "slurm_factory"
+ nodes: 1
+ ntasks_per_node: 1
+ gpus_per_node: 8
+ container: vllm/vllm-openai:nightly
+
+ # Step 3: HF AR (acceptance-length) evaluation on MT-Bench.
+ task_2:
+ script: common/specdec/ar_eval_mtbench.sh
+ args:
+ - --ckpt_dir /scratchspace/dflash_minimax_m2.7
+ - --osl 512
+ - --steps 7
+ slurm_config:
+ _factory_: "slurm_factory"
+ nodes: 1
+ ntasks_per_node: 1
+ gpus_per_node: 8
diff --git a/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/specdec_bench.yaml b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/specdec_bench.yaml
new file mode 100644
index 00000000000..ad538805fb6
--- /dev/null
+++ b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/specdec_bench.yaml
@@ -0,0 +1,90 @@
+# SPEED-Bench run for MiniMax-M2.7 with the trained DFlash draft, via vLLM.
+#
+# Two-task pipeline:
+# task_0 qualitative split (nvidia/SPEED-Bench-Internal/qualitative, 80 prompts)
+# task_1 throughput_32k split (nvidia/SPEED-Bench-Internal/throughput_32k, long context)
+#
+# Both use --speculative_algorithm DFLASH with the exported draft. Acceptance length
+# (AL) is the primary metric; the throughput_32k split checks that AL holds at long
+# context (if it drops, add YaRN/rope_scaling to the draft export).
+#
+# Results write to /scratchspace/minimax_m2.7_dflash_vllm//. The
+# pensieve-intern specdec_bench workflow's wrap_up stage publishes these to
+# s3://team-specdec-workgroup/results/minimax_m2.7_dflash_vllm//.
+#
+# Set draft_model to the trained, exported draft checkpoint (an exported-checkpoint-*
+# dir containing model.safetensors), e.g. the output of hf_online_dflash.yaml.
+#
+# Usage:
+# Edit the two --draft_model_dir args to point at your exported draft checkpoint, then:
+# uv run slurm.py --yaml modules/Model-Optimizer/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/specdec_bench.yaml --yes
+#
+# NOTE: the placeholder below is `exported-checkpoint-final` — that dir is written by the
+# training launcher's post-run export (dflash_online_training.sh) for the final model. The
+# per-save DFlashFSDP2ShardedSDExportCallback instead writes `exported-checkpoint-`
+# dirs; to bench a specific intermediate step, point at one of those.
+
+job_name: MiniMax-M2.7-DFlash_specdec_bench
+pipeline:
+ global_vars:
+ hf_model: /hf-local/MiniMaxAI/MiniMax-M2.7
+
+ # Step 1: qualitative split — quality / acceptance-length numbers.
+ task_0:
+ script: common/specdec_bench/run.sh
+ args:
+ - --dataset speed
+ - --dataset_path /hf-local/nvidia/SPEED-Bench-Internal/qualitative
+ - --engine VLLM
+ - --speculative_algorithm DFLASH
+ - --draft_model_dir /scratchspace/dflash_minimax_m2.7/exported-checkpoint-final
+ - --block_size 8
+ - --tp_size 4
+ - --ep_size 4
+ - --concurrency 32
+ - --output_length 4096
+ - --trust_remote_code
+ - --aa_timing
+ - --show_progress
+ - --save_dir /scratchspace/minimax_m2.7_dflash_vllm/qualitative
+ environment:
+ - HF_MODEL_CKPT: <>
+ - HF_LOCAL: /hf-local
+ slurm_config:
+ _factory_: "slurm_factory"
+ nodes: 1
+ ntasks_per_node: 1
+ gpus_per_node: 8
+ container: vllm/vllm-openai:nightly
+
+ # Step 2: throughput_32k split — long-context throughput / AL.
+ # --num_requests 80 caps the 1,536-sample split to fit the time limit; --max_seq_len
+ # 40960 = 32K input + 4K output + 4K headroom so vLLM doesn't auto-cap below 36K.
+ task_1:
+ script: common/specdec_bench/run.sh
+ args:
+ - --dataset speed
+ - --dataset_path /hf-local/nvidia/SPEED-Bench-Internal/throughput_32k
+ - --engine VLLM
+ - --speculative_algorithm DFLASH
+ - --draft_model_dir /scratchspace/dflash_minimax_m2.7/exported-checkpoint-final
+ - --block_size 8
+ - --tp_size 4
+ - --ep_size 4
+ - --concurrency 8
+ - --num_requests 80
+ - --output_length 4096
+ - --max_seq_len 40960
+ - --trust_remote_code
+ - --aa_timing
+ - --show_progress
+ - --save_dir /scratchspace/minimax_m2.7_dflash_vllm/throughput_32k
+ environment:
+ - HF_MODEL_CKPT: <>
+ - HF_LOCAL: /hf-local
+ slurm_config:
+ _factory_: "slurm_factory"
+ nodes: 1
+ ntasks_per_node: 1
+ gpus_per_node: 8
+ container: vllm/vllm-openai:nightly
diff --git a/tools/launcher/tests/test_slurm_executor.py b/tools/launcher/tests/test_slurm_executor.py
index 900616136e3..6baec571e3b 100644
--- a/tools/launcher/tests/test_slurm_executor.py
+++ b/tools/launcher/tests/test_slurm_executor.py
@@ -34,6 +34,7 @@ def test_scratch_and_modelopt_mounts(self, mock_tunnel, mock_executor):
mock_tunnel.return_value = MagicMock()
slurm_config = MagicMock(
+ requeue=False,
host="test-host",
port=22,
account="test_account",
@@ -77,6 +78,7 @@ def test_scratch_path_uses_experiment_title(self, mock_tunnel, mock_executor):
mock_tunnel.return_value = MagicMock()
slurm_config = MagicMock(
+ requeue=False,
host="host",
port=22,
account="acct",
@@ -112,6 +114,7 @@ def test_tunnel_created_with_correct_params(self, mock_tunnel, mock_executor):
mock_tunnel.return_value = MagicMock()
slurm_config = MagicMock(
+ requeue=False,
host="login.cluster.com",
port=30022,
account="acct",
@@ -150,6 +153,7 @@ def test_executor_params(self, mock_tunnel, mock_executor):
mock_tunnel.return_value = MagicMock()
slurm_config = MagicMock(
+ requeue=False,
host="h",
port=22,
account="my_acct",
@@ -195,6 +199,7 @@ def test_none_container_mounts_handled(self, mock_tunnel, mock_executor):
mock_tunnel.return_value = MagicMock()
slurm_config = MagicMock(
+ requeue=False,
host="h",
port=22,
account="a",
@@ -222,3 +227,42 @@ def test_none_container_mounts_handled(self, mock_tunnel, mock_executor):
# Should not crash; mounts should still include scratch + modelopt + title
mounts = mock_executor.call_args[1]["container_mounts"]
assert len(mounts) >= 3
+
+ @patch("core.run.SlurmExecutor")
+ @patch("core.run.SSHTunnel")
+ def test_requeue_sets_param_and_bumps_retries(self, mock_tunnel, mock_executor):
+ mock_tunnel.return_value = MagicMock()
+ executor = mock_executor.return_value
+ executor.retries = 0
+ executor.additional_parameters = {}
+
+ slurm_config = MagicMock(
+ requeue=True,
+ host="h",
+ port=22,
+ account="a",
+ partition="b",
+ container="c",
+ modelopt_install_path="/m",
+ container_mounts=[],
+ srun_args=[],
+ nodes=1,
+ ntasks_per_node=1,
+ gpus_per_node=1,
+ array=None,
+ )
+
+ build_slurm_executor(
+ user="u",
+ identity=None,
+ slurm_config=slurm_config,
+ experiment_id="e",
+ job_dir="/j",
+ task_name="t",
+ packager=MagicMock(),
+ )
+
+ # requeue=True flags the additional parameter and bumps retries above 0 so
+ # nemo-run's sbatch wrapper actually issues `scontrol requeue` on preemption.
+ assert executor.additional_parameters["requeue"] is True
+ assert executor.retries == 3