From c23d002ede8f408b231f76f70f41417bfe900859 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Thu, 16 Apr 2026 17:36:57 -0700 Subject: [PATCH 01/31] fix: auto-add mask token for DFlash when tokenizer lacks one Models without a <|mask|> token (e.g., MiniMax-M2.7) would fail with ValueError during DFlash training. Instead of requiring the user to manually set dflash_mask_token_id, add the token to the tokenizer and resize model embeddings automatically. Co-Authored-By: Claude Opus 4.6 Signed-off-by: Ye Yu --- examples/speculative_decoding/main.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index 9b7a9f44d2e..bdc6014826c 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -219,13 +219,16 @@ def train(): # Load draft vocab cache mtsp.plugins.HFEagleModel.load_draft_vocab_cache(model, recipe.data.draft_vocab_cache) elif isinstance(recipe, ModelOptDFlashRecipe): - # Fall back to tokenizer.mask_token_id when not set in the recipe; require one of the two. if recipe.dflash.dflash_mask_token_id is None: recipe.dflash.dflash_mask_token_id = getattr(tokenizer, "mask_token_id", None) if recipe.dflash.dflash_mask_token_id is None: - raise ValueError( - "dflash.dflash_mask_token_id is required: set it in the recipe YAML " - "or use a tokenizer that defines mask_token_id." + mask_token = "<|mask|>" + tokenizer.add_special_tokens({"mask_token": mask_token}) + model.resize_token_embeddings(len(tokenizer)) + recipe.dflash.dflash_mask_token_id = tokenizer.mask_token_id + print_rank_0( + f"Added {mask_token} (ID={tokenizer.mask_token_id}), " + f"resized embeddings to {len(tokenizer)}" ) dflash_cfg: dict = recipe.dflash.model_dump() mtsp.convert(model, [("dflash", dflash_cfg)]) From 4a83b91bd69c345cb6f7f3e2126015567a7ad60d Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Fri, 17 Apr 2026 09:39:48 -0700 Subject: [PATCH 02/31] feat: add requeue support to build_slurm_executor When slurm_config.requeue is True, set additional_parameters["requeue"] = True so nemo-run emits #SBATCH --requeue in the sbatch script. Co-Authored-By: Claude Opus 4.6 Signed-off-by: Ye Yu --- tools/launcher/core.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tools/launcher/core.py b/tools/launcher/core.py index 7d53cca0db4..278a14d4aa4 100644 --- a/tools/launcher/core.py +++ b/tools/launcher/core.py @@ -297,7 +297,10 @@ def build_slurm_executor( retries=0, packager=packager, srun_args=slurm_config.srun_args, + additional_parameters=getattr(slurm_config, "additional_parameters", None) or {}, ) + if getattr(slurm_config, "requeue", False): + executor.additional_parameters["requeue"] = True return executor From a26022e798ca04c95cd6cc6c134d3258d57a557c Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Mon, 20 Apr 2026 14:38:36 -0700 Subject: [PATCH 03/31] feat: FSDP2 efficient loading + requeue retries fix MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. main.py: When FSDP2 cpu_ram_efficient_loading is active, only rank 0 loads real weights on CPU; other ranks use meta device. FSDP2 distributes from rank 0. Also adds dp_replicate_size auto-computation so dp_replicate * dp_shard * cp == world_size. 2. core.py: Set retries=3 when requeue is requested. The nemo-run sbatch wrapper only calls scontrol requeue when TORCHX_MAX_RETRIES > SLURM_RESTART_COUNT — retries=0 (the default) disabled requeue. Co-Authored-By: Claude Opus 4.6 Signed-off-by: Ye Yu --- tools/launcher/core.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tools/launcher/core.py b/tools/launcher/core.py index 278a14d4aa4..6eb9473aa71 100644 --- a/tools/launcher/core.py +++ b/tools/launcher/core.py @@ -301,6 +301,10 @@ def build_slurm_executor( ) if getattr(slurm_config, "requeue", False): executor.additional_parameters["requeue"] = True + # The nemo-run sbatch wrapper only calls `scontrol requeue` when + # TORCHX_MAX_RETRIES > SLURM_RESTART_COUNT. retries=0 (the default) + # disables this, so bump it when requeue is requested. + executor.retries = max(executor.retries, 3) return executor From b43c7f9f62599a564542568ce2f4fe92c575f4cb Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Mon, 1 Jun 2026 14:21:57 -0700 Subject: [PATCH 04/31] fix: restore FSDP2 buffer patch and remove buffer-to-CUDA block for checkpoint resume The Pydantic recipe refactor dropped fsdp2_buffer_patch.apply() and patch_accelerator() calls and added a buffer-to-CUDA block that moved DFlash buffers before FSDP wrapping. With cpu_ram_efficient_loading, non-rank-0 processes have meta-device params, causing _infer_parameter_dtype() to return fp32 instead of bf16 on resume. Also detects FSDP distributed checkpoints (no HF model files) and loads the base model instead of trying from_pretrained on them. Co-Authored-By: Claude Opus 4.6 Signed-off-by: Ye Yu --- examples/speculative_decoding/main.py | 46 ++++++++++++++++++--------- 1 file changed, 31 insertions(+), 15 deletions(-) diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index bdc6014826c..92954df39cf 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -64,6 +64,11 @@ torch.manual_seed(0) mto.enable_huggingface_checkpointing() +if os.environ.get("PATCH_FSDP2_BUFFERS") == "1": + import fsdp2_buffer_patch + + fsdp2_buffer_patch.apply() + # HF-compatible TrainingArguments with our speculative-decoding extensions, auto-derived # from :class:`SpecTrainingArgs` so its field set can't drift from the Pydantic recipe schema. @@ -181,7 +186,14 @@ def train(): use_offline_training = recipe.data.mode != "online" - if checkpoint: + # Check if checkpoint has HF-format model files (compatible with from_pretrained). + # FSDP distributed checkpoints (pytorch_model_fsdp_*) don't — load base model instead. + _hf_ckpt_files = ("model.safetensors", "pytorch_model.bin", "model.safetensors.index.json") + checkpoint_is_hf = checkpoint and any( + os.path.isfile(os.path.join(checkpoint, f)) for f in _hf_ckpt_files + ) + + if checkpoint_is_hf: with patch_transformers5_params_loading(): model = load_vlm_or_llm( checkpoint, dtype="auto", trust_remote_code=recipe.model.trust_remote_code @@ -190,6 +202,11 @@ def train(): checkpoint, trust_remote_code=recipe.model.trust_remote_code ) else: + if checkpoint: + print_rank_0( + f"Checkpoint {checkpoint} is not in HF format (FSDP distributed checkpoint). " + f"Loading base model and resuming via Trainer." + ) model_name_or_path = recipe.model.model_name_or_path if model_name_or_path is None: raise ValueError( @@ -224,7 +241,10 @@ def train(): if recipe.dflash.dflash_mask_token_id is None: mask_token = "<|mask|>" tokenizer.add_special_tokens({"mask_token": mask_token}) + orig_dtype = model.dtype model.resize_token_embeddings(len(tokenizer)) + if model.dtype != orig_dtype: + model.to(orig_dtype) recipe.dflash.dflash_mask_token_id = tokenizer.mask_token_id print_rank_0( f"Added {mask_token} (ID={tokenizer.mask_token_id}), " @@ -248,20 +268,6 @@ def train(): ) return - # Move any remaining CPU buffers to CUDA so DDP (NCCL-only) can broadcast - # them. We iterate named_buffers and reassign via the owning module to - # keep the module tree consistent. Parameters are left on CPU — the HF - # Trainer will move them during init. - if torch.cuda.is_available(): - _target_dev = torch.device("cuda", 0) - for name, buf in list(model.named_buffers()): - if buf.device.type == "cpu": - parts = name.split(".") - mod = model - for p in parts[:-1]: - mod = getattr(mod, p) - setattr(mod, parts[-1], buf.to(_target_dev)) - print_rank_0("Loading dataset...") is_dflash = isinstance(recipe, ModelOptDFlashRecipe) data_module = make_speculative_data_module( @@ -296,6 +302,9 @@ def train(): **data_module, ) + if os.environ.get("PATCH_FSDP2_BUFFERS") == "1": + fsdp2_buffer_patch.patch_accelerator(trainer.accelerator) + # Manually enable this to return loss in eval trainer.can_return_loss = True # Make sure label_smoother is None @@ -303,6 +312,13 @@ def train(): "label_smoother is not supported in speculative decoding!" ) + if is_master(): + dtypes = {} + for name, p in trainer.model.named_parameters(): + dtypes.setdefault(str(p.dtype), []).append(name) + for dt, names in dtypes.items(): + print(f"[dtype_check] {dt}: {len(names)} params (e.g. {names[0]})") + print_rank_0("Start training...") trainer.train(resume_from_checkpoint=checkpoint) trainer.save_state() From 967a2c03f05fdb5e03ad624ff2fde86265840437 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Mon, 1 Jun 2026 14:50:00 -0700 Subject: [PATCH 05/31] fix: use broadcast dtype instead of local param dtype in FSDP2 buffer patch _infer_parameter_dtype() reads the model's current param dtype to cast the broadcasted tensor. With cpu_ram_efficient_loading, non-rank-0 processes have fp32 meta-device params for DFlash, so _infer_parameter_dtype returns fp32 and _finish() casts the correctly- broadcasted bf16 tensor back to fp32. Use bcast_dtype (from rank 0) instead. Also prints dtype_check on all ranks to verify consistency. Co-Authored-By: Claude Opus 4.6 Signed-off-by: Ye Yu --- .../fsdp2_buffer_patch.py | 326 ++++++++++++++++++ examples/speculative_decoding/main.py | 13 +- 2 files changed, 333 insertions(+), 6 deletions(-) create mode 100644 examples/speculative_decoding/fsdp2_buffer_patch.py diff --git a/examples/speculative_decoding/fsdp2_buffer_patch.py b/examples/speculative_decoding/fsdp2_buffer_patch.py new file mode 100644 index 00000000000..3de31a0ee3b --- /dev/null +++ b/examples/speculative_decoding/fsdp2_buffer_patch.py @@ -0,0 +1,326 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Monkey-patch for accelerate's fsdp2_load_full_state_dict buffer handling. + +Problem +------- +accelerate's ``fsdp2_load_full_state_dict`` (called during model preparation +when ``cpu_ram_efficient_loading=True``) iterates ``model.state_dict()`` and +unconditionally accesses ``.device_mesh`` on every entry, assuming they are all +DTensors. After ``fully_shard()``, **parameters** become DTensors but +**persistent buffers** (e.g., rotary-embedding ``inv_freq``) remain plain +``torch.Tensor``. This crashes with:: + + AttributeError: 'Tensor' object has no attribute 'device_mesh' + +Additionally, ``cpu_ram_efficient_loading`` causes a dtype divergence: rank 0 +loads the model on CPU (inheriting the checkpoint's dtype, e.g. bfloat16) while +other ranks use ``meta`` device (defaulting to float32 for newly-added modules +like the DFlash head). After ``fully_shard()``, the DTensor dtypes differ +across ranks for these modules. Since ``dist.broadcast()`` requires matching +dtypes and element sizes on all ranks, broadcasting a bfloat16 tensor (2 +bytes/elem) to a float32 receive buffer (4 bytes/elem) causes an NCCL deadlock. + +Why we need FSDP2 via accelerate config (not ParallelismConfig) +--------------------------------------------------------------- +MiniMax-M2.7's ``trust_remote_code`` model code requires **transformers 4.57.x**. +Transformers' native FSDP2 support via ``ParallelismConfig`` requires +**transformers 5.x**. So we fall back to configuring FSDP2 through an +``accelerate`` YAML config file (``accelerate_fsdp2.yaml``), which works with +transformers 4.57.x. We set ``dp_shard_size=1`` to prevent ``main.py`` from +creating a ``ParallelismConfig``, letting the accelerate config handle sharding. + +Why we need cpu_ram_efficient_loading +------------------------------------- +MiniMax-M2.7 is a 229B MoE model (~230 GB in FP8). Each GB200 node has 4 GPUs +and ~800 GB system RAM. Without ``cpu_ram_efficient_loading``, all 4 ranks per +node load the model to CPU simultaneously (4 × 230 GB ≈ 920 GB), exceeding +system RAM and triggering OOM kills. With ``cpu_ram_efficient_loading``, only +rank 0 loads the model; other ranks initialize on ``meta`` device. The weights +are then broadcast via ``fsdp2_load_full_state_dict`` — which is where the bug +hits. + +What this patch does +-------------------- +1. Before accessing ``.device_mesh``, checks ``isinstance(entry, DTensor)``. + For non-DTensor entries (persistent buffers), broadcasts the raw tensor from + rank 0 without calling ``distribute_tensor()``. + +2. All ranks iterate ``model.state_dict()`` (post-shard) in the same order so + broadcast calls match 1-to-1. Rank 0 looks up the full parameter **by key + name** from the pre-shard state dict — never by positional ``zip``, because + ``model.to("meta")`` + ``fully_shard()`` can reorder keys. + +3. **Dtype synchronization**: rank 0 broadcasts a dtype code for each entry + before the main loop. All ranks then use the same dtype for their broadcast + tensors. This fixes the dtype divergence caused by rank 0 loading in + bfloat16 while other ranks default to float32 for newly-added modules. + +Accelerate config constraints (for reference) +---------------------------------------------- +``accelerate_fsdp2.yaml`` also requires: + +- ``fsdp_use_orig_params: true`` — without this, FSDP flattens all params into + FlatParameter, losing per-parameter ``requires_grad`` flags. The DFlash head + can't train because its gradients mix with frozen base model zeros. +- ``fsdp_transformer_layer_cls_to_wrap: MiniMaxM2DecoderLayer,DFlashModule`` — + DFlash head params at the model root must be in the wrap policy so they become + DTensors. Without this, ``fsdp2_load_full_state_dict`` also crashes. +- ``fsdp_sync_module_states: true`` — accelerate's launch validator requires it + when ``cpu_ram_efficient_loading`` is enabled, even though FSDP2 ignores it at + runtime (sets it to None with a warning). + +Does NOT affect models on transformers 5.x +------------------------------------------- +This entire workaround exists ONLY because MiniMax-M2.7 requires +transformers 4.57.x. Models that support transformers 5.x (Qwen, Llama, +Nemotron, etc.) use ``ParallelismConfig`` natively by setting +``dp_shard_size > 1`` in the training args. That code path handles buffers +correctly and does not go through ``fsdp2_load_full_state_dict`` at all. +No accelerate config file, no ``PATCH_FSDP2_BUFFERS``, no +``OVERRIDE_TRANSFORMERS`` needed. + +When to remove +-------------- +This patch can be removed when EITHER of these happens: + +1. MiniMax updates ``trust_remote_code`` for transformers 5.x, allowing native + ``ParallelismConfig`` (which handles this correctly). +2. Upstream accelerate fixes ``fsdp2_load_full_state_dict`` to skip non-DTensor + entries. Track: https://github.com/huggingface/accelerate + +Activation +---------- +Set ``PATCH_FSDP2_BUFFERS=1`` in the environment to activate. Off by default. +Only needed in MiniMax-M2.7 pipeline YAMLs. +""" + +import torch + + +# Dtype encoding for the broadcast dtype-sync step. +_DTYPE_TO_CODE = { + torch.float32: 0, + torch.bfloat16: 1, + torch.float16: 2, + torch.float8_e4m3fn: 3 if hasattr(torch, "float8_e4m3fn") else -1, +} +_CODE_TO_DTYPE = {v: k for k, v in _DTYPE_TO_CODE.items() if v >= 0} + + +def apply(): + """Patch fsdp2_load_full_state_dict if the buffer bug is present.""" + try: + import accelerate.utils.fsdp_utils as fsdp_utils + from torch.distributed.tensor import DTensor + + _orig = fsdp_utils.fsdp2_load_full_state_dict # noqa: F841 + + def _patched(accelerator, model, full_sd, cpu_offload=False): + import time + import torch.distributed as dist + from torch.distributed.tensor import distribute_tensor + + meta_sharded_sd = model.state_dict() + sharded_sd = {} + n_total = len(meta_sharded_sd) + n_dtensor = sum(1 for v in meta_sharded_sd.values() if isinstance(v, DTensor)) + n_buffer = n_total - n_dtensor + + if accelerator.is_main_process: + print(f"[fsdp2_buffer_patch] State dict: {n_total} entries " + f"({n_dtensor} DTensor, {n_buffer} buffer), full_sd: {len(full_sd)}") + else: + print(f"[fsdp2_buffer_patch] State dict: {n_total} entries " + f"({n_dtensor} DTensor, {n_buffer} buffer)") + t0 = time.time() + + # --- Step 0: broadcast dtype codes from rank 0 --- + # cpu_ram_efficient_loading causes rank 0 to load in bfloat16 while + # other ranks default to float32 for newly-added modules (DFlash). + # After fully_shard(), DTensor dtypes diverge across ranks. + # Broadcast rank 0's dtypes so all ranks use the same dtype for + # each broadcast tensor. + if accelerator.is_main_process: + dtype_codes = torch.tensor( + [_DTYPE_TO_CODE.get(full_sd[name].dtype, 0) + for name in meta_sharded_sd.keys()], + dtype=torch.int32, device=accelerator.device, + ) + else: + dtype_codes = torch.empty( + n_total, dtype=torch.int32, device=accelerator.device, + ) + dist.broadcast(dtype_codes, src=0, group=dist.group.WORLD) + broadcast_dtypes = [_CODE_TO_DTYPE[c.item()] for c in dtype_codes] + del dtype_codes + + # Infer dtype/contiguity for cast — copied from upstream + def _infer_parameter_dtype(mdl, param_name, empty_param): + try: + old_param = mdl.get_parameter_or_buffer(param_name) + except AttributeError: + base, local = param_name.rsplit(".", 1) + old_param = getattr(mdl.get_submodule(base), local) + is_f8 = hasattr(torch, "float8_e4m3fn") and empty_param.dtype == torch.float8_e4m3fn + casting_dtype = ( + old_param.dtype if (empty_param.dtype.is_floating_point and not is_f8) else None + ) + return old_param is not None and old_param.is_contiguous(), casting_dtype + + def _finish(st, contig, dtype, offload): + if dtype is not None: + st = st.to(dtype=dtype) + if contig: + st = st.contiguous() + if offload: + st = st.to("cpu") + return st + + # --- Step 1: broadcast all entries --- + # All ranks iterate meta_sharded_sd in the same order to ensure + # identical broadcast sequences. Rank 0 looks up the full parameter + # by name — never positional zip (model.to("meta") + fully_shard() + # can reorder keys). broadcast_dtypes[idx] is used for the tensor + # dtype on ALL ranks to prevent NCCL deadlocks from dtype divergence. + for idx, (param_name, sharded_param) in enumerate(meta_sharded_sd.items()): + is_dtensor = isinstance(sharded_param, DTensor) + bcast_dtype = broadcast_dtypes[idx] + + if not is_dtensor: + # Persistent buffer — broadcast raw, no distribute_tensor + if accelerator.is_main_process: + t = full_sd[param_name].detach().to( + device=accelerator.device, dtype=bcast_dtype) + else: + t = torch.empty( + sharded_param.size(), + device=accelerator.device, + dtype=bcast_dtype, + ) + dist.broadcast(t, src=0, group=dist.group.WORLD) + sharded_sd[param_name] = t + continue + + device_mesh = sharded_param.device_mesh + if accelerator.is_main_process: + ft = full_sd[param_name].detach().to( + device=device_mesh.device_type, dtype=bcast_dtype) + if isinstance(ft, DTensor): + ft = ft.to_local() + else: + ft = torch.empty( + sharded_param.size(), + device=device_mesh.device_type, + dtype=bcast_dtype, + ) + dist.broadcast(ft, src=0, group=dist.group.WORLD) + st = distribute_tensor(ft, device_mesh, sharded_param.placements) + contig, _ = _infer_parameter_dtype(model, param_name, ft) + # Use bcast_dtype (from rank 0) instead of the model's local + # param dtype — with cpu_ram_efficient_loading, non-rank-0 + # processes have fp32 meta-device params for DFlash, and + # _infer_parameter_dtype would incorrectly cast bf16 back to fp32. + is_f8 = hasattr(torch, "float8_e4m3fn") and bcast_dtype == torch.float8_e4m3fn + final_dtype = None if is_f8 else bcast_dtype + sharded_sd[param_name] = _finish(st, contig, final_dtype, cpu_offload) + + elapsed = time.time() - t0 + print(f"[fsdp2_buffer_patch] Broadcast done in {elapsed:.1f}s, " + f"loading {len(sharded_sd)} entries into model...") + model.load_state_dict(sharded_sd, assign=True) + print(f"[fsdp2_buffer_patch] State dict loaded successfully " + f"({time.time() - t0:.1f}s total)") + return model + + fsdp_utils.fsdp2_load_full_state_dict = _patched + print("[fsdp2_buffer_patch] Patched fsdp2_load_full_state_dict for buffer compatibility") + except Exception as e: + print(f"[fsdp2_buffer_patch] Patch skipped: {e}") + + +_clip_grad_norm_call_count = 0 + + +def _clip_grad_norm(parameters, max_norm, norm_type=2): + """Clip gradient norms for FSDP2 DTensor parameters. + + Bypasses DTensor dispatch (which deadlocks with partially-frozen models + on the accelerate FSDP2 path) by extracting local tensor shards and + doing an explicit all_reduce for the global norm. + + Handles Shard (need all_reduce) and Replicate/regular (already global) + placements. Safe for DFlash-only and LoRA co-training. + """ + global _clip_grad_norm_call_count + import torch.distributed as dist + from torch.distributed.tensor import DTensor + from torch.distributed.tensor.placement_types import Shard + + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + + parameters = [p for p in parameters] # materialize generator + max_norm = float(max_norm) + norm_type = float(norm_type) + + grads = [p.grad for p in parameters if p.grad is not None] + if len(grads) == 0: + return torch.tensor(0.0) + + # Shard DTensors hold partial data — need all_reduce for global norm. + # Replicate DTensors and regular tensors already hold full data. + device = grads[0]._local_tensor.device if isinstance(grads[0], DTensor) else grads[0].device + sharded_norm_p = torch.tensor(0.0, device=device) + local_norm_p = torch.tensor(0.0, device=device) + + n_sharded = 0 + n_replicate = 0 + n_regular = 0 + + for g in grads: + if isinstance(g, DTensor): + is_sharded = any(isinstance(p, Shard) for p in g.placements) + t = g._local_tensor.detach().to(torch.float32) + n = torch.linalg.vector_norm(t, norm_type) + if is_sharded: + sharded_norm_p += n.pow(norm_type) + n_sharded += 1 + else: + local_norm_p += n.pow(norm_type) + n_replicate += 1 + else: + n = torch.linalg.vector_norm(g.detach().to(torch.float32), norm_type) + local_norm_p += n.pow(norm_type) + n_regular += 1 + + dist.all_reduce(sharded_norm_p, op=dist.ReduceOp.SUM) + total_norm = (sharded_norm_p + local_norm_p).pow(1.0 / norm_type) + + clip_coef = torch.clamp(max_norm / (total_norm + 1e-6), max=1.0) + + # Debug: log computation breakdown on first 5 calls (no collectives — safe). + _clip_grad_norm_call_count += 1 + if _clip_grad_norm_call_count <= 5 and dist.get_rank() == 0: + print( + f"[clip_grad_norm debug] call={_clip_grad_norm_call_count} " + f"total_norm={total_norm.item():.6f} " + f"sharded_norm_p={sharded_norm_p.item():.6f} local_norm_p={local_norm_p.item():.6f} " + f"grads={len(grads)} (sharded={n_sharded} replicate={n_replicate} regular={n_regular}) " + f"max_norm={max_norm} clip_coef={clip_coef.item():.6f}" + ) + for g in grads: + if isinstance(g, DTensor): + g._local_tensor.mul_(clip_coef) + else: + g.mul_(clip_coef) + + return total_norm + + +def patch_accelerator(accelerator): + """Replace accelerator's clip_grad_norm_ with FSDP2-safe version.""" + accelerator.clip_grad_norm_ = _clip_grad_norm + print("[fsdp2_buffer_patch] Patched accelerator.clip_grad_norm_ " + "for FSDP2 DTensor compatibility") diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index 92954df39cf..f79c68c5d39 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -312,12 +312,13 @@ def train(): "label_smoother is not supported in speculative decoding!" ) - if is_master(): - dtypes = {} - for name, p in trainer.model.named_parameters(): - dtypes.setdefault(str(p.dtype), []).append(name) - for dt, names in dtypes.items(): - print(f"[dtype_check] {dt}: {len(names)} params (e.g. {names[0]})") + rank = int(os.environ.get("RANK", 0)) + dtypes = {} + for name, p in trainer.model.named_parameters(): + dt_key = str(p.dtype) if not hasattr(p, "_local_tensor") else str(p._local_tensor.dtype) + dtypes.setdefault(dt_key, []).append(name) + for dt, names in dtypes.items(): + print(f"[dtype_check rank={rank}] {dt}: {len(names)} params (e.g. {names[0]})") print_rank_0("Start training...") trainer.train(resume_from_checkpoint=checkpoint) From ca89c84892ca96c645d3f8d035eae4dce94152c8 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Wed, 3 Jun 2026 11:37:49 -0700 Subject: [PATCH 06/31] feat: restore DFlashExportCallback for per-checkpoint draft export The Pydantic-recipe refactor (7038dec918) dropped DFlashExportCallback, which had exported the DFlash draft submodule after every checkpoint save. Without it, FSDP2 sharded checkpoints (pytorch_model_fsdp_0/, no model.safetensors) get no exported-checkpoint-{step}/, so downstream vLLM deployment / acceptance-length eval has nothing to load. The verify-only comment 'export happens during training via DFlashExportCallback' was left behind but the callback itself was gone. Restore the callback (gathers only the ~328MB draft submodule across shards via get_model_state_dict, so it works under SHARDED_STATE_DICT without materializing the full base model) and wire it into main.py for DFlash recipes. Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: Ye Yu --- examples/speculative_decoding/eagle_utils.py | 83 ++++++++++++++++++++ examples/speculative_decoding/main.py | 8 ++ 2 files changed, 91 insertions(+) diff --git a/examples/speculative_decoding/eagle_utils.py b/examples/speculative_decoding/eagle_utils.py index f9675e54161..ba929eb9a91 100644 --- a/examples/speculative_decoding/eagle_utils.py +++ b/examples/speculative_decoding/eagle_utils.py @@ -228,6 +228,89 @@ def on_step_begin(self, args, state, control, **kwargs): return control +class DFlashExportCallback(TrainerCallback): + """Export DFlash draft module after each checkpoint save. + + Under FSDP2 SHARDED_STATE_DICT, checkpoints only contain distributed shards + (pytorch_model_fsdp_0/), not model.safetensors. This callback extracts the + small draft module weights and saves them in deployment format after each save. + """ + + def on_save(self, args, state, control, **kwargs): + """Export DFlash draft module weights + config after checkpoint save.""" + import json + import os + + from safetensors.torch import save_file + + model = kwargs["model"] + if not hasattr(model, "dflash_module"): + return control + + step = state.global_step + export_dir = os.path.join(args.output_dir, f"exported-checkpoint-{step}") + + # All ranks participate in state_dict gather (FSDP2 collective op). + # Use get_model_state_dict to get the full (ungathered) weights regardless + # of fsdp_state_dict_type setting. Only the dflash_module submodule is + # gathered (~328 MB for MiniMax-M2.7), not the full 229B base model. + try: + from torch.distributed.checkpoint.state_dict import ( + StateDictOptions, + get_model_state_dict, + ) + + options = StateDictOptions(full_state_dict=True, cpu_offload=True) + try: + raw_sd = get_model_state_dict( + model, submodules={model.dflash_module}, options=options + ) + except TypeError: + # Older PyTorch without submodules parameter — gather full model + raw_sd = get_model_state_dict(model, options=options) + except ImportError: + # Non-distributed / single-GPU fallback + raw_sd = model.state_dict() + + # Extract dflash_module keys and strip prefix + drafter_sd = {} + for key, value in raw_sd.items(): + if "dflash_module." in key: + export_key = key.split("dflash_module.", 1)[1] + if "rotary_emb" not in export_key: + drafter_sd[export_key] = value.cpu() if value.device.type != "cpu" else value + elif not any(prefix in key for prefix in ("model.", "lm_head.", "embed_tokens.")): + # Keys already without prefix (from submodule state_dict) + if "rotary_emb" not in key: + drafter_sd[key] = value.cpu() if value.device.type != "cpu" else value + del raw_sd + + if not drafter_sd: + print_rank_0(f"Warning: No dflash_module weights found at step {step}, skipping export") + return control + + # Only rank 0 writes files + if is_master(): + try: + os.makedirs(export_dir, exist_ok=True) + save_file(drafter_sd, os.path.join(export_dir, "model.safetensors")) + + exporter = model.get_exporter() + config = exporter._export_config() + with open(os.path.join(export_dir, "config.json"), "w") as f: + json.dump(config, f, indent=2) + + total_mb = sum(v.nbytes for v in drafter_sd.values()) / 1024 / 1024 + print_rank_0( + f"Exported DFlash draft ({len(drafter_sd)} tensors, {total_mb:.0f}MB) " + f"to {export_dir}" + ) + except Exception as e: + print_rank_0(f"Warning: DFlash export failed at step {step}: {e}") + + return control + + class EagleTrainingPlot(TrainerCallback): """Callback that plot training acc and AR during training.""" diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index f79c68c5d39..832ee0f72a3 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -36,6 +36,7 @@ import torch import transformers from eagle_utils import ( + DFlashExportCallback, EagleTrainerWithAccLog, EagleTrainingPlot, LoRAWarmupCallback, @@ -194,6 +195,7 @@ def train(): ) if checkpoint_is_hf: + assert checkpoint is not None # guaranteed by checkpoint_is_hf with patch_transformers5_params_loading(): model = load_vlm_or_llm( checkpoint, dtype="auto", trust_remote_code=recipe.model.trust_remote_code @@ -293,6 +295,12 @@ def train(): training_args.ignore_data_skip = True callbacks.append(StreamingResumeCallback()) + # DFlash: export the draft submodule after every checkpoint save. Under + # FSDP2 SHARDED_STATE_DICT, checkpoint-* dirs hold only distributed shards; + # this callback gathers just the small draft module and writes a deployable + # exported-checkpoint-{step}/ that vLLM (and the AL tests) can load. + if isinstance(recipe, ModelOptDFlashRecipe): + callbacks.append(DFlashExportCallback()) trainer = EagleTrainerWithAccLog( model=model, From 7a215364ef981234ebe738da3353a6b4dfd0240e Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Tue, 9 Jun 2026 10:26:34 -0700 Subject: [PATCH 07/31] feat: MiniMax-M2.7-DFlash launcher example (online + offline + specdec_bench) Adds a self-contained launcher example so MiniMax-M2.7 (229B MoE) DFlash training is reproducible end-to-end, plus two common-script enablers it needs: - tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/: - hf_online_dflash.yaml (train -> vLLM smoke -> AR eval; 8-node FSDP2) - hf_offline_dflash.yaml (vLLM hidden-state dump -> offline train) - specdec_bench.yaml (qualitative + throughput_32k, DFLASH) - accelerate_fsdp2_hybrid.yaml, chat_template_train.jinja - common/specdec/dflash_online_training.sh: honor OVERRIDE_TRANSFORMERS, ACCELERATE_CONFIG, and MIXED_PRECISION so trust_remote_code MoE models that need FSDP2 via an accelerate config (MiniMax-M2.7 on transformers 4.57.x) work. - collect_hidden_states/compute_hidden_states_vllm.py: disable prefix caching. With it on, vLLM serves shared prefixes from cache in block chunks and the hidden-state connector emits only the fresh suffix, so dumped hidden_states came out short by N*block_size vs input_ids/loss_mask (observed gaps 0/16/32). Validated: all three YAMLs resolve under launch.py --dryrun; online training and the vLLM hidden-state dump (aligned output) were exercised on CW-DFW. Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: Ye Yu --- .../compute_hidden_states_vllm.py | 5 + .../common/specdec/dflash_online_training.sh | 18 +- .../accelerate_fsdp2_hybrid.yaml | 11 ++ .../chat_template_train.jinja | 162 ++++++++++++++++++ .../hf_offline_dflash.yaml | 86 ++++++++++ .../MiniMax-M2.7-DFlash/hf_online_dflash.yaml | 97 +++++++++++ .../MiniMax-M2.7-DFlash/specdec_bench.yaml | 84 +++++++++ 7 files changed, 462 insertions(+), 1 deletion(-) create mode 100644 tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/accelerate_fsdp2_hybrid.yaml create mode 100644 tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/chat_template_train.jinja create mode 100644 tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_offline_dflash.yaml create mode 100644 tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_online_dflash.yaml create mode 100644 tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/specdec_bench.yaml diff --git a/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py b/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py index dd496480cbb..8166960a189 100644 --- a/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py +++ b/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py @@ -177,6 +177,11 @@ def keep_conversation(entry): max_model_len=args.max_seq_len, trust_remote_code=args.trust_remote_code, enable_chunked_prefill=False, # required by extract_hidden_states + # With prefix caching on, vLLM serves shared prefixes from cache in block-sized + # chunks and the hidden-state connector only emits the freshly-computed suffix, so + # the dumped hidden_states come out short by N*block_size vs the full input_ids / + # loss_mask. Disabling it forces a full prefill so every token's state is dumped. + enable_prefix_caching=False, speculative_config={ "method": "extract_hidden_states", "num_speculative_tokens": 1, diff --git a/tools/launcher/common/specdec/dflash_online_training.sh b/tools/launcher/common/specdec/dflash_online_training.sh index 60c2b9901e1..efc38be3b17 100644 --- a/tools/launcher/common/specdec/dflash_online_training.sh +++ b/tools/launcher/common/specdec/dflash_online_training.sh @@ -42,6 +42,13 @@ pip install -r modules/Model-Optimizer/examples/speculative_decoding/requirement pip install huggingface-hub>=1.2.1 export PATH=$PATH:/workspace/.local/bin +# Some trust_remote_code MoE models pin an older transformers (e.g. MiniMax-M2.7 +# needs 4.57.x; its modeling code is incompatible with the 5.x the recipe defaults +# pull in). Override here when the model needs it. +if [ -n "${OVERRIDE_TRANSFORMERS:-}" ]; then + pip install "transformers==${OVERRIDE_TRANSFORMERS}" +fi + ################################################################################################### trap 'error_handler $0 $LINENO' ERR @@ -99,9 +106,18 @@ fi export TOKENIZERS_PARALLELISM=False +# ACCELERATE_CONFIG selects an explicit accelerate config file (e.g. an FSDP2 YAML). +# Required for trust_remote_code MoE models that need FSDP2 via accelerate config +# rather than transformers-native ParallelismConfig (MiniMax-M2.7 on transformers 4.57.x). +ACCEL_CONFIG_ARGS="" +if [ -n "${ACCELERATE_CONFIG:-}" ] && [ -f "${ACCELERATE_CONFIG}" ]; then + ACCEL_CONFIG_ARGS="--config_file ${ACCELERATE_CONFIG}" + echo "Using accelerate config: ${ACCELERATE_CONFIG}" +fi + set -x start_time=$(date +%s) -accelerate launch --mixed_precision bf16 $MULTI_NODE_ARGS $MAIN_PY "$@" +accelerate launch $ACCEL_CONFIG_ARGS --mixed_precision "${MIXED_PRECISION:-bf16}" $MULTI_NODE_ARGS $MAIN_PY "$@" echo "Training time: $(( $(date +%s) - start_time )) seconds" set +x diff --git a/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/accelerate_fsdp2_hybrid.yaml b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/accelerate_fsdp2_hybrid.yaml new file mode 100644 index 00000000000..3310d1e5c7e --- /dev/null +++ b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/accelerate_fsdp2_hybrid.yaml @@ -0,0 +1,11 @@ +compute_environment: LOCAL_MACHINE +distributed_type: FSDP +fsdp_config: + fsdp_version: 2 + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_transformer_layer_cls_to_wrap: MiniMaxM2DecoderLayer,DFlashModule + fsdp_sharding_strategy: HYBRID_SHARD + fsdp_use_orig_params: true + fsdp_state_dict_type: SHARDED_STATE_DICT + fsdp_sync_module_states: true + fsdp_cpu_ram_efficient_loading: true diff --git a/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/chat_template_train.jinja b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/chat_template_train.jinja new file mode 100644 index 00000000000..178dc002069 --- /dev/null +++ b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/chat_template_train.jinja @@ -0,0 +1,162 @@ +{# MiniMax-M2.7 chat template with {% generation %} tags for answer_only_loss training. + Adapted from https://huggingface.co/MiniMaxAI/MiniMax-M2.7/blob/main/chat_template.jinja + with {% generation %} / {% endgeneration %} wrapping assistant content. +#} +{%- set toolcall_begin_token = '' -%} +{%- set toolcall_end_token = '' -%} +{#- Tool Rendering Functions ============================================== -#} +{%- macro render_tool_namespace(namespace_name, tool_list) -%} +{%- for tool in tool_list -%} +{{ tool.function | tojson(ensure_ascii=False) }} +{% endfor -%} +{%- endmacro -%} +{%- macro visible_text(content) -%} + {%- if content is string -%} + {{ content }} + {%- elif content is iterable and content is not mapping -%} + {%- for item in content -%} + {%- if item is mapping and item.type == 'text' -%} + {{- item.text }} + {%- elif item is string -%} + {{- item }} + {%- endif -%} + {%- endfor -%} + {%- else -%} + {{- content }} + {%- endif -%} +{%- endmacro -%} +{#- System Message Construction ============================================ -#} +{%- macro build_system_message(system_message) -%} + {%- if system_message and system_message.content -%} + {{- visible_text(system_message.content) }} + {%- else -%} + {%- if model_identity is not defined -%} + {%- set model_identity = "You are a helpful assistant. Your name is MiniMax-M2.7 and is built by MiniMax." -%} + {%- endif -%} + {{- model_identity }} + {%- endif -%} + + {#- Handle current_date -#} + {%- if system_message and system_message.current_date -%} + {{- '\n' ~ 'Current date: ' + system_message.current_date }} + {%- endif -%} + {#- Handle current_location -#} + {%- if system_message and system_message.current_location -%} + {{- '\n' ~ 'Current location: ' + system_message.current_location }} + {%- endif -%} +{%- endmacro -%} +{#- Main Template Logic ================================================= -#} +{#- Extract system message (only first message if it's system) -#} +{%- set system_message = none -%} +{%- set conversation_messages = messages -%} +{%- if messages and messages[0].role == "system" -%} + {%- set system_message = messages[0] -%} + {%- set conversation_messages = messages[1:] -%} +{%- endif -%} +{#- Get the last user message turn, for interleaved thinking -#} +{%- set ns = namespace(last_user_index=-1) %} +{% for m in conversation_messages %} + {%- if m.role == 'user' %} + {% set ns.last_user_index = loop.index0 -%} + {%- endif %} +{%- endfor %} +{#- Render system message -#} +{{- ']~!b[' ~ ']~b]system' ~ '\n' }} +{{- build_system_message(system_message) }} +{#- Render tools if available -#} +{%- if tools -%} + {{- '\n\n' ~ '# Tools' ~ '\n' ~ 'You may call one or more tools to assist with the user query.\nHere are the tools available in JSONSchema format:' ~ '\n' }} + {{- '\n' ~ '' ~ '\n' }} + {{- render_tool_namespace("functions", tools) }} + {{- '' ~ '\n\n' }} +{{- 'When making tool calls, use XML format to invoke tools and pass parameters:' ~ '\n' }} +{{- '\n' ~ toolcall_begin_token }} + +param-value-1 +param-value-2 +... + +{{- '\n' ~ toolcall_end_token }} +{%- endif -%} +{{- '[e~[\n' }} + +{#- Render messages -#} +{%- set last_tool_call = namespace(name=none) -%} +{%- for message in conversation_messages -%} + {%- if message.role == 'assistant' -%} + {{- ']~b]ai' ~ '\n' }} + {%- generation -%} + {%- set reasoning_content = '' %} + {%- set content = visible_text(message.content) %} + {%- if message.reasoning_content is string %} + {%- set reasoning_content = message.reasoning_content %} + {%- else %} + {%- if '' in content %} + {%- set reasoning_content = content.split('')[0].strip('\n').split('')[-1].strip('\n') %} + {%- set content = content.split('')[-1].strip('\n') %} + {%- endif %} + {%- endif %} + {%- if reasoning_content and loop.index0 > ns.last_user_index -%} + {{- '' ~ '\n' ~ reasoning_content ~ '\n' ~ '' ~ '\n\n' }} + {%- endif -%} + {%- if content -%} + {{- content }} + {%- endif -%} + {%- if message.tool_calls -%} + {{- '\n' ~ toolcall_begin_token ~ '\n' }} + + {%- for tool_call in message.tool_calls -%} + {%- if tool_call.function %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '' }} + {% set _args = tool_call.arguments %} + {%- for k, v in _args.items() %} + {{- '' }} + {{- v | tojson(ensure_ascii=False) if v is not string else v }} + {{- '' }} + {% endfor %} + {{- '' ~ '\n' }} + {%- endfor -%} + + {{- toolcall_end_token}} + {%- set last_tool_call.name = message.tool_calls[-1].name -%} + {%- else -%} + {%- set last_tool_call.name = none -%} + {%- endif -%} + {%- endgeneration -%} + {{- '[e~[' ~ '\n' }} + + {%- elif message.role == 'tool' -%} + {%- if last_tool_call.name is none -%} + {{- raise_exception("Message has tool role, but there was no previous assistant message with a tool call!") }} + {%- endif -%} + {%- if loop.first or (conversation_messages[loop.index0 - 1].role != 'tool') -%} + {{- ']~b]tool' }} + {%- endif -%} + {%- if message.content is string -%} + {{- '\n' }} + {{- message.content }} + {{- '' }} + {%- else -%} + {%- for tr in message.content -%} + {{- '\n' }} + {{- tr.output if tr.output is defined else (tr.text if tr.type == 'text' and tr.text is defined else tr) }} + {{- '\n' }} + {%- endfor -%} + {%- endif -%} + {%- if loop.last or (conversation_messages[loop.index0 + 1].role != 'tool') -%} + {{- '[e~[\n' -}} + {%- endif -%} + + {%- elif message.role == 'user' -%} + {{- ']~b]user' ~ '\n' }} + {{- visible_text(message.content) }} + {{- '[e~[' ~ '\n' }} + {%- endif -%} +{%- endfor -%} + +{#- Generation prompt -#} +{%- if add_generation_prompt -%} +{{- ']~b]ai' ~ '\n' ~ '' ~ '\n' }} +{%- endif -%} diff --git a/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_offline_dflash.yaml b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_offline_dflash.yaml new file mode 100644 index 00000000000..d81de3c8fbc --- /dev/null +++ b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_offline_dflash.yaml @@ -0,0 +1,86 @@ +# DFlash offline speculative decoding training for MiniMax-M2.7 (229B MoE). +# +# 2-step pipeline (compare with hf_online_dflash.yaml, which streams the base model +# forward at training time instead): +# task_0: Dump base-model hidden states once via vLLM extract_hidden_states. +# task_1: Train the DFlash draft on the dump (FakeBaseModel — loads only lm_head + +# embed_tokens, not the full 229B base). +# +# We use the vLLM dump (compute_hidden_states_vllm.py) rather than the HF dump because +# the 229B MoE is impractical to forward on a single GPU; vLLM shards it with TP. The +# dump disables prefix caching so every token's hidden state is emitted and the dumped +# sequence lengths line up with input_ids/loss_mask. +# +# Reference: "DFlash: Block Diffusion for Flash Speculative Decoding" (arXiv:2602.06036) +# +# Usage: +# uv run slurm.py --yaml modules/Model-Optimizer/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_offline_dflash.yaml --yes + +job_name: MiniMax-M2.7-DFlash_offline +pipeline: + global_vars: + hf_model: /hf-local/MiniMaxAI/MiniMax-M2.7 + + # Step 1: Dump base-model hidden states via vLLM extract_hidden_states (TP=4). + task_0: + script: common/eagle3/dump_offline_data_vllm.sh + args: + - --input-data /hf-local/modelopt/MiniMax-M2.7-synthetic-data-clean-v2 + - --output-dir /scratchspace/dflash_minimax_m2.7_hidden_states + # Must match the draft model's num_hidden_layers (recipe default: 5). + - --aux-layers dflash + - --answer-only-loss + - --chat-template examples/MiniMax/MiniMax-M2.7-DFlash/chat_template_train.jinja + - --max-seq-len 4096 + - --tp 4 + environment: + - HF_MODEL_CKPT: <> + - TRUST_REMOTE_CODE: "1" + slurm_config: + _factory_: "slurm_factory" + nodes: 1 + ntasks_per_node: 1 + gpus_per_node: 8 + container: vllm/vllm-openai:nightly + + # Step 2: Train DFlash offline on the dumped hidden states. FakeBaseModel avoids + # loading the full 229B — only lm_head + embed_tokens are read from the checkpoint. + task_1: + script: common/specdec/dflash_online_training.sh + args: + - --config modules/Model-Optimizer/modelopt_recipes/general/speculative_decoding/dflash.yaml + - model.model_name_or_path=<> + - model.trust_remote_code=true + - model.use_fake_base_for_offline=true + - data.mode=offline + - data.offline_data_path=/scratchspace/dflash_minimax_m2.7_hidden_states + - data.chat_template=examples/MiniMax/MiniMax-M2.7-DFlash/chat_template_train.jinja + - training.output_dir=/scratchspace/dflash_minimax_m2.7_offline + - training.num_train_epochs=10 + - training.per_device_train_batch_size=2 + - training.learning_rate=1.2e-3 + - training.warmup_steps=100 + - training.training_seq_len=4096 + - training.logging_steps=100 + - training.save_steps=400 + - training.disable_tqdm=true + - training.dp_shard_size=1 + - training.answer_only_loss=true + - training.ddp_timeout=3600 + - training.bf16=false + - dflash.dflash_self_logit_distillation=true + - dflash.dflash_block_size=8 + - dflash.dflash_num_anchors=512 + - dflash.dflash_loss_decay_factor=4.0 + - dflash.dflash_architecture_config.num_hidden_layers=5 + environment: + - NUM_NODES: "8" + - OVERRIDE_TRANSFORMERS: "4.57.1" + - ACCELERATE_CONFIG: examples/MiniMax/MiniMax-M2.7-DFlash/accelerate_fsdp2_hybrid.yaml + - PATCH_FSDP2_BUFFERS: "1" + - MIXED_PRECISION: "no" + slurm_config: + _factory_: "slurm_factory" + nodes: 8 + ntasks_per_node: 1 + gpus_per_node: 8 diff --git a/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_online_dflash.yaml b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_online_dflash.yaml new file mode 100644 index 00000000000..bfd4bfc9847 --- /dev/null +++ b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_online_dflash.yaml @@ -0,0 +1,97 @@ +# DFlash online speculative decoding training for MiniMax-M2.7 (229B MoE). +# +# 3-step pipeline: +# task_0: Online DFlash training (8 nodes x 8 GPU, FSDP2 HYBRID_SHARD) +# task_1: vLLM smoke test with the exported DFlash draft +# task_2: HF AR (acceptance-length) evaluation on MT-Bench (1 GPU) +# +# MiniMax-M2.7 specifics: +# - trust_remote_code model whose code requires transformers 4.57.x, so FSDP2 is +# configured via an accelerate config (accelerate_fsdp2_hybrid.yaml) rather than +# transformers-native ParallelismConfig. Hence OVERRIDE_TRANSFORMERS + ACCELERATE_CONFIG +# + dp_shard_size=1 (keeps main.py from building a ParallelismConfig) + PATCH_FSDP2_BUFFERS. +# - 229B in FP8 needs cpu_ram_efficient_loading (set in the accelerate config) and +# MIXED_PRECISION=no (the recipe / checkpoint already carry the right dtypes). +# - The DFlash draft is Qwen3-architecture (5 layers); the target/draft architectures +# are independent by design. +# +# Reference: "DFlash: Block Diffusion for Flash Speculative Decoding" (arXiv:2602.06036) +# +# Usage: +# uv run launch.py --yaml examples/MiniMax/MiniMax-M2.7-DFlash/hf_online_dflash.yaml --yes +# uv run slurm.py --yaml modules/Model-Optimizer/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_online_dflash.yaml --yes + +job_name: MiniMax-M2.7-DFlash_online +pipeline: + global_vars: + hf_model: /hf-local/MiniMaxAI/MiniMax-M2.7 + + # Step 1: Online DFlash training. + task_0: + script: common/specdec/dflash_online_training.sh + args: + - --config modules/Model-Optimizer/modelopt_recipes/general/speculative_decoding/dflash.yaml + - model.model_name_or_path=<> + - model.trust_remote_code=true + - data.data_path=/hf-local/modelopt/MiniMax-M2.7-synthetic-data-clean-v2 + - data.chat_template=examples/MiniMax/MiniMax-M2.7-DFlash/chat_template_train.jinja + - training.output_dir=/scratchspace/dflash_minimax_m2.7 + - training.num_train_epochs=10 + - training.per_device_train_batch_size=2 + - training.learning_rate=1.2e-3 + - training.warmup_steps=100 + - training.training_seq_len=4096 + - training.logging_steps=100 + - training.save_steps=400 + - training.disable_tqdm=true + # dp_shard_size=1 stops main.py from creating a ParallelismConfig; the + # accelerate config below owns FSDP2 sharding. + - training.dp_shard_size=1 + - training.answer_only_loss=true + - training.ddp_timeout=3600 + - training.bf16=false + - dflash.dflash_self_logit_distillation=true + - dflash.dflash_block_size=8 + - dflash.dflash_num_anchors=512 + - dflash.dflash_loss_decay_factor=4.0 + - dflash.dflash_architecture_config.num_hidden_layers=5 + environment: + - NUM_NODES: "8" + - OVERRIDE_TRANSFORMERS: "4.57.1" + - ACCELERATE_CONFIG: examples/MiniMax/MiniMax-M2.7-DFlash/accelerate_fsdp2_hybrid.yaml + - PATCH_FSDP2_BUFFERS: "1" + - MIXED_PRECISION: "no" + slurm_config: + _factory_: "slurm_factory" + nodes: 8 + ntasks_per_node: 1 + gpus_per_node: 8 + + # Step 2: vLLM smoke test with the exported DFlash draft. + task_1: + script: common/specdec/vllm_smoke_test.sh + environment: + - HF_MODEL_CKPT: <> + - DRAFT_CKPT_DIR: /scratchspace/dflash_minimax_m2.7 + - SPEC_METHOD: "dflash" + - NUM_SPEC_TOKENS: "7" + - TP_SIZE: "4" + slurm_config: + _factory_: "slurm_factory" + nodes: 1 + ntasks_per_node: 1 + gpus_per_node: 8 + container: vllm/vllm-openai:nightly + + # Step 3: HF AR (acceptance-length) evaluation on MT-Bench. + task_2: + script: common/specdec/ar_eval_mtbench.sh + args: + - --ckpt_dir /scratchspace/dflash_minimax_m2.7 + - --osl 512 + - --steps 7 + slurm_config: + _factory_: "slurm_factory" + nodes: 1 + ntasks_per_node: 1 + gpus_per_node: 8 diff --git a/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/specdec_bench.yaml b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/specdec_bench.yaml new file mode 100644 index 00000000000..25107e67917 --- /dev/null +++ b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/specdec_bench.yaml @@ -0,0 +1,84 @@ +# SPEED-Bench run for MiniMax-M2.7 with the trained DFlash draft, via vLLM. +# +# Two-task pipeline: +# task_0 qualitative split (nvidia/SPEED-Bench-Internal/qualitative, 80 prompts) +# task_1 throughput_32k split (nvidia/SPEED-Bench-Internal/throughput_32k, long context) +# +# Both use --speculative_algorithm DFLASH with the exported draft. Acceptance length +# (AL) is the primary metric; the throughput_32k split checks that AL holds at long +# context (if it drops, add YaRN/rope_scaling to the draft export). +# +# Results write to /scratchspace/minimax_m2.7_dflash_vllm//. The +# pensieve-intern specdec_bench workflow's wrap_up stage publishes these to +# s3://team-specdec-workgroup/results/minimax_m2.7_dflash_vllm//. +# +# Set draft_model to the trained, exported draft checkpoint (an exported-checkpoint-* +# dir containing model.safetensors), e.g. the output of hf_online_dflash.yaml. +# +# Usage: +# Edit the two --draft_model_dir args to point at your exported draft checkpoint +# (an exported-checkpoint-* dir with model.safetensors), then: +# uv run slurm.py --yaml modules/Model-Optimizer/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/specdec_bench.yaml --yes + +job_name: MiniMax-M2.7-DFlash_specdec_bench +pipeline: + global_vars: + hf_model: /hf-local/MiniMaxAI/MiniMax-M2.7 + + # Step 1: qualitative split — quality / acceptance-length numbers. + task_0: + script: common/specdec_bench/run.sh + args: + - --dataset speed + - --dataset_path /hf-local/nvidia/SPEED-Bench-Internal/qualitative + - --engine VLLM + - --speculative_algorithm DFLASH + - --draft_model_dir /scratchspace/dflash_minimax_m2.7/exported-checkpoint-final + - --tp_size 4 + - --ep_size 4 + - --concurrency 32 + - --output_length 4096 + - --trust_remote_code + - --aa_timing + - --show_progress + - --save_dir /scratchspace/minimax_m2.7_dflash_vllm/qualitative + environment: + - HF_MODEL_CKPT: <> + - HF_LOCAL: /hf-local + slurm_config: + _factory_: "slurm_factory" + nodes: 1 + ntasks_per_node: 1 + gpus_per_node: 8 + container: vllm/vllm-openai:nightly + + # Step 2: throughput_32k split — long-context throughput / AL. + # --num_requests 80 caps the 1,536-sample split to fit the time limit; --max_seq_len + # 40960 = 32K input + 4K output + 4K headroom so vLLM doesn't auto-cap below 36K. + task_1: + script: common/specdec_bench/run.sh + args: + - --dataset speed + - --dataset_path /hf-local/nvidia/SPEED-Bench-Internal/throughput_32k + - --engine VLLM + - --speculative_algorithm DFLASH + - --draft_model_dir /scratchspace/dflash_minimax_m2.7/exported-checkpoint-final + - --tp_size 4 + - --ep_size 4 + - --concurrency 8 + - --num_requests 80 + - --output_length 4096 + - --max_seq_len 40960 + - --trust_remote_code + - --aa_timing + - --show_progress + - --save_dir /scratchspace/minimax_m2.7_dflash_vllm/throughput_32k + environment: + - HF_MODEL_CKPT: <> + - HF_LOCAL: /hf-local + slurm_config: + _factory_: "slurm_factory" + nodes: 1 + ntasks_per_node: 1 + gpus_per_node: 8 + container: vllm/vllm-openai:nightly From 207f257197cbcb5509f2bfceb8598494b34850e7 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Tue, 9 Jun 2026 11:00:12 -0700 Subject: [PATCH 08/31] feat: loss_mask + chat-template support in compute_hidden_states_vllm.py The vLLM hidden-state dump rejected --answer-only-loss / --chat-template and emitted no loss_mask, so DFlash offline training with answer-only loss could not use it (the HF dump supports this but is impractical for 229B on one GPU). Mirror the HF dump: register add_answer_only_loss_args, apply an optional override chat template, verify {% generation %} tags, and tokenize via tokenize_with_loss_mask so each .pt carries an aligned loss_mask. Prefix caching is already disabled, so the dumped hidden states line up 1:1 with input_ids/loss_mask. Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: Ye Yu --- .../compute_hidden_states_vllm.py | 42 ++++++++++++------- 1 file changed, 27 insertions(+), 15 deletions(-) diff --git a/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py b/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py index 8166960a189..624d915f962 100644 --- a/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py +++ b/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py @@ -28,7 +28,14 @@ from pathlib import Path import torch -from common import add_aux_layers_args, resolve_aux_layers +from common import ( + add_answer_only_loss_args, + add_aux_layers_args, + load_chat_template, + resolve_aux_layers, + tokenize_with_loss_mask, + verify_generation_tags, +) from datasets import load_dataset from tqdm import tqdm from transformers import AutoConfig, AutoTokenizer @@ -63,6 +70,7 @@ def parse_args() -> argparse.Namespace: "--debug-max-num-conversations", type=int, default=None, help="Limit conversations." ) add_aux_layers_args(parser) + add_answer_only_loss_args(parser) return parser.parse_args() @@ -121,12 +129,18 @@ def keep_conversation(entry): tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=args.trust_remote_code) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token + override_template = load_chat_template(args.chat_template) + if override_template is not None: + tokenizer.chat_template = override_template if tokenizer.chat_template is not None: tokenizer.chat_template = tokenizer.chat_template.replace(REMOVE_THINK_CHAT_TEMPLATE, "") + if args.answer_only_loss: + verify_generation_tags(tokenizer.chat_template) # Prepare prompts for vLLM prompts = [] conversation_ids = [] + loss_masks = [] num_skipped_too_long = 0 num_invalid = 0 @@ -137,18 +151,13 @@ def keep_conversation(entry): num_invalid += 1 continue - tokenized = tokenizer.apply_chat_template( - conversations, return_tensors="pt", add_generation_prompt=False + # One apply_chat_template call yields aligned input_ids + loss_mask. With + # --answer-only-loss the mask comes from the template's {% generation %} tags; + # otherwise it is all-ones. Same tokens are sent to vLLM, so the dumped hidden + # states line up with this loss_mask 1:1 (prefix caching is disabled below). + input_ids, loss_mask = tokenize_with_loss_mask( + tokenizer, conversations, args.answer_only_loss ) - # transformers 5.x: BatchEncoding may not inherit from dict; use .input_ids - if hasattr(tokenized, "input_ids"): - input_ids = tokenized.input_ids - elif hasattr(tokenized, "__getitem__") and "input_ids" in tokenized: - input_ids = tokenized["input_ids"] - else: - input_ids = tokenized - if not hasattr(input_ids, "shape"): - input_ids = torch.tensor(input_ids) input_ids = input_ids.squeeze(0) num_tokens = input_ids.shape[0] if num_tokens <= 10 or num_tokens > args.max_seq_len: @@ -157,6 +166,7 @@ def keep_conversation(entry): prompts.append(TokensPrompt(prompt_token_ids=input_ids.tolist())) conversation_ids.append(conversation_id) + loss_masks.append(loss_mask) print( f"Prepared {len(prompts)} prompts ({num_skipped_too_long} skipped too long, {num_invalid} invalid)" @@ -202,10 +212,11 @@ def keep_conversation(entry): # max_tokens=1: we only need a single forward pass over the prompt tokens. outputs = llm.generate(prompts, SamplingParams(max_tokens=1)) - # Save in the same format as compute_hidden_states_hf.py (sans loss_mask, which the - # vLLM path does not compute). + # Save in the same format as compute_hidden_states_hf.py, including loss_mask. num_success = 0 - for conv_id, output in tqdm(zip(conversation_ids, outputs), total=len(outputs), desc="Saving"): + for conv_id, loss_mask, output in tqdm( + zip(conversation_ids, loss_masks, outputs), total=len(outputs), desc="Saving" + ): hidden_states_path = output.kv_transfer_params.get("hidden_states_path") if hidden_states_path is None: print(f"WARNING: no hidden_states_path for conversation {conv_id}; skipping") @@ -232,6 +243,7 @@ def keep_conversation(entry): "input_ids": token_ids.cpu(), "hidden_states": output_hidden_states, "aux_hidden_states": aux_hidden_states, + "loss_mask": loss_mask[: output_hidden_states.shape[0]].cpu(), "conversation_id": conv_id, }, f, From f63a3879c9affcdf701d743f877c2f27acac8073 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Tue, 9 Jun 2026 12:10:52 -0700 Subject: [PATCH 09/31] fix: tolerate missing dist metadata in modelopt.__version__ MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit modelopt/__init__.py did __version__ = importlib.metadata.version('nvidia-modelopt') with no guard, so importing modelopt crashes with PackageNotFoundError whenever the source tree is on the path without dist metadata — e.g. the launcher mounts the modelopt source into a vLLM/TRT-LLM container's site-packages rather than pip-installing it. That broke any modelopt import in those containers: collect_hidden_states' resolve_aux_layers (DFlash/EAGLE presets) and specdec_bench (whose guard only caught ModuleNotFoundError, not PackageNotFoundError). Fall back to a sentinel version instead. Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: Ye Yu --- modelopt/__init__.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/modelopt/__init__.py b/modelopt/__init__.py index 14907827956..0f5517c1924 100644 --- a/modelopt/__init__.py +++ b/modelopt/__init__.py @@ -15,6 +15,14 @@ """Nvidia Model Optimizer (modelopt).""" +from importlib.metadata import PackageNotFoundError from importlib.metadata import version as _version -__version__ = _version("nvidia-modelopt") +try: + __version__ = _version("nvidia-modelopt") +except PackageNotFoundError: + # No dist metadata — e.g. the modelopt source tree is mounted directly into a + # vLLM / TRT-LLM container's site-packages (as the launcher does) instead of being + # pip-installed. Importing modelopt must not crash in that case; downstream tools + # (specdec_bench, collect_hidden_states) only need the package, not its version. + __version__ = "0.0.0+unknown" From b9baedecfe6123e1bee18aa6f27518cc133f785e Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Tue, 9 Jun 2026 12:27:02 -0700 Subject: [PATCH 10/31] fix: standalone aux-layer resolution in vLLM dump + --block_size in DFLASH bench Two fixes for the DFlash launcher paths in a stock vLLM container: - compute_hidden_states_vllm.py: resolve_aux_layers('dflash') imports modelopt.torch.speculative.plugins, which drags in the full modelopt.torch init chain (omegaconf, ...) that the vLLM container lacks. Resolve the 'dflash' preset / explicit layer list inline so the dump needs no modelopt at all. - examples/.../specdec_bench.yaml: pass --block_size 8. run.py maps speculative_num_draft_tokens=args.block_size; for DFLASH this must be set or num_speculative_tokens is None and vLLM's max_num_seqs=concurrency*None crashes. Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: Ye Yu --- .../compute_hidden_states_vllm.py | 31 +++++++++++++++++-- .../MiniMax-M2.7-DFlash/specdec_bench.yaml | 2 ++ 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py b/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py index 624d915f962..cece13d6e9b 100644 --- a/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py +++ b/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py @@ -32,7 +32,6 @@ add_answer_only_loss_args, add_aux_layers_args, load_chat_template, - resolve_aux_layers, tokenize_with_loss_mask, verify_generation_tags, ) @@ -45,6 +44,34 @@ ) +def _resolve_aux_layers_standalone(aux_layers: str, num_hidden_layers: int) -> list[int]: + """Resolve aux-layer ids without importing modelopt. + + This dump runs in a stock vLLM container. ``common.resolve_aux_layers`` resolves the + 'dflash'/'eagle' presets by importing ``modelopt.torch.speculative.plugins`` — which + pulls in the full ``modelopt.torch`` init chain (omegaconf, etc.) that the vLLM + container does not have, so the import fails. Resolve the 'dflash' preset inline + (mirroring ``modeling_dflash.build_target_layer_ids`` with num_draft=5, the recipe + default) and accept an explicit comma-separated int list. Keep in sync with modelopt. + """ + spec = aux_layers.strip().lower() + if spec == "dflash": + num_draft = 5 + if num_draft == 1: + return [num_hidden_layers // 2] + start = min(1, num_hidden_layers - 1) + end = max(start, num_hidden_layers - 3) + span = end - start + return sorted({round(start + (i * span) / (num_draft - 1)) for i in range(num_draft)}) + ids = sorted({int(t) for t in aux_layers.split(",") if t.strip()}) + if not ids: + raise ValueError( + f"--aux-layers={aux_layers!r}: in the stock vLLM container (no modelopt) only the " + "'dflash' preset or an explicit comma-separated layer-id list are supported." + ) + return ids + + def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="""Collect hidden states from conversations using vLLM's native extractor.""" @@ -120,7 +147,7 @@ def keep_conversation(entry): num_hidden_layers = getattr(config, "num_hidden_layers", None) if num_hidden_layers is None: raise ValueError(f"model config has no 'num_hidden_layers' attribute: {config}") - aux_layer_ids = resolve_aux_layers(args, num_hidden_layers) + aux_layer_ids = _resolve_aux_layers_standalone(args.aux_layers, num_hidden_layers) # The trailing entry is the final output hidden state; the rest are aux layers. extract_layer_ids = [*aux_layer_ids, num_hidden_layers] print(f"Extracting hidden states from layers {extract_layer_ids} (last = final output)") diff --git a/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/specdec_bench.yaml b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/specdec_bench.yaml index 25107e67917..63ea3c47095 100644 --- a/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/specdec_bench.yaml +++ b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/specdec_bench.yaml @@ -34,6 +34,7 @@ pipeline: - --engine VLLM - --speculative_algorithm DFLASH - --draft_model_dir /scratchspace/dflash_minimax_m2.7/exported-checkpoint-final + - --block_size 8 - --tp_size 4 - --ep_size 4 - --concurrency 32 @@ -63,6 +64,7 @@ pipeline: - --engine VLLM - --speculative_algorithm DFLASH - --draft_model_dir /scratchspace/dflash_minimax_m2.7/exported-checkpoint-final + - --block_size 8 - --tp_size 4 - --ep_size 4 - --concurrency 8 From 79634e976d818bd187ecc576b169010845ff478b Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Wed, 10 Jun 2026 10:00:34 -0700 Subject: [PATCH 11/31] feat: DFlash export injects long-context RoPE (target rope_theta + YaRN) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit DFlash drafts train on a short window (e.g. 4096) but draft for the target at long context. Exporting with the draft's minimal rope (rope_theta=1e4, no scaling) collapses long-context acceptance — MiniMax-M2.7 32K acceptance length was 1.17 (near the 1.0 no-speculation floor). Mirror published Eagle3 drafts (nvidia/Kimi-K2.6-Eagle3): inherit the target's rope_theta and inject a YaRN rope_scaling (type=yarn, factor=max_position/ original, original_max_position_embeddings=4096, betas/mscale=1.0), auto-enabled when the target context exceeds the draft's training window. For MiniMax-M2.7 this exports rope_theta=5e6 + factor 48, which recovered 32K AL from 1.17 to 2.62 (≈ its short-context level). Override the window via model.dflash_export_rope_original_max. Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: Ye Yu --- .../torch/export/plugins/hf_spec_export.py | 30 +++++++++++++++---- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/modelopt/torch/export/plugins/hf_spec_export.py b/modelopt/torch/export/plugins/hf_spec_export.py index 54d6e493c25..607ef6044fa 100644 --- a/modelopt/torch/export/plugins/hf_spec_export.py +++ b/modelopt/torch/export/plugins/hf_spec_export.py @@ -376,11 +376,12 @@ def _export_config(self): "initializer_range": getattr(base_config, "initializer_range", 0.02), "attention_bias": getattr(draft_config, "attention_bias", False), "attention_dropout": getattr(draft_config, "attention_dropout", 0.0), - "rope_theta": getattr( - draft_config, "rope_theta", getattr(base_config, "rope_theta", 1000000.0) - ), - # DFlash draft uses standard Qwen3 RoPE, not M-RoPE from multimodal models. - # z-lab uses null; vLLM handles null rope_scaling correctly. + # Inherit the target's rope_theta — the draft drafts for the base model, so its + # RoPE base must match it. (DFlash trains with a minimal rope; the real + # long-context RoPE is applied here at export.) + "rope_theta": getattr(base_config, "rope_theta", None) + or getattr(draft_config, "rope_theta", 1000000.0), + # YaRN long-context scaling is injected below (see the rope_scaling block). "rope_scaling": None, "tie_word_embeddings": False, "torch_dtype": str(getattr(base_config, "torch_dtype", torch.bfloat16)).replace( @@ -395,6 +396,25 @@ def _export_config(self): else: config["layer_types"] = ["full_attention"] * draft_config.num_hidden_layers + # Long-context RoPE (YaRN). The draft trains on a short window but must draft for + # the target at long context, so — mirroring published Eagle3 drafts such as + # nvidia/Kimi-K2.6-Eagle3 — export a YaRN rope_scaling that extends the training + # window to the target's full context. Auto-enabled when the target's + # max_position_embeddings exceeds the draft's training window; override the window + # via model.dflash_export_rope_original_max (defaults to 4096, the usual seq len). + yarn_original_max = int(getattr(self.model, "dflash_export_rope_original_max", 4096)) + target_max = config.get("max_position_embeddings") or 0 + if target_max > yarn_original_max: + config["rope_scaling"] = { + "type": "yarn", + "factor": float(target_max) / float(yarn_original_max), + "original_max_position_embeddings": yarn_original_max, + "beta_fast": 1.0, + "beta_slow": 1.0, + "mscale": 1.0, + "mscale_all_dim": 1.0, + } + return config def export(self, export_dir: Path | str, dtype: torch.dtype | None = None): From b54c862a14ac3aa953b15f61950640ebb228d003 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Wed, 10 Jun 2026 10:01:40 -0700 Subject: [PATCH 12/31] fix: vLLM hidden-state dump accepts 'messages' key (not just 'conversations') compute_hidden_states_vllm.py read entry['conversations'] and KeyError'd on OpenAI-style data keyed by 'messages' (e.g. the MiniMax-M2.7 synthetic data). Accept either key, matching the lenient loaders elsewhere. Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: Ye Yu --- .../collect_hidden_states/compute_hidden_states_vllm.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py b/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py index cece13d6e9b..b79f0bfb7c2 100644 --- a/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py +++ b/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py @@ -173,7 +173,9 @@ def keep_conversation(entry): for entry in dataset: conversation_id = entry.get("conversation_id", entry.get("uuid")) - conversations = entry["conversations"] + # Accept either the "conversations" or OpenAI-style "messages" key (the + # MiniMax synthetic data uses "messages"). + conversations = entry.get("conversations") or entry.get("messages") if not conversations or not isinstance(conversations, list): num_invalid += 1 continue From e00aaba8c29da9905517bf60d2ebb0443de00212 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Wed, 10 Jun 2026 10:17:43 -0700 Subject: [PATCH 13/31] fix: vLLM hidden-state dump stages on /dev/shm with sync lock MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The connector wrote intermediate safetensors to the (lustre) output dir with use_synchronization_lock=False; the client then read each file back immediately and hit FileNotFoundError on a .lock the lock-less producer never wrote. Stage on local tmpfs (/dev/shm, per-DP-rank) and enable the sync lock so the reader waits for the producer — matching the validated dump path. Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: Ye Yu --- .../compute_hidden_states_vllm.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py b/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py index b79f0bfb7c2..caf97ad4d3b 100644 --- a/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py +++ b/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py @@ -207,7 +207,11 @@ def keep_conversation(entry): # Initialize vLLM with the native hidden-state extractor. tp = args.tp if args.tp is not None else torch.cuda.device_count() - storage_path = output_dir / ".vllm_hidden_states" + # Stage the connector's intermediate safetensors on local tmpfs, not the (lustre) + # output dir: the producer writes one file per request and the client reads it back + # immediately, so a fast local path avoids cross-node FS latency. Per-DP-rank dir so + # parallel shards don't collide. + storage_path = Path("/dev/shm") / f"vllm_hidden_states_dp{args.dp_rank}" storage_path.mkdir(parents=True, exist_ok=True) llm = LLM( @@ -233,7 +237,10 @@ def keep_conversation(entry): kv_role="kv_producer", kv_connector_extra_config={ "shared_storage_path": str(storage_path), - "use_synchronization_lock": False, # batch generation, no concurrent readers + # The client reads each request's safetensors right after generation; the + # lock makes the producer signal completion so the reader doesn't race the + # writer (without it the reader looks for a .lock the producer never wrote). + "use_synchronization_lock": True, }, ), ) From 9dbd693bda7808545721d76b75bcf8f06e397e9b Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Wed, 10 Jun 2026 10:32:47 -0700 Subject: [PATCH 14/31] review: address CodeRabbit nitpicks + revert modelopt version guard MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - modelopt/__init__.py: revert the PackageNotFoundError guard (per @kevalmorabia97 — modelopt should be pip-installed so its deps are present; the dump no longer imports modelopt anyway, using inline aux-layer resolution). - compute_hidden_states_vllm.py: validate explicit --aux-layers ids are in [0, num_hidden_layers); skip-with-warning when loss_mask is shorter than the dumped hidden states (avoids silent slice-to-self misalignment); add TODO to reuse common.resolve_aux_layers once decoupled from modelopt.torch. - fsdp2_buffer_patch.py: clip_grad_norm empty-grad path returns a CUDA tensor to match the normal path's device. - launcher/core.py: copy additional_parameters into a fresh dict before the requeue mutation so it doesn't leak into the shared slurm_config. Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: Ye Yu --- .../compute_hidden_states_vllm.py | 21 +++++++++++++++++++ .../fsdp2_buffer_patch.py | 5 ++++- modelopt/__init__.py | 10 +-------- tools/launcher/core.py | 4 +++- 4 files changed, 29 insertions(+), 11 deletions(-) diff --git a/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py b/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py index caf97ad4d3b..0a721fb5ad9 100644 --- a/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py +++ b/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py @@ -53,6 +53,9 @@ def _resolve_aux_layers_standalone(aux_layers: str, num_hidden_layers: int) -> l container does not have, so the import fails. Resolve the 'dflash' preset inline (mirroring ``modeling_dflash.build_target_layer_ids`` with num_draft=5, the recipe default) and accept an explicit comma-separated int list. Keep in sync with modelopt. + + TODO: drop this once ``common.resolve_aux_layers`` is decoupled from the heavy + ``modelopt.torch`` import chain so it can be reused directly in a vLLM container. """ spec = aux_layers.strip().lower() if spec == "dflash": @@ -64,6 +67,13 @@ def _resolve_aux_layers_standalone(aux_layers: str, num_hidden_layers: int) -> l span = end - start return sorted({round(start + (i * span) / (num_draft - 1)) for i in range(num_draft)}) ids = sorted({int(t) for t in aux_layers.split(",") if t.strip()}) + # Match the shared helper's contract: ids must be valid layer indices. + out_of_range = [i for i in ids if not 0 <= i < num_hidden_layers] + if out_of_range: + raise ValueError( + f"--aux-layers ids {out_of_range} out of range [0, {num_hidden_layers}) " + f"for a {num_hidden_layers}-layer model." + ) if not ids: raise ValueError( f"--aux-layers={aux_layers!r}: in the stock vLLM container (no modelopt) only the " @@ -272,6 +282,17 @@ def keep_conversation(entry): else: aux_hidden_states = torch.empty(0) + # loss_mask is sliced to the dumped length below; a shorter loss_mask would slice + # to itself and silently misalign with the hidden states, so guard explicitly. + n_hs = output_hidden_states.shape[0] + if loss_mask.shape[0] < n_hs: + print( + f"WARNING: {conv_id}: loss_mask ({loss_mask.shape[0]}) shorter than hidden " + f"states ({n_hs}); skipping to avoid misalignment" + ) + num_error += 1 + continue + output_file = output_dir / f"{conv_id}.pt" with open(output_file, "wb") as f: torch.save( diff --git a/examples/speculative_decoding/fsdp2_buffer_patch.py b/examples/speculative_decoding/fsdp2_buffer_patch.py index 3de31a0ee3b..96d2b89fbbe 100644 --- a/examples/speculative_decoding/fsdp2_buffer_patch.py +++ b/examples/speculative_decoding/fsdp2_buffer_patch.py @@ -267,7 +267,10 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2): grads = [p.grad for p in parameters if p.grad is not None] if len(grads) == 0: - return torch.tensor(0.0) + # Match the device of the normal return path (GPU when training on CUDA) so + # callers don't hit a device mismatch on the empty-grad case. + dev = "cuda" if torch.cuda.is_available() else "cpu" + return torch.tensor(0.0, device=dev) # Shard DTensors hold partial data — need all_reduce for global norm. # Replicate DTensors and regular tensors already hold full data. diff --git a/modelopt/__init__.py b/modelopt/__init__.py index 0f5517c1924..14907827956 100644 --- a/modelopt/__init__.py +++ b/modelopt/__init__.py @@ -15,14 +15,6 @@ """Nvidia Model Optimizer (modelopt).""" -from importlib.metadata import PackageNotFoundError from importlib.metadata import version as _version -try: - __version__ = _version("nvidia-modelopt") -except PackageNotFoundError: - # No dist metadata — e.g. the modelopt source tree is mounted directly into a - # vLLM / TRT-LLM container's site-packages (as the launcher does) instead of being - # pip-installed. Importing modelopt must not crash in that case; downstream tools - # (specdec_bench, collect_hidden_states) only need the package, not its version. - __version__ = "0.0.0+unknown" +__version__ = _version("nvidia-modelopt") diff --git a/tools/launcher/core.py b/tools/launcher/core.py index 6eb9473aa71..69f516ccceb 100644 --- a/tools/launcher/core.py +++ b/tools/launcher/core.py @@ -297,7 +297,9 @@ def build_slurm_executor( retries=0, packager=packager, srun_args=slurm_config.srun_args, - additional_parameters=getattr(slurm_config, "additional_parameters", None) or {}, + # Copy into a fresh dict so the requeue mutation below doesn't leak back into + # the shared slurm_config.additional_parameters. + additional_parameters=dict(getattr(slurm_config, "additional_parameters", None) or {}), ) if getattr(slurm_config, "requeue", False): executor.additional_parameters["requeue"] = True From 13167609b3cd2e6ad47eaa995e143eb1386fc739 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Wed, 10 Jun 2026 10:35:31 -0700 Subject: [PATCH 15/31] style: ruff format fsdp2_buffer_patch.py (license header + line wrapping) Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: Ye Yu --- .../fsdp2_buffer_patch.py | 72 +++++++++++++------ 1 file changed, 52 insertions(+), 20 deletions(-) diff --git a/examples/speculative_decoding/fsdp2_buffer_patch.py b/examples/speculative_decoding/fsdp2_buffer_patch.py index 96d2b89fbbe..01a31fe7eec 100644 --- a/examples/speculative_decoding/fsdp2_buffer_patch.py +++ b/examples/speculative_decoding/fsdp2_buffer_patch.py @@ -1,3 +1,18 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 @@ -98,7 +113,6 @@ import torch - # Dtype encoding for the broadcast dtype-sync step. _DTYPE_TO_CODE = { torch.float32: 0, @@ -115,10 +129,11 @@ def apply(): import accelerate.utils.fsdp_utils as fsdp_utils from torch.distributed.tensor import DTensor - _orig = fsdp_utils.fsdp2_load_full_state_dict # noqa: F841 + _orig = fsdp_utils.fsdp2_load_full_state_dict def _patched(accelerator, model, full_sd, cpu_offload=False): import time + import torch.distributed as dist from torch.distributed.tensor import distribute_tensor @@ -129,11 +144,15 @@ def _patched(accelerator, model, full_sd, cpu_offload=False): n_buffer = n_total - n_dtensor if accelerator.is_main_process: - print(f"[fsdp2_buffer_patch] State dict: {n_total} entries " - f"({n_dtensor} DTensor, {n_buffer} buffer), full_sd: {len(full_sd)}") + print( + f"[fsdp2_buffer_patch] State dict: {n_total} entries " + f"({n_dtensor} DTensor, {n_buffer} buffer), full_sd: {len(full_sd)}" + ) else: - print(f"[fsdp2_buffer_patch] State dict: {n_total} entries " - f"({n_dtensor} DTensor, {n_buffer} buffer)") + print( + f"[fsdp2_buffer_patch] State dict: {n_total} entries " + f"({n_dtensor} DTensor, {n_buffer} buffer)" + ) t0 = time.time() # --- Step 0: broadcast dtype codes from rank 0 --- @@ -144,13 +163,15 @@ def _patched(accelerator, model, full_sd, cpu_offload=False): # each broadcast tensor. if accelerator.is_main_process: dtype_codes = torch.tensor( - [_DTYPE_TO_CODE.get(full_sd[name].dtype, 0) - for name in meta_sharded_sd.keys()], - dtype=torch.int32, device=accelerator.device, + [_DTYPE_TO_CODE.get(full_sd[name].dtype, 0) for name in meta_sharded_sd.keys()], + dtype=torch.int32, + device=accelerator.device, ) else: dtype_codes = torch.empty( - n_total, dtype=torch.int32, device=accelerator.device, + n_total, + dtype=torch.int32, + device=accelerator.device, ) dist.broadcast(dtype_codes, src=0, group=dist.group.WORLD) broadcast_dtypes = [_CODE_TO_DTYPE[c.item()] for c in dtype_codes] @@ -191,8 +212,11 @@ def _finish(st, contig, dtype, offload): if not is_dtensor: # Persistent buffer — broadcast raw, no distribute_tensor if accelerator.is_main_process: - t = full_sd[param_name].detach().to( - device=accelerator.device, dtype=bcast_dtype) + t = ( + full_sd[param_name] + .detach() + .to(device=accelerator.device, dtype=bcast_dtype) + ) else: t = torch.empty( sharded_param.size(), @@ -205,8 +229,11 @@ def _finish(st, contig, dtype, offload): device_mesh = sharded_param.device_mesh if accelerator.is_main_process: - ft = full_sd[param_name].detach().to( - device=device_mesh.device_type, dtype=bcast_dtype) + ft = ( + full_sd[param_name] + .detach() + .to(device=device_mesh.device_type, dtype=bcast_dtype) + ) if isinstance(ft, DTensor): ft = ft.to_local() else: @@ -227,11 +254,15 @@ def _finish(st, contig, dtype, offload): sharded_sd[param_name] = _finish(st, contig, final_dtype, cpu_offload) elapsed = time.time() - t0 - print(f"[fsdp2_buffer_patch] Broadcast done in {elapsed:.1f}s, " - f"loading {len(sharded_sd)} entries into model...") + print( + f"[fsdp2_buffer_patch] Broadcast done in {elapsed:.1f}s, " + f"loading {len(sharded_sd)} entries into model..." + ) model.load_state_dict(sharded_sd, assign=True) - print(f"[fsdp2_buffer_patch] State dict loaded successfully " - f"({time.time() - t0:.1f}s total)") + print( + f"[fsdp2_buffer_patch] State dict loaded successfully " + f"({time.time() - t0:.1f}s total)" + ) return model fsdp_utils.fsdp2_load_full_state_dict = _patched @@ -325,5 +356,6 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2): def patch_accelerator(accelerator): """Replace accelerator's clip_grad_norm_ with FSDP2-safe version.""" accelerator.clip_grad_norm_ = _clip_grad_norm - print("[fsdp2_buffer_patch] Patched accelerator.clip_grad_norm_ " - "for FSDP2 DTensor compatibility") + print( + "[fsdp2_buffer_patch] Patched accelerator.clip_grad_norm_ for FSDP2 DTensor compatibility" + ) From 45f3a1e84791333b5f7f8fc05ac1c337c7016deb Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Wed, 10 Jun 2026 11:28:16 -0700 Subject: [PATCH 16/31] review: converge DFlash export RoPE to a config field; gate export callback on FSDP2 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Addresses @hguo-nv's PR feedback: - YaRN export rope_scaling now comes from a config field (dflash_export_rope_scaling on DFlashConfig, copied onto the model in modify(), read by DFlashExporter), mirroring eagle's eagle_export_rope_scaling convention — replacing the auto-derive from max_position_embeddings. Default {} disables injection. The MiniMax recipes set the validated YaRN config (factor=48, original_max=4096, betas/mscale=1.0). - DFlashExportCallback is now appended only under FSDP2 (dp_shard_size>1 or PATCH_FSDP2_BUFFERS=1). Under DDP, checkpoint-* dirs are already full and the launcher script's post-run export_hf_checkpoint.py handles them, so the per-step draft-gather callback is unnecessary. - Offline MiniMax recipe switched from FSDP2 to plain DDP: offline training uses a lightweight FakeBaseModel (embeddings + lm_head only, not the 229B base), which fits under DDP and bypasses the FSDP2 buffer/dtype patches entirely. Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: Ye Yu --- examples/speculative_decoding/main.py | 24 +++-- .../torch/export/plugins/hf_spec_export.py | 23 ++--- modelopt/torch/speculative/config.py | 14 +++ .../torch/speculative/dflash/dflash_model.py | 1 + .../MiniMax-M2.7-DFlash/_smoke_offline.yaml | 87 +++++++++++++++++++ .../MiniMax-M2.7-DFlash/_smoke_offline6.yaml | 87 +++++++++++++++++++ .../MiniMax-M2.7-DFlash/_smoke_offline7.yaml | 87 +++++++++++++++++++ .../MiniMax-M2.7-DFlash/_smoke_offline8.yaml | 87 +++++++++++++++++++ .../MiniMax-M2.7-DFlash/_thru32k_test.yaml | 32 +++++++ .../MiniMax-M2.7-DFlash/_thru32k_yarn.yaml | 32 +++++++ .../hf_offline_dflash.yaml | 17 +++- .../MiniMax-M2.7-DFlash/hf_online_dflash.yaml | 10 +++ 12 files changed, 478 insertions(+), 23 deletions(-) create mode 100644 tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/_smoke_offline.yaml create mode 100644 tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/_smoke_offline6.yaml create mode 100644 tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/_smoke_offline7.yaml create mode 100644 tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/_smoke_offline8.yaml create mode 100644 tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/_thru32k_test.yaml create mode 100644 tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/_thru32k_yarn.yaml diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index 832ee0f72a3..8d368ddfe1c 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -295,12 +295,26 @@ def train(): training_args.ignore_data_skip = True callbacks.append(StreamingResumeCallback()) - # DFlash: export the draft submodule after every checkpoint save. Under - # FSDP2 SHARDED_STATE_DICT, checkpoint-* dirs hold only distributed shards; - # this callback gathers just the small draft module and writes a deployable - # exported-checkpoint-{step}/ that vLLM (and the AL tests) can load. + # DFlash: export the draft submodule after every checkpoint save — but only under + # FSDP2. With FSDP2 SHARDED_STATE_DICT, checkpoint-* dirs hold only distributed + # shards that the post-training export_hf_checkpoint.py pass can't read, so this + # callback gathers just the small draft module per save and writes a deployable + # exported-checkpoint-{step}/. Under DDP (e.g. offline FakeBaseModel training, or + # any single-device recipe) checkpoints are already full and the launcher script's + # post-run export handles them, so the callback is unnecessary overhead. + # FSDP2 is active via either route: native ParallelismConfig (dp_shard_size > 1) or + # the accelerate-config fallback used for transformers 4.57.x (PATCH_FSDP2_BUFFERS). if isinstance(recipe, ModelOptDFlashRecipe): - callbacks.append(DFlashExportCallback()) + uses_fsdp2 = (getattr(training_args, "dp_shard_size", 1) or 1) > 1 or os.environ.get( + "PATCH_FSDP2_BUFFERS" + ) == "1" + if uses_fsdp2: + callbacks.append(DFlashExportCallback()) + else: + print_rank_0( + "DFlash: non-FSDP2 run detected — skipping per-step DFlashExportCallback; " + "checkpoints are full and will be exported post-training by the launcher script." + ) trainer = EagleTrainerWithAccLog( model=model, diff --git a/modelopt/torch/export/plugins/hf_spec_export.py b/modelopt/torch/export/plugins/hf_spec_export.py index 607ef6044fa..b35888dc1d2 100644 --- a/modelopt/torch/export/plugins/hf_spec_export.py +++ b/modelopt/torch/export/plugins/hf_spec_export.py @@ -398,22 +398,13 @@ def _export_config(self): # Long-context RoPE (YaRN). The draft trains on a short window but must draft for # the target at long context, so — mirroring published Eagle3 drafts such as - # nvidia/Kimi-K2.6-Eagle3 — export a YaRN rope_scaling that extends the training - # window to the target's full context. Auto-enabled when the target's - # max_position_embeddings exceeds the draft's training window; override the window - # via model.dflash_export_rope_original_max (defaults to 4096, the usual seq len). - yarn_original_max = int(getattr(self.model, "dflash_export_rope_original_max", 4096)) - target_max = config.get("max_position_embeddings") or 0 - if target_max > yarn_original_max: - config["rope_scaling"] = { - "type": "yarn", - "factor": float(target_max) / float(yarn_original_max), - "original_max_position_embeddings": yarn_original_max, - "beta_fast": 1.0, - "beta_slow": 1.0, - "mscale": 1.0, - "mscale_all_dim": 1.0, - } + # nvidia/Kimi-K2.6-Eagle3 — inject a YaRN rope_scaling that extends the training + # window to the target's full context. Sourced from the config field + # dflash_export_rope_scaling (set in the recipe YAML), matching the eagle + # eagle_export_rope_scaling convention. Empty dict (default) disables injection. + export_rope_scaling = getattr(self.model, "dflash_export_rope_scaling", None) + if export_rope_scaling: + config["rope_scaling"] = export_rope_scaling return config diff --git a/modelopt/torch/speculative/config.py b/modelopt/torch/speculative/config.py index 6b2c9396ce7..11b591ae781 100644 --- a/modelopt/torch/speculative/config.py +++ b/modelopt/torch/speculative/config.py @@ -118,6 +118,20 @@ class DFlashConfig(ModeloptBaseConfig): description="Whether to use torch.compile on DFlash forward/loss methods.", ) + dflash_export_rope_scaling: dict = ModeloptField( + default={}, + description=( + "The rope_scaling config to inject into the exported HuggingFace draft config. " + "The DFlash draft trains on a short window but must draft for the target at long " + "context, so — mirroring published Eagle3 drafts such as nvidia/Kimi-K2.6-Eagle3 — " + "a YaRN rope_scaling is injected at export to extend the training window to the " + "target's full context. Example: " + '{"type": "yarn", "factor": 48.0, "original_max_position_embeddings": 4096, ' + '"beta_fast": 1.0, "beta_slow": 1.0, "mscale": 1.0, "mscale_all_dim": 1.0}. ' + "Set to empty dict {} (default) to disable rope scaling injection at export." + ), + ) + class MedusaConfig(ModeloptBaseConfig): """Medusa config.""" diff --git a/modelopt/torch/speculative/dflash/dflash_model.py b/modelopt/torch/speculative/dflash/dflash_model.py index a99e93c816a..702d9812482 100644 --- a/modelopt/torch/speculative/dflash/dflash_model.py +++ b/modelopt/torch/speculative/dflash/dflash_model.py @@ -35,3 +35,4 @@ def modify(self, config): self.dflash_num_anchors = config.dflash_num_anchors self.dflash_report_acc = config.dflash_report_acc self.dflash_use_torch_compile = config.dflash_use_torch_compile + self.dflash_export_rope_scaling = config.dflash_export_rope_scaling diff --git a/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/_smoke_offline.yaml b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/_smoke_offline.yaml new file mode 100644 index 00000000000..924e1bcfbb4 --- /dev/null +++ b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/_smoke_offline.yaml @@ -0,0 +1,87 @@ +# DFlash offline speculative decoding training for MiniMax-M2.7 (229B MoE). +# +# 2-step pipeline (compare with hf_online_dflash.yaml, which streams the base model +# forward at training time instead): +# task_0: Dump base-model hidden states once via vLLM extract_hidden_states. +# task_1: Train the DFlash draft on the dump (FakeBaseModel — loads only lm_head + +# embed_tokens, not the full 229B base). +# +# We use the vLLM dump (compute_hidden_states_vllm.py) rather than the HF dump because +# the 229B MoE is impractical to forward on a single GPU; vLLM shards it with TP. The +# dump disables prefix caching so every token's hidden state is emitted and the dumped +# sequence lengths line up with input_ids/loss_mask. +# +# Reference: "DFlash: Block Diffusion for Flash Speculative Decoding" (arXiv:2602.06036) +# +# Usage: +# uv run slurm.py --yaml modules/Model-Optimizer/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_offline_dflash.yaml --yes + +job_name: MiniMax-M2.7-DFlash_offline_smoke5 +pipeline: + global_vars: + hf_model: /hf-local/MiniMaxAI/MiniMax-M2.7 + + # Step 1: Dump base-model hidden states via vLLM extract_hidden_states (TP=4). + task_0: + script: common/eagle3/dump_offline_data_vllm.sh + args: + - --input-data /hf-local/modelopt/MiniMax-M2.7-synthetic-data-clean-v2 + - --output-dir /scratchspace/smoke_hs5 + # Must match the draft model's num_hidden_layers (recipe default: 5). + - --aux-layers dflash + - --answer-only-loss + - --chat-template modules/Model-Optimizer/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/chat_template_train.jinja + - --max-seq-len 4096 + - --tp 4 + - --debug-max-num-conversations 16 + environment: + - HF_MODEL_CKPT: <> + - TRUST_REMOTE_CODE: "1" + slurm_config: + _factory_: "slurm_factory" + nodes: 1 + ntasks_per_node: 1 + gpus_per_node: 8 + container: vllm/vllm-openai:nightly + + # Step 2: Train DFlash offline on the dumped hidden states. FakeBaseModel avoids + # loading the full 229B — only lm_head + embed_tokens are read from the checkpoint. + task_1: + script: common/specdec/dflash_online_training.sh + args: + - --config modules/Model-Optimizer/modelopt_recipes/general/speculative_decoding/dflash.yaml + - model.model_name_or_path=<> + - model.trust_remote_code=true + - model.use_fake_base_for_offline=true + - data.mode=offline + - data.offline_data_path=/scratchspace/smoke_hs5 + - data.chat_template=modules/Model-Optimizer/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/chat_template_train.jinja + - training.output_dir=/scratchspace/smoke_offline5 + - training.max_steps=10 + - training.per_device_train_batch_size=2 + - training.learning_rate=1.2e-3 + - training.warmup_steps=100 + - training.training_seq_len=4096 + - training.logging_steps=100 + - training.save_steps=400 + - training.disable_tqdm=true + - training.dp_shard_size=1 + - training.answer_only_loss=true + - training.ddp_timeout=3600 + - training.bf16=false + - dflash.dflash_self_logit_distillation=true + - dflash.dflash_block_size=8 + - dflash.dflash_num_anchors=512 + - dflash.dflash_loss_decay_factor=4.0 + - dflash.dflash_architecture_config.num_hidden_layers=5 + environment: + - NUM_NODES: "8" + - OVERRIDE_TRANSFORMERS: "4.57.1" + - ACCELERATE_CONFIG: examples/MiniMax/MiniMax-M2.7-DFlash/accelerate_fsdp2_hybrid.yaml + - PATCH_FSDP2_BUFFERS: "1" + - MIXED_PRECISION: "no" + slurm_config: + _factory_: "slurm_factory" + nodes: 8 + ntasks_per_node: 1 + gpus_per_node: 8 diff --git a/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/_smoke_offline6.yaml b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/_smoke_offline6.yaml new file mode 100644 index 00000000000..0ae74f85fcc --- /dev/null +++ b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/_smoke_offline6.yaml @@ -0,0 +1,87 @@ +# DFlash offline speculative decoding training for MiniMax-M2.7 (229B MoE). +# +# 2-step pipeline (compare with hf_online_dflash.yaml, which streams the base model +# forward at training time instead): +# task_0: Dump base-model hidden states once via vLLM extract_hidden_states. +# task_1: Train the DFlash draft on the dump (FakeBaseModel — loads only lm_head + +# embed_tokens, not the full 229B base). +# +# We use the vLLM dump (compute_hidden_states_vllm.py) rather than the HF dump because +# the 229B MoE is impractical to forward on a single GPU; vLLM shards it with TP. The +# dump disables prefix caching so every token's hidden state is emitted and the dumped +# sequence lengths line up with input_ids/loss_mask. +# +# Reference: "DFlash: Block Diffusion for Flash Speculative Decoding" (arXiv:2602.06036) +# +# Usage: +# uv run slurm.py --yaml modules/Model-Optimizer/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_offline_dflash.yaml --yes + +job_name: MiniMax-M2.7-DFlash_offline_smoke6 +pipeline: + global_vars: + hf_model: /hf-local/MiniMaxAI/MiniMax-M2.7 + + # Step 1: Dump base-model hidden states via vLLM extract_hidden_states (TP=4). + task_0: + script: common/eagle3/dump_offline_data_vllm.sh + args: + - --input-data /hf-local/modelopt/MiniMax-M2.7-synthetic-data-clean-v2 + - --output-dir /scratchspace/smoke_hs6 + # Must match the draft model's num_hidden_layers (recipe default: 5). + - --aux-layers dflash + - --answer-only-loss + - --chat-template services/pipeline/dflash/chat_template_minimax-m2.7.jinja + - --max-seq-len 4096 + - --tp 4 + - --debug-max-num-conversations 16 + environment: + - HF_MODEL_CKPT: <> + - TRUST_REMOTE_CODE: "1" + slurm_config: + _factory_: "slurm_factory" + nodes: 1 + ntasks_per_node: 1 + gpus_per_node: 8 + container: vllm/vllm-openai:nightly + + # Step 2: Train DFlash offline on the dumped hidden states. FakeBaseModel avoids + # loading the full 229B — only lm_head + embed_tokens are read from the checkpoint. + task_1: + script: common/specdec/dflash_online_training.sh + args: + - --config modules/Model-Optimizer/modelopt_recipes/general/speculative_decoding/dflash.yaml + - model.model_name_or_path=<> + - model.trust_remote_code=true + - model.use_fake_base_for_offline=true + - data.mode=offline + - data.offline_data_path=/scratchspace/smoke_hs6 + - data.chat_template=services/pipeline/dflash/chat_template_minimax-m2.7.jinja + - training.output_dir=/scratchspace/smoke_offline6 + - training.max_steps=10 + - training.per_device_train_batch_size=2 + - training.learning_rate=1.2e-3 + - training.warmup_steps=100 + - training.training_seq_len=4096 + - training.logging_steps=100 + - training.save_steps=400 + - training.disable_tqdm=true + - training.dp_shard_size=1 + - training.answer_only_loss=true + - training.ddp_timeout=3600 + - training.bf16=false + - dflash.dflash_self_logit_distillation=true + - dflash.dflash_block_size=8 + - dflash.dflash_num_anchors=512 + - dflash.dflash_loss_decay_factor=4.0 + - dflash.dflash_architecture_config.num_hidden_layers=5 + environment: + - NUM_NODES: "8" + - OVERRIDE_TRANSFORMERS: "4.57.1" + - ACCELERATE_CONFIG: examples/MiniMax/MiniMax-M2.7-DFlash/accelerate_fsdp2_hybrid.yaml + - PATCH_FSDP2_BUFFERS: "1" + - MIXED_PRECISION: "no" + slurm_config: + _factory_: "slurm_factory" + nodes: 8 + ntasks_per_node: 1 + gpus_per_node: 8 diff --git a/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/_smoke_offline7.yaml b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/_smoke_offline7.yaml new file mode 100644 index 00000000000..3186a9fdc58 --- /dev/null +++ b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/_smoke_offline7.yaml @@ -0,0 +1,87 @@ +# DFlash offline speculative decoding training for MiniMax-M2.7 (229B MoE). +# +# 2-step pipeline (compare with hf_online_dflash.yaml, which streams the base model +# forward at training time instead): +# task_0: Dump base-model hidden states once via vLLM extract_hidden_states. +# task_1: Train the DFlash draft on the dump (FakeBaseModel — loads only lm_head + +# embed_tokens, not the full 229B base). +# +# We use the vLLM dump (compute_hidden_states_vllm.py) rather than the HF dump because +# the 229B MoE is impractical to forward on a single GPU; vLLM shards it with TP. The +# dump disables prefix caching so every token's hidden state is emitted and the dumped +# sequence lengths line up with input_ids/loss_mask. +# +# Reference: "DFlash: Block Diffusion for Flash Speculative Decoding" (arXiv:2602.06036) +# +# Usage: +# uv run slurm.py --yaml modules/Model-Optimizer/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_offline_dflash.yaml --yes + +job_name: MiniMax-M2.7-DFlash_offline_smoke7 +pipeline: + global_vars: + hf_model: /hf-local/MiniMaxAI/MiniMax-M2.7 + + # Step 1: Dump base-model hidden states via vLLM extract_hidden_states (TP=4). + task_0: + script: common/eagle3/dump_offline_data_vllm.sh + args: + - --input-data /hf-local/modelopt/MiniMax-M2.7-synthetic-data-clean-v2 + - --output-dir /scratchspace/smoke_hs7 + # Must match the draft model's num_hidden_layers (recipe default: 5). + - --aux-layers dflash + - --answer-only-loss + - --chat-template services/pipeline/dflash/chat_template_minimax-m2.7.jinja + - --max-seq-len 4096 + - --tp 4 + - --debug-max-num-conversations 16 + environment: + - HF_MODEL_CKPT: <> + - TRUST_REMOTE_CODE: "1" + slurm_config: + _factory_: "slurm_factory" + nodes: 1 + ntasks_per_node: 1 + gpus_per_node: 8 + container: vllm/vllm-openai:nightly + + # Step 2: Train DFlash offline on the dumped hidden states. FakeBaseModel avoids + # loading the full 229B — only lm_head + embed_tokens are read from the checkpoint. + task_1: + script: common/specdec/dflash_online_training.sh + args: + - --config modules/Model-Optimizer/modelopt_recipes/general/speculative_decoding/dflash.yaml + - model.model_name_or_path=<> + - model.trust_remote_code=true + - model.use_fake_base_for_offline=true + - data.mode=offline + - data.offline_data_path=/scratchspace/smoke_hs7 + - data.chat_template=services/pipeline/dflash/chat_template_minimax-m2.7.jinja + - training.output_dir=/scratchspace/smoke_offline7 + - training.max_steps=10 + - training.per_device_train_batch_size=2 + - training.learning_rate=1.2e-3 + - training.warmup_steps=100 + - training.training_seq_len=4096 + - training.logging_steps=100 + - training.save_steps=400 + - training.disable_tqdm=true + - training.dp_shard_size=1 + - training.answer_only_loss=true + - training.ddp_timeout=3600 + - training.bf16=false + - dflash.dflash_self_logit_distillation=true + - dflash.dflash_block_size=8 + - dflash.dflash_num_anchors=512 + - dflash.dflash_loss_decay_factor=4.0 + - dflash.dflash_architecture_config.num_hidden_layers=5 + environment: + - NUM_NODES: "8" + - OVERRIDE_TRANSFORMERS: "4.57.1" + - ACCELERATE_CONFIG: examples/MiniMax/MiniMax-M2.7-DFlash/accelerate_fsdp2_hybrid.yaml + - PATCH_FSDP2_BUFFERS: "1" + - MIXED_PRECISION: "no" + slurm_config: + _factory_: "slurm_factory" + nodes: 8 + ntasks_per_node: 1 + gpus_per_node: 8 diff --git a/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/_smoke_offline8.yaml b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/_smoke_offline8.yaml new file mode 100644 index 00000000000..1fe11b90e66 --- /dev/null +++ b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/_smoke_offline8.yaml @@ -0,0 +1,87 @@ +# DFlash offline speculative decoding training for MiniMax-M2.7 (229B MoE). +# +# 2-step pipeline (compare with hf_online_dflash.yaml, which streams the base model +# forward at training time instead): +# task_0: Dump base-model hidden states once via vLLM extract_hidden_states. +# task_1: Train the DFlash draft on the dump (FakeBaseModel — loads only lm_head + +# embed_tokens, not the full 229B base). +# +# We use the vLLM dump (compute_hidden_states_vllm.py) rather than the HF dump because +# the 229B MoE is impractical to forward on a single GPU; vLLM shards it with TP. The +# dump disables prefix caching so every token's hidden state is emitted and the dumped +# sequence lengths line up with input_ids/loss_mask. +# +# Reference: "DFlash: Block Diffusion for Flash Speculative Decoding" (arXiv:2602.06036) +# +# Usage: +# uv run slurm.py --yaml modules/Model-Optimizer/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_offline_dflash.yaml --yes + +job_name: MiniMax-M2.7-DFlash_offline_smoke8 +pipeline: + global_vars: + hf_model: /hf-local/MiniMaxAI/MiniMax-M2.7 + + # Step 1: Dump base-model hidden states via vLLM extract_hidden_states (TP=4). + task_0: + script: common/eagle3/dump_offline_data_vllm.sh + args: + - --input-data /hf-local/modelopt/MiniMax-M2.7-synthetic-data-clean-v2 + - --output-dir /scratchspace/smoke_hs8 + # Must match the draft model's num_hidden_layers (recipe default: 5). + - --aux-layers dflash + - --answer-only-loss + - --chat-template services/pipeline/dflash/chat_template_minimax-m2.7.jinja + - --max-seq-len 4096 + - --tp 4 + - --debug-max-num-conversations 16 + environment: + - HF_MODEL_CKPT: <> + - TRUST_REMOTE_CODE: "1" + slurm_config: + _factory_: "slurm_factory" + nodes: 1 + ntasks_per_node: 1 + gpus_per_node: 8 + container: vllm/vllm-openai:nightly + + # Step 2: Train DFlash offline on the dumped hidden states. FakeBaseModel avoids + # loading the full 229B — only lm_head + embed_tokens are read from the checkpoint. + task_1: + script: common/specdec/dflash_online_training.sh + args: + - --config modules/Model-Optimizer/modelopt_recipes/general/speculative_decoding/dflash.yaml + - model.model_name_or_path=<> + - model.trust_remote_code=true + - model.use_fake_base_for_offline=true + - data.mode=offline + - data.offline_data_path=/scratchspace/smoke_hs8 + - data.chat_template=services/pipeline/dflash/chat_template_minimax-m2.7.jinja + - training.output_dir=/scratchspace/smoke_offline8 + - training.max_steps=10 + - training.per_device_train_batch_size=2 + - training.learning_rate=1.2e-3 + - training.warmup_steps=100 + - training.training_seq_len=4096 + - training.logging_steps=100 + - training.save_steps=400 + - training.disable_tqdm=true + - training.dp_shard_size=1 + - training.answer_only_loss=true + - training.ddp_timeout=3600 + - training.bf16=false + - dflash.dflash_self_logit_distillation=true + - dflash.dflash_block_size=8 + - dflash.dflash_num_anchors=512 + - dflash.dflash_loss_decay_factor=4.0 + - dflash.dflash_architecture_config.num_hidden_layers=5 + environment: + - NUM_NODES: "8" + - OVERRIDE_TRANSFORMERS: "4.57.1" + - ACCELERATE_CONFIG: examples/MiniMax/MiniMax-M2.7-DFlash/accelerate_fsdp2_hybrid.yaml + - PATCH_FSDP2_BUFFERS: "1" + - MIXED_PRECISION: "no" + slurm_config: + _factory_: "slurm_factory" + nodes: 8 + ntasks_per_node: 1 + gpus_per_node: 8 diff --git a/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/_thru32k_test.yaml b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/_thru32k_test.yaml new file mode 100644 index 00000000000..d33adb35557 --- /dev/null +++ b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/_thru32k_test.yaml @@ -0,0 +1,32 @@ +job_name: MiniMax-M2.7-DFlash_thru32k_test_v2 +pipeline: + global_vars: + hf_model: /hf-local/MiniMaxAI/MiniMax-M2.7 + task_1: + script: common/specdec_bench/run.sh + args: + - --dataset speed + - --dataset_path /hf-local/nvidia/SPEED-Bench-Internal/throughput_32k + - --engine VLLM + - --speculative_algorithm DFLASH + - --draft_model_dir /lustre/fsw/portfolios/coreai/users/yeyu/experiments/dflash_minimax_m2.7_training_lr1.2e-3/exported-checkpoint-20400 + - --block_size 8 + - --tp_size 4 + - --ep_size 4 + - --concurrency 8 + - --num_requests 80 + - --output_length 4096 + - --max_seq_len 40960 + - --trust_remote_code + - --aa_timing + - --show_progress + - --save_dir /lustre/fsw/portfolios/coreai/users/yeyu/experiments/dflash_minimax_m2.7_training_lr1.2e-3/thru32k_results_v2 + environment: + - HF_MODEL_CKPT: <> + - HF_LOCAL: /hf-local + slurm_config: + _factory_: slurm_factory + nodes: 1 + ntasks_per_node: 1 + gpus_per_node: 8 + container: vllm/vllm-openai:nightly diff --git a/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/_thru32k_yarn.yaml b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/_thru32k_yarn.yaml new file mode 100644 index 00000000000..8aa634bbd9c --- /dev/null +++ b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/_thru32k_yarn.yaml @@ -0,0 +1,32 @@ +job_name: MiniMax-M2.7-DFlash_thru32k_test_yarn +pipeline: + global_vars: + hf_model: /hf-local/MiniMaxAI/MiniMax-M2.7 + task_1: + script: common/specdec_bench/run.sh + args: + - --dataset speed + - --dataset_path /hf-local/nvidia/SPEED-Bench-Internal/throughput_32k + - --engine VLLM + - --speculative_algorithm DFLASH + - --draft_model_dir /lustre/fsw/portfolios/coreai/users/yeyu/experiments/dflash_minimax_m2.7_training_lr1.2e-3/exported-checkpoint-20400-yarn + - --block_size 8 + - --tp_size 4 + - --ep_size 4 + - --concurrency 8 + - --num_requests 80 + - --output_length 4096 + - --max_seq_len 40960 + - --trust_remote_code + - --aa_timing + - --show_progress + - --save_dir /lustre/fsw/portfolios/coreai/users/yeyu/experiments/dflash_minimax_m2.7_training_lr1.2e-3/thru32k_results_yarn + environment: + - HF_MODEL_CKPT: <> + - HF_LOCAL: /hf-local + slurm_config: + _factory_: slurm_factory + nodes: 1 + ntasks_per_node: 1 + gpus_per_node: 8 + container: vllm/vllm-openai:nightly diff --git a/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_offline_dflash.yaml b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_offline_dflash.yaml index d81de3c8fbc..a63891cd14d 100644 --- a/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_offline_dflash.yaml +++ b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_offline_dflash.yaml @@ -73,11 +73,24 @@ pipeline: - dflash.dflash_num_anchors=512 - dflash.dflash_loss_decay_factor=4.0 - dflash.dflash_architecture_config.num_hidden_layers=5 + # Long-context RoPE injected at export (YaRN), same as the online recipe — the + # offline path also exports a draft and must match the target's long context. + # factor = 196608/4096 = 48. Mirrors nvidia/Kimi-K2.6-Eagle3. + - dflash.dflash_export_rope_scaling.type=yarn + - dflash.dflash_export_rope_scaling.factor=48.0 + - dflash.dflash_export_rope_scaling.original_max_position_embeddings=4096 + - dflash.dflash_export_rope_scaling.beta_fast=1.0 + - dflash.dflash_export_rope_scaling.beta_slow=1.0 + - dflash.dflash_export_rope_scaling.mscale=1.0 + - dflash.dflash_export_rope_scaling.mscale_all_dim=1.0 environment: - NUM_NODES: "8" + # Offline training uses a lightweight FakeBaseModel (embeddings + lm_head only, + # not the 229B base), so it fits comfortably under plain DDP — no FSDP needed. + # Leaving ACCELERATE_CONFIG unset makes the launcher fall back to DDP, which also + # bypasses the FSDP2 buffer/dtype patches entirely. OVERRIDE_TRANSFORMERS is still + # required: the MiniMax config (loaded for the fake base's dims) needs 4.57.1. - OVERRIDE_TRANSFORMERS: "4.57.1" - - ACCELERATE_CONFIG: examples/MiniMax/MiniMax-M2.7-DFlash/accelerate_fsdp2_hybrid.yaml - - PATCH_FSDP2_BUFFERS: "1" - MIXED_PRECISION: "no" slurm_config: _factory_: "slurm_factory" diff --git a/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_online_dflash.yaml b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_online_dflash.yaml index bfd4bfc9847..74bd28acbed 100644 --- a/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_online_dflash.yaml +++ b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_online_dflash.yaml @@ -55,6 +55,16 @@ pipeline: - dflash.dflash_num_anchors=512 - dflash.dflash_loss_decay_factor=4.0 - dflash.dflash_architecture_config.num_hidden_layers=5 + # Long-context RoPE injected at export (YaRN). DFlash trains on a short window + # (~4096) but drafts for MiniMax-M2.7 at up to 196608 tokens; factor = 196608/4096 + # = 48. Mirrors nvidia/Kimi-K2.6-Eagle3. Validated: 32k AL 1.17 -> 2.62. + - dflash.dflash_export_rope_scaling.type=yarn + - dflash.dflash_export_rope_scaling.factor=48.0 + - dflash.dflash_export_rope_scaling.original_max_position_embeddings=4096 + - dflash.dflash_export_rope_scaling.beta_fast=1.0 + - dflash.dflash_export_rope_scaling.beta_slow=1.0 + - dflash.dflash_export_rope_scaling.mscale=1.0 + - dflash.dflash_export_rope_scaling.mscale_all_dim=1.0 environment: - NUM_NODES: "8" - OVERRIDE_TRANSFORMERS: "4.57.1" From 035bad7059b288ebee6cf0e536de8b5b6a4028b7 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Wed, 10 Jun 2026 11:32:37 -0700 Subject: [PATCH 17/31] review: avoid embedding resize for DFlash mask token; reuse an existing reserved row MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Addresses @hguo-nv's question (How do we ship the model after resizing its word embedding?). The DFlash draft ships no embeddings — masked positions are embedded via the base/target embed_tokens, and at deployment vLLM reuses the target's table. Appending a '<|mask|>' row by resizing is therefore unsafe: with the base frozen the row is never trained, it is never exported, and it is absent from the target at inference. - main.py: when no mask token is configured, reuse an existing reserved embedding row (embedding padded past the used vocab) instead of resizing, so train==deploy by construction. Resize remains only as a last resort, now with a loud warning. - MiniMax recipes: pin dflash_mask_token_id=200054 (an existing reserved row; the embedding has 200064 rows, real tokens are 0..200053). Matches the released checkpoint's behavior while making it deterministic and resize-free. - doc/dflash.md: document the mask-id / embedding-sharing constraint. Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: Ye Yu --- examples/speculative_decoding/doc/dflash.md | 13 ++++- examples/speculative_decoding/main.py | 48 ++++++++++++++----- .../hf_offline_dflash.yaml | 5 ++ .../MiniMax-M2.7-DFlash/hf_online_dflash.yaml | 7 +++ 4 files changed, 61 insertions(+), 12 deletions(-) diff --git a/examples/speculative_decoding/doc/dflash.md b/examples/speculative_decoding/doc/dflash.md index 44db5d39e72..c48e8918c0b 100644 --- a/examples/speculative_decoding/doc/dflash.md +++ b/examples/speculative_decoding/doc/dflash.md @@ -164,9 +164,20 @@ See [`modelopt_recipes/general/speculative_decoding/dflash.yaml`](../../../model | `dflash.dflash_loss_decay_factor` | 4.0 | Exponential decay gamma (0 disables, see below) | | `dflash.dflash_self_logit_distillation` | true | Use target model logits as soft labels (vs hard CE) | | `dflash.dflash_architecture_config.num_hidden_layers` | 5 | Draft decoder layers | -| `dflash.dflash_architecture_config.mask_token_id` | auto | Token ID for masked positions | +| `dflash.dflash_mask_token_id` | auto | Token ID for masked positions (see note below) | +| `dflash.dflash_architecture_config.num_hidden_layers` | 5 | Draft decoder layers | | `training.answer_only_loss` | false | Mask loss on non-assistant tokens | +> **Note on `dflash_mask_token_id`:** masked positions are embedded with the **base/target +> model's** `embed_tokens` (the draft ships no embeddings — it reuses the target's at +> deployment). So the mask id must be a token that **physically exists in the target +> embedding**. If left unset and the tokenizer has no mask token, training prefers an +> existing reserved row (when the embedding is padded past the used vocab) over resizing, +> because a resized-in row would be neither trained (the base is frozen) nor exported, and +> absent from the target at inference. For production models, pin `dflash_mask_token_id` to +> an existing reserved token id — e.g. MiniMax-M2.7 uses `200054` (its embedding has 200064 +> rows; tokens 0..200053 are real, 200054+ are reserved). + > **Note on `answer_only_loss` and chat templates:** When `answer_only_loss=true`, the > tokenizer's chat template must include `{% generation %}` / `{% endgeneration %}` tags > around assistant content. HuggingFace uses these tags to produce `assistant_masks` via diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index 8d368ddfe1c..4e96e7fedb0 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -241,17 +241,43 @@ def train(): if recipe.dflash.dflash_mask_token_id is None: recipe.dflash.dflash_mask_token_id = getattr(tokenizer, "mask_token_id", None) if recipe.dflash.dflash_mask_token_id is None: - mask_token = "<|mask|>" - tokenizer.add_special_tokens({"mask_token": mask_token}) - orig_dtype = model.dtype - model.resize_token_embeddings(len(tokenizer)) - if model.dtype != orig_dtype: - model.to(orig_dtype) - recipe.dflash.dflash_mask_token_id = tokenizer.mask_token_id - print_rank_0( - f"Added {mask_token} (ID={tokenizer.mask_token_id}), " - f"resized embeddings to {len(tokenizer)}" - ) + # The DFlash draft ships NO embeddings — masked positions are embedded via + # the base/target embed_tokens at mask_token_id, and at deployment vLLM + # reuses the *target's* embed table. So the mask id must be a row that + # physically exists in the target embedding. Resizing to append a new + # "<|mask|>" row is unsafe: with the base model frozen the row is never + # trained, it is never exported (the draft has no embeddings), and it is + # absent from the target at inference. Prefer an existing reserved row — + # many tokenizers leave the embedding padded past the used vocab — so that + # train and deploy resolve the identical frozen embedding by construction. + embed = model.get_input_embeddings() + n_phys = embed.weight.shape[0] + n_used = len(tokenizer) + if n_phys > n_used: + recipe.dflash.dflash_mask_token_id = n_used + print_rank_0( + f"DFlash: no mask token configured; reusing existing reserved " + f"embedding row {n_used} as mask_token_id (embedding has {n_phys} " + f"rows, tokenizer vocab {n_used}). No resize — the row already " + f"exists in the target and is frozen, so train==deploy." + ) + else: + mask_token = "<|mask|>" + tokenizer.add_special_tokens({"mask_token": mask_token}) + orig_dtype = model.dtype + model.resize_token_embeddings(len(tokenizer)) + if model.dtype != orig_dtype: + model.to(orig_dtype) + recipe.dflash.dflash_mask_token_id = tokenizer.mask_token_id + print_rank_0( + f"WARNING: DFlash added {mask_token} (ID={tokenizer.mask_token_id}) " + f"and resized embeddings to {len(tokenizer)}. The DFlash draft does " + f"NOT export embeddings and the base model is frozen, so this new " + f"row is neither trained nor shipped. At deployment vLLM must find " + f"this id in the TARGET model's embed_tokens — ensure the target " + f"vocab physically contains index {tokenizer.mask_token_id}, or pin " + f"dflash.dflash_mask_token_id to an existing reserved token id." + ) dflash_cfg: dict = recipe.dflash.model_dump() mtsp.convert(model, [("dflash", dflash_cfg)]) else: diff --git a/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_offline_dflash.yaml b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_offline_dflash.yaml index a63891cd14d..9b2ba6d4729 100644 --- a/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_offline_dflash.yaml +++ b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_offline_dflash.yaml @@ -73,6 +73,11 @@ pipeline: - dflash.dflash_num_anchors=512 - dflash.dflash_loss_decay_factor=4.0 - dflash.dflash_architecture_config.num_hidden_layers=5 + # Mask token for block-diffusion — an existing reserved row in MiniMax-M2.7's + # 200064-row embedding (real tokens are 0..200053). The draft ships no embeddings, + # so the mask id must already exist in the target embed_tokens; 200054 makes + # train==deploy by construction and avoids resizing. See the online recipe. + - dflash.dflash_mask_token_id=200054 # Long-context RoPE injected at export (YaRN), same as the online recipe — the # offline path also exports a draft and must match the target's long context. # factor = 196608/4096 = 48. Mirrors nvidia/Kimi-K2.6-Eagle3. diff --git a/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_online_dflash.yaml b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_online_dflash.yaml index 74bd28acbed..b6db61708cb 100644 --- a/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_online_dflash.yaml +++ b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_online_dflash.yaml @@ -55,6 +55,13 @@ pipeline: - dflash.dflash_num_anchors=512 - dflash.dflash_loss_decay_factor=4.0 - dflash.dflash_architecture_config.num_hidden_layers=5 + # Mask token for block-diffusion. The DFlash draft ships no embeddings — it reuses + # the target's embed_tokens at deploy — so the mask id MUST be a token that already + # physically exists in the target embedding (not one added via resize, which would + # neither be trained, with the base frozen, nor exported). MiniMax-M2.7's embedding + # is 200064 rows while only 0..200053 are real tokens, so 200054 is an existing + # reserved row: using it makes train==deploy by construction and avoids resizing. + - dflash.dflash_mask_token_id=200054 # Long-context RoPE injected at export (YaRN). DFlash trains on a short window # (~4096) but drafts for MiniMax-M2.7 at up to 196608 tokens; factor = 196608/4096 # = 48. Mirrors nvidia/Kimi-K2.6-Eagle3. Validated: 32k AL 1.17 -> 2.62. From 645cd5bcb703d1f2fd2d64a099e923627c508ce1 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Wed, 10 Jun 2026 11:36:21 -0700 Subject: [PATCH 18/31] review: document applicability/scope of fsdp2_buffer_patch MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Addresses @hguo-nv's question (what range needs this patch?). Adds an Applicability section clarifying it is needed only for FSDP2-via-accelerate-config + cpu_ram_efficient loading on transformers 4.57.x — currently only MiniMax-M2.7 — and never for the transformers-5.x native ParallelismConfig path used by Qwen/Llama/Nemotron. Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: Ye Yu --- examples/speculative_decoding/fsdp2_buffer_patch.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/examples/speculative_decoding/fsdp2_buffer_patch.py b/examples/speculative_decoding/fsdp2_buffer_patch.py index 01a31fe7eec..ef19cd5a218 100644 --- a/examples/speculative_decoding/fsdp2_buffer_patch.py +++ b/examples/speculative_decoding/fsdp2_buffer_patch.py @@ -18,6 +18,18 @@ """Monkey-patch for accelerate's fsdp2_load_full_state_dict buffer handling. +Applicability (scope of this patch) +----------------------------------- +This is **not** needed for FSDP2 in general. It is required only for the narrow +combination of: **FSDP2 configured via an accelerate YAML config** (not torch-native +``ParallelismConfig``) **with** ``cpu_ram_efficient_loading=True``. Today that path is +forced by **models that require transformers 4.57.x** (their ``trust_remote_code`` code +predates transformers 5.x ``ParallelismConfig`` support) **and** are too large to load on +every rank — currently only **MiniMax-M2.7** (229B MoE). Models that run on transformers +5.x (Qwen, Llama, Nemotron, ...) use native ``ParallelismConfig`` (``dp_shard_size > 1``), +which handles buffers/dtypes correctly and never enters ``fsdp2_load_full_state_dict`` — +they need none of this. Gated off by default; activate with ``PATCH_FSDP2_BUFFERS=1``. + Problem ------- accelerate's ``fsdp2_load_full_state_dict`` (called during model preparation From 426fadfa3158b4cb377145da2296fe3e6a646b6f Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Wed, 10 Jun 2026 11:52:36 -0700 Subject: [PATCH 19/31] test: cover DFlash export rope-scaling field + mask-token resolution Addresses @hguo-nv's request to add tests for the new patches/callbacks. - Extract the mask-token-id resolution from main.py into a pure helper resolve_dflash_mask_token_id() in modeling_dflash.py (preference order: configured id -> tokenizer mask id -> existing reserved row -> resize), and unit-test all four branches (tests/.../plugins/test_hf_dflash.py). - Add DFlashExporter rope-scaling unit tests mirroring the eagle ones (tests/.../export/test_hf_spec_rope_export.py): YaRN injected from the dflash_export_rope_scaling config field, empty dict disables it, rope_theta inherits the base/target config. The DFlashExportCallback gather and fsdp2_buffer_patch state-dict path are inherently FSDP2/distributed and remain covered by the dflash regression suite (tests/regression/torch/speculative/test_dflash*.py). Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: Ye Yu --- examples/speculative_decoding/main.py | 71 +++++++++---------- .../speculative/plugins/modeling_dflash.py | 32 +++++++++ .../torch/export/test_hf_spec_rope_export.py | 67 ++++++++++++++++- .../speculative/plugins/test_hf_dflash.py | 43 +++++++++++ 4 files changed, 173 insertions(+), 40 deletions(-) diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index 4e96e7fedb0..a5e225d9696 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -58,6 +58,7 @@ from modelopt.torch.speculative.plugins.hf_training_args import ( TrainingArguments as SpecTrainingArgs, ) +from modelopt.torch.speculative.plugins.modeling_dflash import resolve_dflash_mask_token_id from modelopt.torch.speculative.utils import load_vlm_or_llm, patch_transformers5_params_loading from modelopt.torch.utils import print_rank_0 from modelopt.torch.utils.distributed import is_master, local_rank @@ -238,46 +239,40 @@ def train(): # Load draft vocab cache mtsp.plugins.HFEagleModel.load_draft_vocab_cache(model, recipe.data.draft_vocab_cache) elif isinstance(recipe, ModelOptDFlashRecipe): - if recipe.dflash.dflash_mask_token_id is None: - recipe.dflash.dflash_mask_token_id = getattr(tokenizer, "mask_token_id", None) - if recipe.dflash.dflash_mask_token_id is None: - # The DFlash draft ships NO embeddings — masked positions are embedded via - # the base/target embed_tokens at mask_token_id, and at deployment vLLM - # reuses the *target's* embed table. So the mask id must be a row that - # physically exists in the target embedding. Resizing to append a new - # "<|mask|>" row is unsafe: with the base model frozen the row is never - # trained, it is never exported (the draft has no embeddings), and it is - # absent from the target at inference. Prefer an existing reserved row — - # many tokenizers leave the embedding padded past the used vocab — so that - # train and deploy resolve the identical frozen embedding by construction. - embed = model.get_input_embeddings() - n_phys = embed.weight.shape[0] - n_used = len(tokenizer) - if n_phys > n_used: - recipe.dflash.dflash_mask_token_id = n_used - print_rank_0( - f"DFlash: no mask token configured; reusing existing reserved " - f"embedding row {n_used} as mask_token_id (embedding has {n_phys} " - f"rows, tokenizer vocab {n_used}). No resize — the row already " - f"exists in the target and is frozen, so train==deploy." - ) - else: - mask_token = "<|mask|>" - tokenizer.add_special_tokens({"mask_token": mask_token}) - orig_dtype = model.dtype - model.resize_token_embeddings(len(tokenizer)) - if model.dtype != orig_dtype: - model.to(orig_dtype) - recipe.dflash.dflash_mask_token_id = tokenizer.mask_token_id + # Resolve the mask token id without resizing the embedding when possible. + # The DFlash draft ships no embeddings (it reuses the base/target embed_tokens + # at deploy), so the mask id must already exist in the target embedding; see + # resolve_dflash_mask_token_id for the full rationale. + mask_id, needs_resize = resolve_dflash_mask_token_id( + configured_id=recipe.dflash.dflash_mask_token_id, + tokenizer_mask_id=getattr(tokenizer, "mask_token_id", None), + num_embedding_rows=model.get_input_embeddings().weight.shape[0], + tokenizer_len=len(tokenizer), + ) + if not needs_resize: + if mask_id != recipe.dflash.dflash_mask_token_id: print_rank_0( - f"WARNING: DFlash added {mask_token} (ID={tokenizer.mask_token_id}) " - f"and resized embeddings to {len(tokenizer)}. The DFlash draft does " - f"NOT export embeddings and the base model is frozen, so this new " - f"row is neither trained nor shipped. At deployment vLLM must find " - f"this id in the TARGET model's embed_tokens — ensure the target " - f"vocab physically contains index {tokenizer.mask_token_id}, or pin " - f"dflash.dflash_mask_token_id to an existing reserved token id." + f"DFlash: using mask_token_id={mask_id} (existing embedding row; " + f"no resize — train==deploy by construction)." ) + recipe.dflash.dflash_mask_token_id = mask_id + else: + mask_token = "<|mask|>" + tokenizer.add_special_tokens({"mask_token": mask_token}) + orig_dtype = model.dtype + model.resize_token_embeddings(len(tokenizer)) + if model.dtype != orig_dtype: + model.to(orig_dtype) + recipe.dflash.dflash_mask_token_id = tokenizer.mask_token_id + print_rank_0( + f"WARNING: DFlash added {mask_token} (ID={tokenizer.mask_token_id}) " + f"and resized embeddings to {len(tokenizer)}. The DFlash draft does " + f"NOT export embeddings and the base model is frozen, so this new row " + f"is neither trained nor shipped. At deployment vLLM must find this id " + f"in the TARGET model's embed_tokens — ensure the target vocab " + f"physically contains index {tokenizer.mask_token_id}, or pin " + f"dflash.dflash_mask_token_id to an existing reserved token id." + ) dflash_cfg: dict = recipe.dflash.model_dump() mtsp.convert(model, [("dflash", dflash_cfg)]) else: diff --git a/modelopt/torch/speculative/plugins/modeling_dflash.py b/modelopt/torch/speculative/plugins/modeling_dflash.py index 31ddcbf0cf9..4cd9da09c5c 100644 --- a/modelopt/torch/speculative/plugins/modeling_dflash.py +++ b/modelopt/torch/speculative/plugins/modeling_dflash.py @@ -69,6 +69,38 @@ def build_target_layer_ids(num_target_layers, num_draft_layers): return [round(start + (i * span) / (num_draft_layers - 1)) for i in range(num_draft_layers)] +def resolve_dflash_mask_token_id( + configured_id, tokenizer_mask_id, num_embedding_rows, tokenizer_len +): + """Decide the DFlash mask token id, avoiding an embedding resize when possible. + + The DFlash draft ships no embeddings of its own — masked positions are embedded via + the base/target ``embed_tokens``, and at deployment the draft reuses the target's + table. So the mask id must be a row that physically exists in the target embedding. + Appending a new ``<|mask|>`` row by resizing is unsafe in general: with the base model + frozen the row is never trained, it is never exported, and it is absent from the + target at inference. We therefore prefer, in order: + + 1. an explicitly configured id, + 2. the tokenizer's own mask token id, + 3. an existing reserved row (when the embedding is padded past the used vocab), + + and only fall back to a resize as a last resort. + + Returns: + (mask_token_id, needs_resize). When ``needs_resize`` is True, ``mask_token_id`` is + None and the caller must add a special token + resize the embeddings (last resort). + """ + if configured_id is not None: + return configured_id, False + if tokenizer_mask_id is not None: + return tokenizer_mask_id, False + if num_embedding_rows > tokenizer_len: + # First unused row that already physically exists in the (padded) target embedding. + return tokenizer_len, False + return None, True + + def apply_rotary_pos_emb(q, k, cos, sin): """Apply RoPE. Q uses last q_len positions, K uses all positions.""" cos = cos.unsqueeze(1) # [B, 1, seq, dim] diff --git a/tests/unit/torch/export/test_hf_spec_rope_export.py b/tests/unit/torch/export/test_hf_spec_rope_export.py index 171fe6263c6..720082bc617 100644 --- a/tests/unit/torch/export/test_hf_spec_rope_export.py +++ b/tests/unit/torch/export/test_hf_spec_rope_export.py @@ -13,11 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Unit tests for EAGLE export rope scaling logic in hf_spec_export.py.""" +"""Unit tests for EAGLE/DFlash export rope scaling logic in hf_spec_export.py.""" +from types import SimpleNamespace from unittest.mock import MagicMock -from modelopt.torch.export.plugins.hf_spec_export import EagleExporter +import torch + +from modelopt.torch.export.plugins.hf_spec_export import DFlashExporter, EagleExporter DEFAULT_ROPE_SCALING = { "rope_type": "yarn", @@ -76,3 +79,63 @@ def test_rope_theta_fallback_from_rope_scaling(): """rope_theta is populated from rope_scaling when not available as top-level attr.""" config = _make_exporter(rope_type="default", rope_theta=500000)._export_config() assert config["rope_theta"] == 500000 + + +# --------------------------------------------------------------------------- +# DFlash export rope scaling (config-field convergence, mirrors the eagle style) +# --------------------------------------------------------------------------- + +DFLASH_YARN = { + "type": "yarn", + "factor": 48.0, + "original_max_position_embeddings": 4096, + "beta_fast": 1.0, + "beta_slow": 1.0, + "mscale": 1.0, + "mscale_all_dim": 1.0, +} + + +def _make_dflash_exporter(dflash_export_rope_scaling=None, base_rope_theta=5000000.0): + base_config = SimpleNamespace( + hidden_size=128, + num_attention_heads=4, + num_key_value_heads=2, + intermediate_size=256, + vocab_size=1000, + max_position_embeddings=196608, + initializer_range=0.02, + num_hidden_layers=8, + rope_theta=base_rope_theta, + torch_dtype=torch.bfloat16, + ) + draft_config = SimpleNamespace(num_hidden_layers=2) + model = SimpleNamespace( + config=base_config, + dflash_config=draft_config, + dflash_block_size=8, + mask_token_id=999, + target_layer_ids=[1, 3, 5, 7], + dflash_export_rope_scaling=dflash_export_rope_scaling, + ) + exporter = DFlashExporter.__new__(DFlashExporter) + exporter.model = model + return exporter + + +def test_dflash_yarn_rope_injected_from_config_field(): + """YaRN rope_scaling from dflash_export_rope_scaling is injected verbatim.""" + config = _make_dflash_exporter(dflash_export_rope_scaling=DFLASH_YARN)._export_config() + assert config["rope_scaling"] == DFLASH_YARN + + +def test_dflash_rope_not_injected_when_field_empty(): + """Empty dict (default) disables rope scaling injection.""" + config = _make_dflash_exporter(dflash_export_rope_scaling={})._export_config() + assert config["rope_scaling"] is None + + +def test_dflash_rope_theta_inherits_base(): + """rope_theta is inherited from the target/base config (draft drafts for the base).""" + config = _make_dflash_exporter(base_rope_theta=5000000.0)._export_config() + assert config["rope_theta"] == 5000000.0 diff --git a/tests/unit/torch/speculative/plugins/test_hf_dflash.py b/tests/unit/torch/speculative/plugins/test_hf_dflash.py index e35ac698e76..59f17ee23f6 100644 --- a/tests/unit/torch/speculative/plugins/test_hf_dflash.py +++ b/tests/unit/torch/speculative/plugins/test_hf_dflash.py @@ -40,6 +40,7 @@ HFDFlashModel, build_target_layer_ids, ) +from modelopt.torch.speculative.plugins.modeling_dflash import resolve_dflash_mask_token_id from modelopt.torch.speculative.utils import AcceptanceRateValidation from modelopt.torch.utils.plugins.transformers_dataset import LanguageDataCollator @@ -467,3 +468,45 @@ def test_multi_turn_masks_only_assistant(self, tiny_tokenizer): # User/system content should NOT appear in unmasked tokens assert "You are helpful" not in decoded assert "How are you?" not in decoded + + +class TestResolveMaskTokenId: + """Mask-token-id resolution: avoid resizing embeddings when possible. + + The DFlash draft ships no embeddings — masked positions are embedded via the + base/target embed_tokens and reused at deploy — so the mask id must already exist + in the target embedding. resolve_dflash_mask_token_id encodes that preference order. + """ + + def test_explicit_configured_id_wins(self): + mask_id, needs_resize = resolve_dflash_mask_token_id( + configured_id=200054, + tokenizer_mask_id=7, + num_embedding_rows=200064, + tokenizer_len=200054, + ) + assert (mask_id, needs_resize) == (200054, False) + + def test_tokenizer_mask_id_used_when_unconfigured(self): + mask_id, needs_resize = resolve_dflash_mask_token_id( + configured_id=None, tokenizer_mask_id=42, num_embedding_rows=1000, tokenizer_len=999 + ) + assert (mask_id, needs_resize) == (42, False) + + def test_reuses_existing_reserved_row_when_padded(self): + # Embedding padded past the used vocab (e.g. MiniMax-M2.7: 200064 rows, 200054 used) + # -> reuse the first reserved row instead of resizing. + mask_id, needs_resize = resolve_dflash_mask_token_id( + configured_id=None, + tokenizer_mask_id=None, + num_embedding_rows=200064, + tokenizer_len=200054, + ) + assert (mask_id, needs_resize) == (200054, False) + + def test_needs_resize_when_not_padded(self): + # No spare rows -> caller must add a token + resize (last resort). + mask_id, needs_resize = resolve_dflash_mask_token_id( + configured_id=None, tokenizer_mask_id=None, num_embedding_rows=1000, tokenizer_len=1000 + ) + assert (mask_id, needs_resize) == (None, True) From bc97213e39fa093276fe1aaef56d4cd33beb3afd Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Wed, 10 Jun 2026 11:56:57 -0700 Subject: [PATCH 20/31] review: rename PATCH_FSDP2_BUFFERS -> PATCH_FSDP2_BUFFERS_TF457; trim dflash doc Per @hguo-nv: - Rename the activation env var to make its scope explicit (the patch is specific to the FSDP2-via-accelerate-config path on transformers 4.57.x); updated docstring, main.py gate, and the online recipe. - Trim the dflash_mask_token_id doc note to the user-facing essentials (drop the implementation/rationale prose). Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: Ye Yu --- examples/speculative_decoding/doc/dflash.md | 13 ++++--------- examples/speculative_decoding/fsdp2_buffer_patch.py | 6 +++--- examples/speculative_decoding/main.py | 8 ++++---- .../MiniMax-M2.7-DFlash/hf_online_dflash.yaml | 4 ++-- 4 files changed, 13 insertions(+), 18 deletions(-) diff --git a/examples/speculative_decoding/doc/dflash.md b/examples/speculative_decoding/doc/dflash.md index c48e8918c0b..37bbe268f70 100644 --- a/examples/speculative_decoding/doc/dflash.md +++ b/examples/speculative_decoding/doc/dflash.md @@ -168,15 +168,10 @@ See [`modelopt_recipes/general/speculative_decoding/dflash.yaml`](../../../model | `dflash.dflash_architecture_config.num_hidden_layers` | 5 | Draft decoder layers | | `training.answer_only_loss` | false | Mask loss on non-assistant tokens | -> **Note on `dflash_mask_token_id`:** masked positions are embedded with the **base/target -> model's** `embed_tokens` (the draft ships no embeddings — it reuses the target's at -> deployment). So the mask id must be a token that **physically exists in the target -> embedding**. If left unset and the tokenizer has no mask token, training prefers an -> existing reserved row (when the embedding is padded past the used vocab) over resizing, -> because a resized-in row would be neither trained (the base is frozen) nor exported, and -> absent from the target at inference. For production models, pin `dflash_mask_token_id` to -> an existing reserved token id — e.g. MiniMax-M2.7 uses `200054` (its embedding has 200064 -> rows; tokens 0..200053 are real, 200054+ are reserved). +> **Note on `dflash_mask_token_id`:** the draft reuses the target's `embed_tokens`, so the +> mask id must be a token that exists in the target embedding. Pin it to an existing +> reserved token id — e.g. MiniMax-M2.7 uses `200054` (embedding is 200064 rows; tokens +> 0..200053 are real, 200054+ reserved). > **Note on `answer_only_loss` and chat templates:** When `answer_only_loss=true`, the > tokenizer's chat template must include `{% generation %}` / `{% endgeneration %}` tags diff --git a/examples/speculative_decoding/fsdp2_buffer_patch.py b/examples/speculative_decoding/fsdp2_buffer_patch.py index ef19cd5a218..24a39f091bf 100644 --- a/examples/speculative_decoding/fsdp2_buffer_patch.py +++ b/examples/speculative_decoding/fsdp2_buffer_patch.py @@ -28,7 +28,7 @@ every rank — currently only **MiniMax-M2.7** (229B MoE). Models that run on transformers 5.x (Qwen, Llama, Nemotron, ...) use native ``ParallelismConfig`` (``dp_shard_size > 1``), which handles buffers/dtypes correctly and never enters ``fsdp2_load_full_state_dict`` — -they need none of this. Gated off by default; activate with ``PATCH_FSDP2_BUFFERS=1``. +they need none of this. Gated off by default; activate with ``PATCH_FSDP2_BUFFERS_TF457=1``. Problem ------- @@ -105,7 +105,7 @@ Nemotron, etc.) use ``ParallelismConfig`` natively by setting ``dp_shard_size > 1`` in the training args. That code path handles buffers correctly and does not go through ``fsdp2_load_full_state_dict`` at all. -No accelerate config file, no ``PATCH_FSDP2_BUFFERS``, no +No accelerate config file, no ``PATCH_FSDP2_BUFFERS_TF457``, no ``OVERRIDE_TRANSFORMERS`` needed. When to remove @@ -119,7 +119,7 @@ Activation ---------- -Set ``PATCH_FSDP2_BUFFERS=1`` in the environment to activate. Off by default. +Set ``PATCH_FSDP2_BUFFERS_TF457=1`` in the environment to activate. Off by default. Only needed in MiniMax-M2.7 pipeline YAMLs. """ diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index a5e225d9696..3494ef5f586 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -66,7 +66,7 @@ torch.manual_seed(0) mto.enable_huggingface_checkpointing() -if os.environ.get("PATCH_FSDP2_BUFFERS") == "1": +if os.environ.get("PATCH_FSDP2_BUFFERS_TF457") == "1": import fsdp2_buffer_patch fsdp2_buffer_patch.apply() @@ -324,10 +324,10 @@ def train(): # any single-device recipe) checkpoints are already full and the launcher script's # post-run export handles them, so the callback is unnecessary overhead. # FSDP2 is active via either route: native ParallelismConfig (dp_shard_size > 1) or - # the accelerate-config fallback used for transformers 4.57.x (PATCH_FSDP2_BUFFERS). + # the accelerate-config fallback used for transformers 4.57.x (PATCH_FSDP2_BUFFERS_TF457). if isinstance(recipe, ModelOptDFlashRecipe): uses_fsdp2 = (getattr(training_args, "dp_shard_size", 1) or 1) > 1 or os.environ.get( - "PATCH_FSDP2_BUFFERS" + "PATCH_FSDP2_BUFFERS_TF457" ) == "1" if uses_fsdp2: callbacks.append(DFlashExportCallback()) @@ -345,7 +345,7 @@ def train(): **data_module, ) - if os.environ.get("PATCH_FSDP2_BUFFERS") == "1": + if os.environ.get("PATCH_FSDP2_BUFFERS_TF457") == "1": fsdp2_buffer_patch.patch_accelerator(trainer.accelerator) # Manually enable this to return loss in eval diff --git a/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_online_dflash.yaml b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_online_dflash.yaml index b6db61708cb..b2881c8cd6b 100644 --- a/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_online_dflash.yaml +++ b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_online_dflash.yaml @@ -9,7 +9,7 @@ # - trust_remote_code model whose code requires transformers 4.57.x, so FSDP2 is # configured via an accelerate config (accelerate_fsdp2_hybrid.yaml) rather than # transformers-native ParallelismConfig. Hence OVERRIDE_TRANSFORMERS + ACCELERATE_CONFIG -# + dp_shard_size=1 (keeps main.py from building a ParallelismConfig) + PATCH_FSDP2_BUFFERS. +# + dp_shard_size=1 (keeps main.py from building a ParallelismConfig) + PATCH_FSDP2_BUFFERS_TF457. # - 229B in FP8 needs cpu_ram_efficient_loading (set in the accelerate config) and # MIXED_PRECISION=no (the recipe / checkpoint already carry the right dtypes). # - The DFlash draft is Qwen3-architecture (5 layers); the target/draft architectures @@ -76,7 +76,7 @@ pipeline: - NUM_NODES: "8" - OVERRIDE_TRANSFORMERS: "4.57.1" - ACCELERATE_CONFIG: examples/MiniMax/MiniMax-M2.7-DFlash/accelerate_fsdp2_hybrid.yaml - - PATCH_FSDP2_BUFFERS: "1" + - PATCH_FSDP2_BUFFERS_TF457: "1" - MIXED_PRECISION: "no" slurm_config: _factory_: "slurm_factory" From c7fec79e2d7f3bd6b11b8382272a1bf9e89b7376 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Wed, 10 Jun 2026 12:18:34 -0700 Subject: [PATCH 21/31] =?UTF-8?q?fix:=20CI=20failures=20=E2=80=94=20drop?= =?UTF-8?q?=20bogus=20num=5Ferror,=20ruff=20nits,=20remove=20scratch=20YAM?= =?UTF-8?q?Ls?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - compute_hidden_states_vllm.py: the loss_mask-length guard incremented a non-existent num_error counter (mypy name-defined / ruff F821); just skip. - fsdp2_buffer_patch.py: ruff SIM118 (drop .keys()) + C416 (list() not comprehension). - Remove six scratch pipeline YAMLs (_smoke_offline*, _thru32k_*) accidentally committed by a broad 'git add -A' — they also failed yamlfmt. Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: Ye Yu --- .../compute_hidden_states_vllm.py | 1 - .../fsdp2_buffer_patch.py | 4 +- .../MiniMax-M2.7-DFlash/_smoke_offline.yaml | 87 ------------------- .../MiniMax-M2.7-DFlash/_smoke_offline6.yaml | 87 ------------------- .../MiniMax-M2.7-DFlash/_smoke_offline7.yaml | 87 ------------------- .../MiniMax-M2.7-DFlash/_smoke_offline8.yaml | 87 ------------------- .../MiniMax-M2.7-DFlash/_thru32k_test.yaml | 32 ------- .../MiniMax-M2.7-DFlash/_thru32k_yarn.yaml | 32 ------- 8 files changed, 2 insertions(+), 415 deletions(-) delete mode 100644 tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/_smoke_offline.yaml delete mode 100644 tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/_smoke_offline6.yaml delete mode 100644 tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/_smoke_offline7.yaml delete mode 100644 tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/_smoke_offline8.yaml delete mode 100644 tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/_thru32k_test.yaml delete mode 100644 tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/_thru32k_yarn.yaml diff --git a/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py b/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py index 0a721fb5ad9..45129131293 100644 --- a/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py +++ b/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py @@ -290,7 +290,6 @@ def keep_conversation(entry): f"WARNING: {conv_id}: loss_mask ({loss_mask.shape[0]}) shorter than hidden " f"states ({n_hs}); skipping to avoid misalignment" ) - num_error += 1 continue output_file = output_dir / f"{conv_id}.pt" diff --git a/examples/speculative_decoding/fsdp2_buffer_patch.py b/examples/speculative_decoding/fsdp2_buffer_patch.py index 24a39f091bf..c5dfcfa62dd 100644 --- a/examples/speculative_decoding/fsdp2_buffer_patch.py +++ b/examples/speculative_decoding/fsdp2_buffer_patch.py @@ -175,7 +175,7 @@ def _patched(accelerator, model, full_sd, cpu_offload=False): # each broadcast tensor. if accelerator.is_main_process: dtype_codes = torch.tensor( - [_DTYPE_TO_CODE.get(full_sd[name].dtype, 0) for name in meta_sharded_sd.keys()], + [_DTYPE_TO_CODE.get(full_sd[name].dtype, 0) for name in meta_sharded_sd], dtype=torch.int32, device=accelerator.device, ) @@ -304,7 +304,7 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2): if isinstance(parameters, torch.Tensor): parameters = [parameters] - parameters = [p for p in parameters] # materialize generator + parameters = list(parameters) # materialize generator max_norm = float(max_norm) norm_type = float(norm_type) diff --git a/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/_smoke_offline.yaml b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/_smoke_offline.yaml deleted file mode 100644 index 924e1bcfbb4..00000000000 --- a/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/_smoke_offline.yaml +++ /dev/null @@ -1,87 +0,0 @@ -# DFlash offline speculative decoding training for MiniMax-M2.7 (229B MoE). -# -# 2-step pipeline (compare with hf_online_dflash.yaml, which streams the base model -# forward at training time instead): -# task_0: Dump base-model hidden states once via vLLM extract_hidden_states. -# task_1: Train the DFlash draft on the dump (FakeBaseModel — loads only lm_head + -# embed_tokens, not the full 229B base). -# -# We use the vLLM dump (compute_hidden_states_vllm.py) rather than the HF dump because -# the 229B MoE is impractical to forward on a single GPU; vLLM shards it with TP. The -# dump disables prefix caching so every token's hidden state is emitted and the dumped -# sequence lengths line up with input_ids/loss_mask. -# -# Reference: "DFlash: Block Diffusion for Flash Speculative Decoding" (arXiv:2602.06036) -# -# Usage: -# uv run slurm.py --yaml modules/Model-Optimizer/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_offline_dflash.yaml --yes - -job_name: MiniMax-M2.7-DFlash_offline_smoke5 -pipeline: - global_vars: - hf_model: /hf-local/MiniMaxAI/MiniMax-M2.7 - - # Step 1: Dump base-model hidden states via vLLM extract_hidden_states (TP=4). - task_0: - script: common/eagle3/dump_offline_data_vllm.sh - args: - - --input-data /hf-local/modelopt/MiniMax-M2.7-synthetic-data-clean-v2 - - --output-dir /scratchspace/smoke_hs5 - # Must match the draft model's num_hidden_layers (recipe default: 5). - - --aux-layers dflash - - --answer-only-loss - - --chat-template modules/Model-Optimizer/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/chat_template_train.jinja - - --max-seq-len 4096 - - --tp 4 - - --debug-max-num-conversations 16 - environment: - - HF_MODEL_CKPT: <> - - TRUST_REMOTE_CODE: "1" - slurm_config: - _factory_: "slurm_factory" - nodes: 1 - ntasks_per_node: 1 - gpus_per_node: 8 - container: vllm/vllm-openai:nightly - - # Step 2: Train DFlash offline on the dumped hidden states. FakeBaseModel avoids - # loading the full 229B — only lm_head + embed_tokens are read from the checkpoint. - task_1: - script: common/specdec/dflash_online_training.sh - args: - - --config modules/Model-Optimizer/modelopt_recipes/general/speculative_decoding/dflash.yaml - - model.model_name_or_path=<> - - model.trust_remote_code=true - - model.use_fake_base_for_offline=true - - data.mode=offline - - data.offline_data_path=/scratchspace/smoke_hs5 - - data.chat_template=modules/Model-Optimizer/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/chat_template_train.jinja - - training.output_dir=/scratchspace/smoke_offline5 - - training.max_steps=10 - - training.per_device_train_batch_size=2 - - training.learning_rate=1.2e-3 - - training.warmup_steps=100 - - training.training_seq_len=4096 - - training.logging_steps=100 - - training.save_steps=400 - - training.disable_tqdm=true - - training.dp_shard_size=1 - - training.answer_only_loss=true - - training.ddp_timeout=3600 - - training.bf16=false - - dflash.dflash_self_logit_distillation=true - - dflash.dflash_block_size=8 - - dflash.dflash_num_anchors=512 - - dflash.dflash_loss_decay_factor=4.0 - - dflash.dflash_architecture_config.num_hidden_layers=5 - environment: - - NUM_NODES: "8" - - OVERRIDE_TRANSFORMERS: "4.57.1" - - ACCELERATE_CONFIG: examples/MiniMax/MiniMax-M2.7-DFlash/accelerate_fsdp2_hybrid.yaml - - PATCH_FSDP2_BUFFERS: "1" - - MIXED_PRECISION: "no" - slurm_config: - _factory_: "slurm_factory" - nodes: 8 - ntasks_per_node: 1 - gpus_per_node: 8 diff --git a/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/_smoke_offline6.yaml b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/_smoke_offline6.yaml deleted file mode 100644 index 0ae74f85fcc..00000000000 --- a/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/_smoke_offline6.yaml +++ /dev/null @@ -1,87 +0,0 @@ -# DFlash offline speculative decoding training for MiniMax-M2.7 (229B MoE). -# -# 2-step pipeline (compare with hf_online_dflash.yaml, which streams the base model -# forward at training time instead): -# task_0: Dump base-model hidden states once via vLLM extract_hidden_states. -# task_1: Train the DFlash draft on the dump (FakeBaseModel — loads only lm_head + -# embed_tokens, not the full 229B base). -# -# We use the vLLM dump (compute_hidden_states_vllm.py) rather than the HF dump because -# the 229B MoE is impractical to forward on a single GPU; vLLM shards it with TP. The -# dump disables prefix caching so every token's hidden state is emitted and the dumped -# sequence lengths line up with input_ids/loss_mask. -# -# Reference: "DFlash: Block Diffusion for Flash Speculative Decoding" (arXiv:2602.06036) -# -# Usage: -# uv run slurm.py --yaml modules/Model-Optimizer/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_offline_dflash.yaml --yes - -job_name: MiniMax-M2.7-DFlash_offline_smoke6 -pipeline: - global_vars: - hf_model: /hf-local/MiniMaxAI/MiniMax-M2.7 - - # Step 1: Dump base-model hidden states via vLLM extract_hidden_states (TP=4). - task_0: - script: common/eagle3/dump_offline_data_vllm.sh - args: - - --input-data /hf-local/modelopt/MiniMax-M2.7-synthetic-data-clean-v2 - - --output-dir /scratchspace/smoke_hs6 - # Must match the draft model's num_hidden_layers (recipe default: 5). - - --aux-layers dflash - - --answer-only-loss - - --chat-template services/pipeline/dflash/chat_template_minimax-m2.7.jinja - - --max-seq-len 4096 - - --tp 4 - - --debug-max-num-conversations 16 - environment: - - HF_MODEL_CKPT: <> - - TRUST_REMOTE_CODE: "1" - slurm_config: - _factory_: "slurm_factory" - nodes: 1 - ntasks_per_node: 1 - gpus_per_node: 8 - container: vllm/vllm-openai:nightly - - # Step 2: Train DFlash offline on the dumped hidden states. FakeBaseModel avoids - # loading the full 229B — only lm_head + embed_tokens are read from the checkpoint. - task_1: - script: common/specdec/dflash_online_training.sh - args: - - --config modules/Model-Optimizer/modelopt_recipes/general/speculative_decoding/dflash.yaml - - model.model_name_or_path=<> - - model.trust_remote_code=true - - model.use_fake_base_for_offline=true - - data.mode=offline - - data.offline_data_path=/scratchspace/smoke_hs6 - - data.chat_template=services/pipeline/dflash/chat_template_minimax-m2.7.jinja - - training.output_dir=/scratchspace/smoke_offline6 - - training.max_steps=10 - - training.per_device_train_batch_size=2 - - training.learning_rate=1.2e-3 - - training.warmup_steps=100 - - training.training_seq_len=4096 - - training.logging_steps=100 - - training.save_steps=400 - - training.disable_tqdm=true - - training.dp_shard_size=1 - - training.answer_only_loss=true - - training.ddp_timeout=3600 - - training.bf16=false - - dflash.dflash_self_logit_distillation=true - - dflash.dflash_block_size=8 - - dflash.dflash_num_anchors=512 - - dflash.dflash_loss_decay_factor=4.0 - - dflash.dflash_architecture_config.num_hidden_layers=5 - environment: - - NUM_NODES: "8" - - OVERRIDE_TRANSFORMERS: "4.57.1" - - ACCELERATE_CONFIG: examples/MiniMax/MiniMax-M2.7-DFlash/accelerate_fsdp2_hybrid.yaml - - PATCH_FSDP2_BUFFERS: "1" - - MIXED_PRECISION: "no" - slurm_config: - _factory_: "slurm_factory" - nodes: 8 - ntasks_per_node: 1 - gpus_per_node: 8 diff --git a/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/_smoke_offline7.yaml b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/_smoke_offline7.yaml deleted file mode 100644 index 3186a9fdc58..00000000000 --- a/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/_smoke_offline7.yaml +++ /dev/null @@ -1,87 +0,0 @@ -# DFlash offline speculative decoding training for MiniMax-M2.7 (229B MoE). -# -# 2-step pipeline (compare with hf_online_dflash.yaml, which streams the base model -# forward at training time instead): -# task_0: Dump base-model hidden states once via vLLM extract_hidden_states. -# task_1: Train the DFlash draft on the dump (FakeBaseModel — loads only lm_head + -# embed_tokens, not the full 229B base). -# -# We use the vLLM dump (compute_hidden_states_vllm.py) rather than the HF dump because -# the 229B MoE is impractical to forward on a single GPU; vLLM shards it with TP. The -# dump disables prefix caching so every token's hidden state is emitted and the dumped -# sequence lengths line up with input_ids/loss_mask. -# -# Reference: "DFlash: Block Diffusion for Flash Speculative Decoding" (arXiv:2602.06036) -# -# Usage: -# uv run slurm.py --yaml modules/Model-Optimizer/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_offline_dflash.yaml --yes - -job_name: MiniMax-M2.7-DFlash_offline_smoke7 -pipeline: - global_vars: - hf_model: /hf-local/MiniMaxAI/MiniMax-M2.7 - - # Step 1: Dump base-model hidden states via vLLM extract_hidden_states (TP=4). - task_0: - script: common/eagle3/dump_offline_data_vllm.sh - args: - - --input-data /hf-local/modelopt/MiniMax-M2.7-synthetic-data-clean-v2 - - --output-dir /scratchspace/smoke_hs7 - # Must match the draft model's num_hidden_layers (recipe default: 5). - - --aux-layers dflash - - --answer-only-loss - - --chat-template services/pipeline/dflash/chat_template_minimax-m2.7.jinja - - --max-seq-len 4096 - - --tp 4 - - --debug-max-num-conversations 16 - environment: - - HF_MODEL_CKPT: <> - - TRUST_REMOTE_CODE: "1" - slurm_config: - _factory_: "slurm_factory" - nodes: 1 - ntasks_per_node: 1 - gpus_per_node: 8 - container: vllm/vllm-openai:nightly - - # Step 2: Train DFlash offline on the dumped hidden states. FakeBaseModel avoids - # loading the full 229B — only lm_head + embed_tokens are read from the checkpoint. - task_1: - script: common/specdec/dflash_online_training.sh - args: - - --config modules/Model-Optimizer/modelopt_recipes/general/speculative_decoding/dflash.yaml - - model.model_name_or_path=<> - - model.trust_remote_code=true - - model.use_fake_base_for_offline=true - - data.mode=offline - - data.offline_data_path=/scratchspace/smoke_hs7 - - data.chat_template=services/pipeline/dflash/chat_template_minimax-m2.7.jinja - - training.output_dir=/scratchspace/smoke_offline7 - - training.max_steps=10 - - training.per_device_train_batch_size=2 - - training.learning_rate=1.2e-3 - - training.warmup_steps=100 - - training.training_seq_len=4096 - - training.logging_steps=100 - - training.save_steps=400 - - training.disable_tqdm=true - - training.dp_shard_size=1 - - training.answer_only_loss=true - - training.ddp_timeout=3600 - - training.bf16=false - - dflash.dflash_self_logit_distillation=true - - dflash.dflash_block_size=8 - - dflash.dflash_num_anchors=512 - - dflash.dflash_loss_decay_factor=4.0 - - dflash.dflash_architecture_config.num_hidden_layers=5 - environment: - - NUM_NODES: "8" - - OVERRIDE_TRANSFORMERS: "4.57.1" - - ACCELERATE_CONFIG: examples/MiniMax/MiniMax-M2.7-DFlash/accelerate_fsdp2_hybrid.yaml - - PATCH_FSDP2_BUFFERS: "1" - - MIXED_PRECISION: "no" - slurm_config: - _factory_: "slurm_factory" - nodes: 8 - ntasks_per_node: 1 - gpus_per_node: 8 diff --git a/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/_smoke_offline8.yaml b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/_smoke_offline8.yaml deleted file mode 100644 index 1fe11b90e66..00000000000 --- a/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/_smoke_offline8.yaml +++ /dev/null @@ -1,87 +0,0 @@ -# DFlash offline speculative decoding training for MiniMax-M2.7 (229B MoE). -# -# 2-step pipeline (compare with hf_online_dflash.yaml, which streams the base model -# forward at training time instead): -# task_0: Dump base-model hidden states once via vLLM extract_hidden_states. -# task_1: Train the DFlash draft on the dump (FakeBaseModel — loads only lm_head + -# embed_tokens, not the full 229B base). -# -# We use the vLLM dump (compute_hidden_states_vllm.py) rather than the HF dump because -# the 229B MoE is impractical to forward on a single GPU; vLLM shards it with TP. The -# dump disables prefix caching so every token's hidden state is emitted and the dumped -# sequence lengths line up with input_ids/loss_mask. -# -# Reference: "DFlash: Block Diffusion for Flash Speculative Decoding" (arXiv:2602.06036) -# -# Usage: -# uv run slurm.py --yaml modules/Model-Optimizer/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_offline_dflash.yaml --yes - -job_name: MiniMax-M2.7-DFlash_offline_smoke8 -pipeline: - global_vars: - hf_model: /hf-local/MiniMaxAI/MiniMax-M2.7 - - # Step 1: Dump base-model hidden states via vLLM extract_hidden_states (TP=4). - task_0: - script: common/eagle3/dump_offline_data_vllm.sh - args: - - --input-data /hf-local/modelopt/MiniMax-M2.7-synthetic-data-clean-v2 - - --output-dir /scratchspace/smoke_hs8 - # Must match the draft model's num_hidden_layers (recipe default: 5). - - --aux-layers dflash - - --answer-only-loss - - --chat-template services/pipeline/dflash/chat_template_minimax-m2.7.jinja - - --max-seq-len 4096 - - --tp 4 - - --debug-max-num-conversations 16 - environment: - - HF_MODEL_CKPT: <> - - TRUST_REMOTE_CODE: "1" - slurm_config: - _factory_: "slurm_factory" - nodes: 1 - ntasks_per_node: 1 - gpus_per_node: 8 - container: vllm/vllm-openai:nightly - - # Step 2: Train DFlash offline on the dumped hidden states. FakeBaseModel avoids - # loading the full 229B — only lm_head + embed_tokens are read from the checkpoint. - task_1: - script: common/specdec/dflash_online_training.sh - args: - - --config modules/Model-Optimizer/modelopt_recipes/general/speculative_decoding/dflash.yaml - - model.model_name_or_path=<> - - model.trust_remote_code=true - - model.use_fake_base_for_offline=true - - data.mode=offline - - data.offline_data_path=/scratchspace/smoke_hs8 - - data.chat_template=services/pipeline/dflash/chat_template_minimax-m2.7.jinja - - training.output_dir=/scratchspace/smoke_offline8 - - training.max_steps=10 - - training.per_device_train_batch_size=2 - - training.learning_rate=1.2e-3 - - training.warmup_steps=100 - - training.training_seq_len=4096 - - training.logging_steps=100 - - training.save_steps=400 - - training.disable_tqdm=true - - training.dp_shard_size=1 - - training.answer_only_loss=true - - training.ddp_timeout=3600 - - training.bf16=false - - dflash.dflash_self_logit_distillation=true - - dflash.dflash_block_size=8 - - dflash.dflash_num_anchors=512 - - dflash.dflash_loss_decay_factor=4.0 - - dflash.dflash_architecture_config.num_hidden_layers=5 - environment: - - NUM_NODES: "8" - - OVERRIDE_TRANSFORMERS: "4.57.1" - - ACCELERATE_CONFIG: examples/MiniMax/MiniMax-M2.7-DFlash/accelerate_fsdp2_hybrid.yaml - - PATCH_FSDP2_BUFFERS: "1" - - MIXED_PRECISION: "no" - slurm_config: - _factory_: "slurm_factory" - nodes: 8 - ntasks_per_node: 1 - gpus_per_node: 8 diff --git a/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/_thru32k_test.yaml b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/_thru32k_test.yaml deleted file mode 100644 index d33adb35557..00000000000 --- a/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/_thru32k_test.yaml +++ /dev/null @@ -1,32 +0,0 @@ -job_name: MiniMax-M2.7-DFlash_thru32k_test_v2 -pipeline: - global_vars: - hf_model: /hf-local/MiniMaxAI/MiniMax-M2.7 - task_1: - script: common/specdec_bench/run.sh - args: - - --dataset speed - - --dataset_path /hf-local/nvidia/SPEED-Bench-Internal/throughput_32k - - --engine VLLM - - --speculative_algorithm DFLASH - - --draft_model_dir /lustre/fsw/portfolios/coreai/users/yeyu/experiments/dflash_minimax_m2.7_training_lr1.2e-3/exported-checkpoint-20400 - - --block_size 8 - - --tp_size 4 - - --ep_size 4 - - --concurrency 8 - - --num_requests 80 - - --output_length 4096 - - --max_seq_len 40960 - - --trust_remote_code - - --aa_timing - - --show_progress - - --save_dir /lustre/fsw/portfolios/coreai/users/yeyu/experiments/dflash_minimax_m2.7_training_lr1.2e-3/thru32k_results_v2 - environment: - - HF_MODEL_CKPT: <> - - HF_LOCAL: /hf-local - slurm_config: - _factory_: slurm_factory - nodes: 1 - ntasks_per_node: 1 - gpus_per_node: 8 - container: vllm/vllm-openai:nightly diff --git a/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/_thru32k_yarn.yaml b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/_thru32k_yarn.yaml deleted file mode 100644 index 8aa634bbd9c..00000000000 --- a/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/_thru32k_yarn.yaml +++ /dev/null @@ -1,32 +0,0 @@ -job_name: MiniMax-M2.7-DFlash_thru32k_test_yarn -pipeline: - global_vars: - hf_model: /hf-local/MiniMaxAI/MiniMax-M2.7 - task_1: - script: common/specdec_bench/run.sh - args: - - --dataset speed - - --dataset_path /hf-local/nvidia/SPEED-Bench-Internal/throughput_32k - - --engine VLLM - - --speculative_algorithm DFLASH - - --draft_model_dir /lustre/fsw/portfolios/coreai/users/yeyu/experiments/dflash_minimax_m2.7_training_lr1.2e-3/exported-checkpoint-20400-yarn - - --block_size 8 - - --tp_size 4 - - --ep_size 4 - - --concurrency 8 - - --num_requests 80 - - --output_length 4096 - - --max_seq_len 40960 - - --trust_remote_code - - --aa_timing - - --show_progress - - --save_dir /lustre/fsw/portfolios/coreai/users/yeyu/experiments/dflash_minimax_m2.7_training_lr1.2e-3/thru32k_results_yarn - environment: - - HF_MODEL_CKPT: <> - - HF_LOCAL: /hf-local - slurm_config: - _factory_: slurm_factory - nodes: 1 - ntasks_per_node: 1 - gpus_per_node: 8 - container: vllm/vllm-openai:nightly From 4df2d605d1b1930a5b27353ab6f0caa17645eefd Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Thu, 11 Jun 2026 09:17:42 -0700 Subject: [PATCH 22/31] fix: launcher tests for the new requeue path in build_slurm_executor The requeue block I added to build_slurm_executor reads slurm_config.requeue; the existing build_slurm_executor tests pass a bare MagicMock, so .requeue was a truthy mock and the block ran against a mocked executor (max(, 3) -> TypeError). Set requeue=False on those fixtures (matches the real SlurmConfig default) and add a dedicated test for requeue=True asserting the additional parameter + retries bump. Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: Ye Yu --- tools/launcher/tests/test_slurm_executor.py | 44 +++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/tools/launcher/tests/test_slurm_executor.py b/tools/launcher/tests/test_slurm_executor.py index 900616136e3..6baec571e3b 100644 --- a/tools/launcher/tests/test_slurm_executor.py +++ b/tools/launcher/tests/test_slurm_executor.py @@ -34,6 +34,7 @@ def test_scratch_and_modelopt_mounts(self, mock_tunnel, mock_executor): mock_tunnel.return_value = MagicMock() slurm_config = MagicMock( + requeue=False, host="test-host", port=22, account="test_account", @@ -77,6 +78,7 @@ def test_scratch_path_uses_experiment_title(self, mock_tunnel, mock_executor): mock_tunnel.return_value = MagicMock() slurm_config = MagicMock( + requeue=False, host="host", port=22, account="acct", @@ -112,6 +114,7 @@ def test_tunnel_created_with_correct_params(self, mock_tunnel, mock_executor): mock_tunnel.return_value = MagicMock() slurm_config = MagicMock( + requeue=False, host="login.cluster.com", port=30022, account="acct", @@ -150,6 +153,7 @@ def test_executor_params(self, mock_tunnel, mock_executor): mock_tunnel.return_value = MagicMock() slurm_config = MagicMock( + requeue=False, host="h", port=22, account="my_acct", @@ -195,6 +199,7 @@ def test_none_container_mounts_handled(self, mock_tunnel, mock_executor): mock_tunnel.return_value = MagicMock() slurm_config = MagicMock( + requeue=False, host="h", port=22, account="a", @@ -222,3 +227,42 @@ def test_none_container_mounts_handled(self, mock_tunnel, mock_executor): # Should not crash; mounts should still include scratch + modelopt + title mounts = mock_executor.call_args[1]["container_mounts"] assert len(mounts) >= 3 + + @patch("core.run.SlurmExecutor") + @patch("core.run.SSHTunnel") + def test_requeue_sets_param_and_bumps_retries(self, mock_tunnel, mock_executor): + mock_tunnel.return_value = MagicMock() + executor = mock_executor.return_value + executor.retries = 0 + executor.additional_parameters = {} + + slurm_config = MagicMock( + requeue=True, + host="h", + port=22, + account="a", + partition="b", + container="c", + modelopt_install_path="/m", + container_mounts=[], + srun_args=[], + nodes=1, + ntasks_per_node=1, + gpus_per_node=1, + array=None, + ) + + build_slurm_executor( + user="u", + identity=None, + slurm_config=slurm_config, + experiment_id="e", + job_dir="/j", + task_name="t", + packager=MagicMock(), + ) + + # requeue=True flags the additional parameter and bumps retries above 0 so + # nemo-run's sbatch wrapper actually issues `scontrol requeue` on preemption. + assert executor.additional_parameters["requeue"] is True + assert executor.retries == 3 From e82ef5694c29a057c5ebf824da5b51611edfe936 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Thu, 11 Jun 2026 10:12:31 -0700 Subject: [PATCH 23/31] review: revert mask-token resize/helper; pin id in recipe (per @hguo-nv) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit @hguo-nv asked why the mask-token change is needed and when we'd ever resize the embedding. The answer: we don't — the original code (recipe id, else tokenizer mask_token_id, else raise) already works because the MiniMax recipes pin dflash_mask_token_id=200054 (an existing reserved row). So: - Revert main.py to the original mask-token resolution (no resize branch). - Remove the resolve_dflash_mask_token_id helper and its unit tests. - Trim the over-documented mask/rope/DDP comments in the MiniMax recipes to terse, example-appropriate one-liners. The export RoPE config-field tests (test_hf_spec_rope_export.py) are unaffected. Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: Ye Yu --- examples/speculative_decoding/main.py | 41 +++--------------- .../speculative/plugins/modeling_dflash.py | 32 -------------- .../speculative/plugins/test_hf_dflash.py | 43 ------------------- .../hf_offline_dflash.yaml | 17 +++----- .../MiniMax-M2.7-DFlash/hf_online_dflash.yaml | 12 ++---- 5 files changed, 15 insertions(+), 130 deletions(-) diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index 3494ef5f586..7b97548011e 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -58,7 +58,6 @@ from modelopt.torch.speculative.plugins.hf_training_args import ( TrainingArguments as SpecTrainingArgs, ) -from modelopt.torch.speculative.plugins.modeling_dflash import resolve_dflash_mask_token_id from modelopt.torch.speculative.utils import load_vlm_or_llm, patch_transformers5_params_loading from modelopt.torch.utils import print_rank_0 from modelopt.torch.utils.distributed import is_master, local_rank @@ -239,39 +238,13 @@ def train(): # Load draft vocab cache mtsp.plugins.HFEagleModel.load_draft_vocab_cache(model, recipe.data.draft_vocab_cache) elif isinstance(recipe, ModelOptDFlashRecipe): - # Resolve the mask token id without resizing the embedding when possible. - # The DFlash draft ships no embeddings (it reuses the base/target embed_tokens - # at deploy), so the mask id must already exist in the target embedding; see - # resolve_dflash_mask_token_id for the full rationale. - mask_id, needs_resize = resolve_dflash_mask_token_id( - configured_id=recipe.dflash.dflash_mask_token_id, - tokenizer_mask_id=getattr(tokenizer, "mask_token_id", None), - num_embedding_rows=model.get_input_embeddings().weight.shape[0], - tokenizer_len=len(tokenizer), - ) - if not needs_resize: - if mask_id != recipe.dflash.dflash_mask_token_id: - print_rank_0( - f"DFlash: using mask_token_id={mask_id} (existing embedding row; " - f"no resize — train==deploy by construction)." - ) - recipe.dflash.dflash_mask_token_id = mask_id - else: - mask_token = "<|mask|>" - tokenizer.add_special_tokens({"mask_token": mask_token}) - orig_dtype = model.dtype - model.resize_token_embeddings(len(tokenizer)) - if model.dtype != orig_dtype: - model.to(orig_dtype) - recipe.dflash.dflash_mask_token_id = tokenizer.mask_token_id - print_rank_0( - f"WARNING: DFlash added {mask_token} (ID={tokenizer.mask_token_id}) " - f"and resized embeddings to {len(tokenizer)}. The DFlash draft does " - f"NOT export embeddings and the base model is frozen, so this new row " - f"is neither trained nor shipped. At deployment vLLM must find this id " - f"in the TARGET model's embed_tokens — ensure the target vocab " - f"physically contains index {tokenizer.mask_token_id}, or pin " - f"dflash.dflash_mask_token_id to an existing reserved token id." + # Fall back to tokenizer.mask_token_id when not set in the recipe; require one of the two. + if recipe.dflash.dflash_mask_token_id is None: + recipe.dflash.dflash_mask_token_id = getattr(tokenizer, "mask_token_id", None) + if recipe.dflash.dflash_mask_token_id is None: + raise ValueError( + "dflash.dflash_mask_token_id is required: set it in the recipe YAML " + "or use a tokenizer that defines mask_token_id." ) dflash_cfg: dict = recipe.dflash.model_dump() mtsp.convert(model, [("dflash", dflash_cfg)]) diff --git a/modelopt/torch/speculative/plugins/modeling_dflash.py b/modelopt/torch/speculative/plugins/modeling_dflash.py index 4cd9da09c5c..31ddcbf0cf9 100644 --- a/modelopt/torch/speculative/plugins/modeling_dflash.py +++ b/modelopt/torch/speculative/plugins/modeling_dflash.py @@ -69,38 +69,6 @@ def build_target_layer_ids(num_target_layers, num_draft_layers): return [round(start + (i * span) / (num_draft_layers - 1)) for i in range(num_draft_layers)] -def resolve_dflash_mask_token_id( - configured_id, tokenizer_mask_id, num_embedding_rows, tokenizer_len -): - """Decide the DFlash mask token id, avoiding an embedding resize when possible. - - The DFlash draft ships no embeddings of its own — masked positions are embedded via - the base/target ``embed_tokens``, and at deployment the draft reuses the target's - table. So the mask id must be a row that physically exists in the target embedding. - Appending a new ``<|mask|>`` row by resizing is unsafe in general: with the base model - frozen the row is never trained, it is never exported, and it is absent from the - target at inference. We therefore prefer, in order: - - 1. an explicitly configured id, - 2. the tokenizer's own mask token id, - 3. an existing reserved row (when the embedding is padded past the used vocab), - - and only fall back to a resize as a last resort. - - Returns: - (mask_token_id, needs_resize). When ``needs_resize`` is True, ``mask_token_id`` is - None and the caller must add a special token + resize the embeddings (last resort). - """ - if configured_id is not None: - return configured_id, False - if tokenizer_mask_id is not None: - return tokenizer_mask_id, False - if num_embedding_rows > tokenizer_len: - # First unused row that already physically exists in the (padded) target embedding. - return tokenizer_len, False - return None, True - - def apply_rotary_pos_emb(q, k, cos, sin): """Apply RoPE. Q uses last q_len positions, K uses all positions.""" cos = cos.unsqueeze(1) # [B, 1, seq, dim] diff --git a/tests/unit/torch/speculative/plugins/test_hf_dflash.py b/tests/unit/torch/speculative/plugins/test_hf_dflash.py index 59f17ee23f6..e35ac698e76 100644 --- a/tests/unit/torch/speculative/plugins/test_hf_dflash.py +++ b/tests/unit/torch/speculative/plugins/test_hf_dflash.py @@ -40,7 +40,6 @@ HFDFlashModel, build_target_layer_ids, ) -from modelopt.torch.speculative.plugins.modeling_dflash import resolve_dflash_mask_token_id from modelopt.torch.speculative.utils import AcceptanceRateValidation from modelopt.torch.utils.plugins.transformers_dataset import LanguageDataCollator @@ -468,45 +467,3 @@ def test_multi_turn_masks_only_assistant(self, tiny_tokenizer): # User/system content should NOT appear in unmasked tokens assert "You are helpful" not in decoded assert "How are you?" not in decoded - - -class TestResolveMaskTokenId: - """Mask-token-id resolution: avoid resizing embeddings when possible. - - The DFlash draft ships no embeddings — masked positions are embedded via the - base/target embed_tokens and reused at deploy — so the mask id must already exist - in the target embedding. resolve_dflash_mask_token_id encodes that preference order. - """ - - def test_explicit_configured_id_wins(self): - mask_id, needs_resize = resolve_dflash_mask_token_id( - configured_id=200054, - tokenizer_mask_id=7, - num_embedding_rows=200064, - tokenizer_len=200054, - ) - assert (mask_id, needs_resize) == (200054, False) - - def test_tokenizer_mask_id_used_when_unconfigured(self): - mask_id, needs_resize = resolve_dflash_mask_token_id( - configured_id=None, tokenizer_mask_id=42, num_embedding_rows=1000, tokenizer_len=999 - ) - assert (mask_id, needs_resize) == (42, False) - - def test_reuses_existing_reserved_row_when_padded(self): - # Embedding padded past the used vocab (e.g. MiniMax-M2.7: 200064 rows, 200054 used) - # -> reuse the first reserved row instead of resizing. - mask_id, needs_resize = resolve_dflash_mask_token_id( - configured_id=None, - tokenizer_mask_id=None, - num_embedding_rows=200064, - tokenizer_len=200054, - ) - assert (mask_id, needs_resize) == (200054, False) - - def test_needs_resize_when_not_padded(self): - # No spare rows -> caller must add a token + resize (last resort). - mask_id, needs_resize = resolve_dflash_mask_token_id( - configured_id=None, tokenizer_mask_id=None, num_embedding_rows=1000, tokenizer_len=1000 - ) - assert (mask_id, needs_resize) == (None, True) diff --git a/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_offline_dflash.yaml b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_offline_dflash.yaml index 9b2ba6d4729..dc3dba4a65a 100644 --- a/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_offline_dflash.yaml +++ b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_offline_dflash.yaml @@ -73,14 +73,9 @@ pipeline: - dflash.dflash_num_anchors=512 - dflash.dflash_loss_decay_factor=4.0 - dflash.dflash_architecture_config.num_hidden_layers=5 - # Mask token for block-diffusion — an existing reserved row in MiniMax-M2.7's - # 200064-row embedding (real tokens are 0..200053). The draft ships no embeddings, - # so the mask id must already exist in the target embed_tokens; 200054 makes - # train==deploy by construction and avoids resizing. See the online recipe. + # Mask token id (an existing reserved row in MiniMax-M2.7's embedding). - dflash.dflash_mask_token_id=200054 - # Long-context RoPE injected at export (YaRN), same as the online recipe — the - # offline path also exports a draft and must match the target's long context. - # factor = 196608/4096 = 48. Mirrors nvidia/Kimi-K2.6-Eagle3. + # YaRN rope_scaling injected at export for long context (factor = 196608/4096 = 48). - dflash.dflash_export_rope_scaling.type=yarn - dflash.dflash_export_rope_scaling.factor=48.0 - dflash.dflash_export_rope_scaling.original_max_position_embeddings=4096 @@ -90,11 +85,9 @@ pipeline: - dflash.dflash_export_rope_scaling.mscale_all_dim=1.0 environment: - NUM_NODES: "8" - # Offline training uses a lightweight FakeBaseModel (embeddings + lm_head only, - # not the 229B base), so it fits comfortably under plain DDP — no FSDP needed. - # Leaving ACCELERATE_CONFIG unset makes the launcher fall back to DDP, which also - # bypasses the FSDP2 buffer/dtype patches entirely. OVERRIDE_TRANSFORMERS is still - # required: the MiniMax config (loaded for the fake base's dims) needs 4.57.1. + # Offline training uses a lightweight FakeBaseModel, so plain DDP suffices (no + # ACCELERATE_CONFIG / FSDP2 patches). OVERRIDE_TRANSFORMERS pins 4.57.1 for the + # MiniMax config load. - OVERRIDE_TRANSFORMERS: "4.57.1" - MIXED_PRECISION: "no" slurm_config: diff --git a/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_online_dflash.yaml b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_online_dflash.yaml index b2881c8cd6b..92c450ee550 100644 --- a/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_online_dflash.yaml +++ b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_online_dflash.yaml @@ -55,16 +55,10 @@ pipeline: - dflash.dflash_num_anchors=512 - dflash.dflash_loss_decay_factor=4.0 - dflash.dflash_architecture_config.num_hidden_layers=5 - # Mask token for block-diffusion. The DFlash draft ships no embeddings — it reuses - # the target's embed_tokens at deploy — so the mask id MUST be a token that already - # physically exists in the target embedding (not one added via resize, which would - # neither be trained, with the base frozen, nor exported). MiniMax-M2.7's embedding - # is 200064 rows while only 0..200053 are real tokens, so 200054 is an existing - # reserved row: using it makes train==deploy by construction and avoids resizing. + # Mask token id. The draft reuses the target's embed_tokens, so this must be an id + # that exists in the target embedding; 200054 is a reserved row in MiniMax-M2.7. - dflash.dflash_mask_token_id=200054 - # Long-context RoPE injected at export (YaRN). DFlash trains on a short window - # (~4096) but drafts for MiniMax-M2.7 at up to 196608 tokens; factor = 196608/4096 - # = 48. Mirrors nvidia/Kimi-K2.6-Eagle3. Validated: 32k AL 1.17 -> 2.62. + # YaRN rope_scaling injected at export for long context (factor = 196608/4096 = 48). - dflash.dflash_export_rope_scaling.type=yarn - dflash.dflash_export_rope_scaling.factor=48.0 - dflash.dflash_export_rope_scaling.original_max_position_embeddings=4096 From 58466572788184b72c032358c013bb75bbcb2b6f Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Fri, 12 Jun 2026 10:50:32 -0700 Subject: [PATCH 24/31] review: clarify rope_theta rationale (KV injection); trim rope_scaling comment MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per @hguo-nv: - rope_theta: the base-first choice is intentional for DFlash — the draft injects the target's KV into every layer, so its RoPE base must match the target's. The draft arch config carries no rope_theta, so draft-first fell back to the 1e6 default (a mismatch with the target's 5e6). Tightened the comment to state this. - Trimmed the verbose rope_scaling comment to the essentials. Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: Ye Yu --- modelopt/torch/export/plugins/hf_spec_export.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/modelopt/torch/export/plugins/hf_spec_export.py b/modelopt/torch/export/plugins/hf_spec_export.py index b35888dc1d2..f8efe0035af 100644 --- a/modelopt/torch/export/plugins/hf_spec_export.py +++ b/modelopt/torch/export/plugins/hf_spec_export.py @@ -376,9 +376,9 @@ def _export_config(self): "initializer_range": getattr(base_config, "initializer_range", 0.02), "attention_bias": getattr(draft_config, "attention_bias", False), "attention_dropout": getattr(draft_config, "attention_dropout", 0.0), - # Inherit the target's rope_theta — the draft drafts for the base model, so its - # RoPE base must match it. (DFlash trains with a minimal rope; the real - # long-context RoPE is applied here at export.) + # Inherit the target's rope_theta: DFlash injects the target's KV into every + # draft layer, so the draft's RoPE base must match the target's. (The draft + # arch config carries no rope_theta of its own.) "rope_theta": getattr(base_config, "rope_theta", None) or getattr(draft_config, "rope_theta", 1000000.0), # YaRN long-context scaling is injected below (see the rope_scaling block). @@ -396,12 +396,8 @@ def _export_config(self): else: config["layer_types"] = ["full_attention"] * draft_config.num_hidden_layers - # Long-context RoPE (YaRN). The draft trains on a short window but must draft for - # the target at long context, so — mirroring published Eagle3 drafts such as - # nvidia/Kimi-K2.6-Eagle3 — inject a YaRN rope_scaling that extends the training - # window to the target's full context. Sourced from the config field - # dflash_export_rope_scaling (set in the recipe YAML), matching the eagle - # eagle_export_rope_scaling convention. Empty dict (default) disables injection. + # Inject the export-time YaRN rope_scaling from the dflash_export_rope_scaling + # config field (empty dict disables). Mirrors eagle's eagle_export_rope_scaling. export_rope_scaling = getattr(self.model, "dflash_export_rope_scaling", None) if export_rope_scaling: config["rope_scaling"] = export_rope_scaling From 49123793a6ac1f8a00ccd4b2ab93cc7f9b845de8 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Fri, 12 Jun 2026 11:00:34 -0700 Subject: [PATCH 25/31] review: gate DFlashExportCallback on SHARDED_STATE_DICT; reuse exporter extraction MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Addresses @hguo-nv's points 1-2 on the export callback: 1. The callback now self-skips when the checkpoint was saved as a full state dict (model.safetensors / pytorch_model.bin present) — i.e. it only does work under FSDP2 SHARDED_STATE_DICT, where checkpoints are distributed shards the post-run export can't read. The check is from the on-disk checkpoint format (identical across ranks, before the collective gather). main.py now appends it unconditionally for DFlash since it self-gates (DDP / single-device / FSDP2-full all self-skip), removing the coarse dp_shard_size/env heuristic. 2. Extraction now delegates to DFlashExporter._extract_state_dict (the same logic the normal export path uses) for the common prefixed-key layout, with a fallback for the already-stripped submodule-gather keys. Config generation already delegated to the exporter; this removes the remaining duplication. The gather (get_model_state_dict on the dflash_module submodule) — the load-bearing FSDP2 part that produced the release — is unchanged. The full split into separate reusable gather/export callbacks is deferred: offline DFlash uses DDP (full checkpoints, no callback needed) and eagle has its own export path, so there is no current consumer. Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: Ye Yu --- examples/speculative_decoding/eagle_utils.py | 52 ++++++++++++++------ examples/speculative_decoding/main.py | 25 +++------- 2 files changed, 42 insertions(+), 35 deletions(-) diff --git a/examples/speculative_decoding/eagle_utils.py b/examples/speculative_decoding/eagle_utils.py index 71a9f3da223..82619a00771 100644 --- a/examples/speculative_decoding/eagle_utils.py +++ b/examples/speculative_decoding/eagle_utils.py @@ -250,12 +250,29 @@ def on_save(self, args, state, control, **kwargs): return control step = state.global_step + # This callback is only needed under FSDP2 SHARDED_STATE_DICT, where the + # checkpoint holds distributed shards (pytorch_model_fsdp_0/) and no consolidated + # weights. When the checkpoint was saved as a full state dict (model.safetensors / + # pytorch_model.bin), the post-run export_hf_checkpoint.py pass can read it + # directly, so skip the gather here. The decision is made from the on-disk + # checkpoint format — identical across ranks (shared FS), and before the collective + # gather below, so all ranks take the same branch. + ckpt_dir = os.path.join(args.output_dir, f"checkpoint-{step}") + if any( + os.path.exists(os.path.join(ckpt_dir, f)) + for f in ( + "model.safetensors", + "model.safetensors.index.json", + "pytorch_model.bin", + "pytorch_model.bin.index.json", + ) + ): + return control + export_dir = os.path.join(args.output_dir, f"exported-checkpoint-{step}") - # All ranks participate in state_dict gather (FSDP2 collective op). - # Use get_model_state_dict to get the full (ungathered) weights regardless - # of fsdp_state_dict_type setting. Only the dflash_module submodule is - # gathered (~328 MB for MiniMax-M2.7), not the full 229B base model. + # All ranks participate in the state_dict gather (FSDP2 collective op). Only the + # dflash_module submodule is gathered (~328 MB for MiniMax-M2.7), not the 229B base. try: from torch.distributed.checkpoint.state_dict import ( StateDictOptions, @@ -274,18 +291,22 @@ def on_save(self, args, state, control, **kwargs): # Non-distributed / single-GPU fallback raw_sd = model.state_dict() - # Extract dflash_module keys and strip prefix - drafter_sd = {} - for key, value in raw_sd.items(): - if "dflash_module." in key: - export_key = key.split("dflash_module.", 1)[1] - if "rotary_emb" not in export_key: - drafter_sd[export_key] = value.cpu() if value.device.type != "cpu" else value - elif not any(prefix in key for prefix in ("model.", "lm_head.", "embed_tokens.")): - # Keys already without prefix (from submodule state_dict) - if "rotary_emb" not in key: - drafter_sd[key] = value.cpu() if value.device.type != "cpu" else value + # Reuse the exporter's extraction (strips the dflash_module prefix, drops rotary + # buffers) for the common full-model key layout. Some PyTorch versions return the + # submodule gather with keys already stripped of the prefix — handle that directly. + exporter = model.get_exporter() + drafter_sd = exporter._extract_state_dict(raw_sd) + if not drafter_sd: + drafter_sd = { + k: v + for k, v in raw_sd.items() + if "rotary_emb" not in k + and not any(p in k for p in ("model.", "lm_head.", "embed_tokens.")) + } del raw_sd + # Coerce to CPU for save_file (the distributed gather uses cpu_offload, but the + # single-GPU fallback may return CUDA tensors). + drafter_sd = {k: (v.cpu() if v.device.type != "cpu" else v) for k, v in drafter_sd.items()} if not drafter_sd: print_rank_0(f"Warning: No dflash_module weights found at step {step}, skipping export") @@ -297,7 +318,6 @@ def on_save(self, args, state, control, **kwargs): os.makedirs(export_dir, exist_ok=True) save_file(drafter_sd, os.path.join(export_dir, "model.safetensors")) - exporter = model.get_exporter() config = exporter._export_config() with open(os.path.join(export_dir, "config.json"), "w") as f: json.dump(config, f, indent=2) diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index 3379afb4754..bd645d15f25 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -285,26 +285,13 @@ def train(): # map-style, so HF Trainer's resume skips consumed indices at the batch-sampler # level (accelerate.skip_first_batches) without re-fetching them, landing at the # exact data position. Setting it True would restart the data order from the top. - # DFlash: export the draft submodule after every checkpoint save — but only under - # FSDP2. With FSDP2 SHARDED_STATE_DICT, checkpoint-* dirs hold only distributed - # shards that the post-training export_hf_checkpoint.py pass can't read, so this - # callback gathers just the small draft module per save and writes a deployable - # exported-checkpoint-{step}/. Under DDP (e.g. offline FakeBaseModel training, or - # any single-device recipe) checkpoints are already full and the launcher script's - # post-run export handles them, so the callback is unnecessary overhead. - # FSDP2 is active via either route: native ParallelismConfig (dp_shard_size > 1) or - # the accelerate-config fallback used for transformers 4.57.x (PATCH_FSDP2_BUFFERS_TF457). + # DFlash: export the draft submodule after each checkpoint save. The callback only + # does work under FSDP2 SHARDED_STATE_DICT (where checkpoint-* dirs hold distributed + # shards the post-training export_hf_checkpoint.py pass can't read); for full + # checkpoints — DDP, single-device, or FSDP2 FULL_STATE_DICT — it self-skips, since + # the launcher's post-run export handles those. So it's safe to add unconditionally. if isinstance(recipe, ModelOptDFlashRecipe): - uses_fsdp2 = (getattr(training_args, "dp_shard_size", 1) or 1) > 1 or os.environ.get( - "PATCH_FSDP2_BUFFERS_TF457" - ) == "1" - if uses_fsdp2: - callbacks.append(DFlashExportCallback()) - else: - print_rank_0( - "DFlash: non-FSDP2 run detected — skipping per-step DFlashExportCallback; " - "checkpoints are full and will be exported post-training by the launcher script." - ) + callbacks.append(DFlashExportCallback()) trainer = EagleTrainerWithAccLog( model=model, From 01778c8f94a5e760539ae5fc5774d50fed9a200a Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Fri, 12 Jun 2026 19:11:09 -0700 Subject: [PATCH 26/31] review: gate export callback by fsdp_state_dict_type, rename it, move dtype diag Implements @hguo-nv's agreed asks on the DFlash export callback: - Rename DFlashExportCallback -> DFlashFSDP2ShardedSDExportCallback to make its applicable range (FSDP2 SHARDED_STATE_DICT only) explicit in the name. - Gate it in main.py by reading the live accelerator FSDP state dict type (trainer.accelerator.state.fsdp_plugin.state_dict_type) instead of the callback's on-disk checkpoint-filename heuristic; the callback is added only for SHARDED_STATE_DICT (full-state-dict runs use the launcher's post-run export). Removed the in-callback early-exit accordingly. - Move the per-rank dtype diagnostic out of main.py into fsdp2_buffer_patch.log_param_dtypes(), gated behind DFLASH_LOG_PARAM_DTYPES=1 (no-op by default), since it only exists to verify the FSDP2 dtype sync. Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: Ye Yu --- examples/speculative_decoding/eagle_utils.py | 34 +++++------------ .../fsdp2_buffer_patch.py | 20 ++++++++++ examples/speculative_decoding/main.py | 38 ++++++++++--------- 3 files changed, 51 insertions(+), 41 deletions(-) diff --git a/examples/speculative_decoding/eagle_utils.py b/examples/speculative_decoding/eagle_utils.py index 82619a00771..be54889be07 100644 --- a/examples/speculative_decoding/eagle_utils.py +++ b/examples/speculative_decoding/eagle_utils.py @@ -230,12 +230,17 @@ def on_step_begin(self, args, state, control, **kwargs): return control -class DFlashExportCallback(TrainerCallback): - """Export DFlash draft module after each checkpoint save. +class DFlashFSDP2ShardedSDExportCallback(TrainerCallback): + """Export the DFlash draft module after each checkpoint save, for FSDP2 sharded runs. - Under FSDP2 SHARDED_STATE_DICT, checkpoints only contain distributed shards - (pytorch_model_fsdp_0/), not model.safetensors. This callback extracts the - small draft module weights and saves them in deployment format after each save. + Applicable range: this is needed only under FSDP2 ``SHARDED_STATE_DICT``, where the + checkpoint holds distributed shards (``pytorch_model_fsdp_0/``) and no consolidated + ``model.safetensors`` — so the post-training ``export_hf_checkpoint.py`` pass can't read + it. It gathers just the small draft submodule and writes the deployable export. + + Gating is the caller's responsibility: ``main.py`` adds this callback only when the + accelerator's FSDP state dict type is ``SHARDED_STATE_DICT`` (full-state-dict runs — + DDP, single-device, FSDP2 FULL_STATE_DICT — use the launcher's post-run export instead). """ def on_save(self, args, state, control, **kwargs): @@ -250,25 +255,6 @@ def on_save(self, args, state, control, **kwargs): return control step = state.global_step - # This callback is only needed under FSDP2 SHARDED_STATE_DICT, where the - # checkpoint holds distributed shards (pytorch_model_fsdp_0/) and no consolidated - # weights. When the checkpoint was saved as a full state dict (model.safetensors / - # pytorch_model.bin), the post-run export_hf_checkpoint.py pass can read it - # directly, so skip the gather here. The decision is made from the on-disk - # checkpoint format — identical across ranks (shared FS), and before the collective - # gather below, so all ranks take the same branch. - ckpt_dir = os.path.join(args.output_dir, f"checkpoint-{step}") - if any( - os.path.exists(os.path.join(ckpt_dir, f)) - for f in ( - "model.safetensors", - "model.safetensors.index.json", - "pytorch_model.bin", - "pytorch_model.bin.index.json", - ) - ): - return control - export_dir = os.path.join(args.output_dir, f"exported-checkpoint-{step}") # All ranks participate in the state_dict gather (FSDP2 collective op). Only the diff --git a/examples/speculative_decoding/fsdp2_buffer_patch.py b/examples/speculative_decoding/fsdp2_buffer_patch.py index c5dfcfa62dd..2a34e45edc7 100644 --- a/examples/speculative_decoding/fsdp2_buffer_patch.py +++ b/examples/speculative_decoding/fsdp2_buffer_patch.py @@ -371,3 +371,23 @@ def patch_accelerator(accelerator): print( "[fsdp2_buffer_patch] Patched accelerator.clip_grad_norm_ for FSDP2 DTensor compatibility" ) + + +def log_param_dtypes(model): + """Debug aid: log per-rank parameter dtype counts (gated by DFLASH_LOG_PARAM_DTYPES=1). + + Used to verify the FSDP2 dtype synchronization above — after ``fully_shard()`` params + are DTensors whose dtype lives on ``_local_tensor``. Off by default; this is purely + diagnostic and has no effect on training. + """ + import os + + if os.environ.get("DFLASH_LOG_PARAM_DTYPES") != "1": + return + rank = int(os.environ.get("RANK", 0)) + dtypes = {} + for name, p in model.named_parameters(): + dt_key = str(p.dtype) if not hasattr(p, "_local_tensor") else str(p._local_tensor.dtype) + dtypes.setdefault(dt_key, []).append(name) + for dt, names in dtypes.items(): + print(f"[dtype_check rank={rank}] {dt}: {len(names)} params (e.g. {names[0]})") diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index bd645d15f25..7fc7c61215c 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -33,10 +33,11 @@ import dataclasses import os +import fsdp2_buffer_patch import torch import transformers from eagle_utils import ( - DFlashExportCallback, + DFlashFSDP2ShardedSDExportCallback, EagleTrainerWithAccLog, EagleTrainingPlot, LoRAWarmupCallback, @@ -66,8 +67,6 @@ mto.enable_huggingface_checkpointing() if os.environ.get("PATCH_FSDP2_BUFFERS_TF457") == "1": - import fsdp2_buffer_patch - fsdp2_buffer_patch.apply() @@ -285,13 +284,6 @@ def train(): # map-style, so HF Trainer's resume skips consumed indices at the batch-sampler # level (accelerate.skip_first_batches) without re-fetching them, landing at the # exact data position. Setting it True would restart the data order from the top. - # DFlash: export the draft submodule after each checkpoint save. The callback only - # does work under FSDP2 SHARDED_STATE_DICT (where checkpoint-* dirs hold distributed - # shards the post-training export_hf_checkpoint.py pass can't read); for full - # checkpoints — DDP, single-device, or FSDP2 FULL_STATE_DICT — it self-skips, since - # the launcher's post-run export handles those. So it's safe to add unconditionally. - if isinstance(recipe, ModelOptDFlashRecipe): - callbacks.append(DFlashExportCallback()) trainer = EagleTrainerWithAccLog( model=model, @@ -304,6 +296,23 @@ def train(): if os.environ.get("PATCH_FSDP2_BUFFERS_TF457") == "1": fsdp2_buffer_patch.patch_accelerator(trainer.accelerator) + # DFlash: export the draft submodule after each checkpoint save — but only under FSDP2 + # SHARDED_STATE_DICT, where checkpoints are distributed shards the post-training + # export_hf_checkpoint.py pass can't read. Gate by reading the live FSDP state dict + # type off the accelerator; full-state-dict runs (DDP, single-device, FSDP2 + # FULL_STATE_DICT) use the launcher's post-run export instead. + if isinstance(recipe, ModelOptDFlashRecipe): + fsdp_plugin = getattr(trainer.accelerator.state, "fsdp_plugin", None) + sd_type = str(getattr(fsdp_plugin, "state_dict_type", "") or "") + if "SHARDED_STATE_DICT" in sd_type: + trainer.add_callback(DFlashFSDP2ShardedSDExportCallback()) + print_rank_0("DFlash: FSDP2 SHARDED_STATE_DICT — enabling per-save draft export.") + else: + print_rank_0( + f"DFlash: checkpoints use {sd_type or 'a full state dict'}; relying on the " + "launcher's post-run export (no per-save export callback added)." + ) + # Manually enable this to return loss in eval trainer.can_return_loss = True # Make sure label_smoother is None @@ -311,13 +320,8 @@ def train(): "label_smoother is not supported in speculative decoding!" ) - rank = int(os.environ.get("RANK", 0)) - dtypes = {} - for name, p in trainer.model.named_parameters(): - dt_key = str(p.dtype) if not hasattr(p, "_local_tensor") else str(p._local_tensor.dtype) - dtypes.setdefault(dt_key, []).append(name) - for dt, names in dtypes.items(): - print(f"[dtype_check rank={rank}] {dt}: {len(names)} params (e.g. {names[0]})") + # Diagnostic (no-op unless DFLASH_LOG_PARAM_DTYPES=1): verifies FSDP2 dtype sync. + fsdp2_buffer_patch.log_param_dtypes(trainer.model) print_rank_0("Start training...") trainer.train(resume_from_checkpoint=checkpoint) From 786b7184a3248952a93ac89aa0e36b158aea055b Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Mon, 15 Jun 2026 14:10:09 -0700 Subject: [PATCH 27/31] fix: enforce draft rope_theta/rope_type from base model (not setdefault) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Addresses @hguo-nv: the modify() inheritance was a setdefault, so a rope_theta provided in dflash_architecture_config would NOT be overridden — the draft would train with that value while the exporter writes the base's rope_theta, causing a train/inference mismatch. DFlash injects the target's KV into every draft layer, so the draft's RoPE base must match the target's. Enforce rope_theta/rope_type/rope_interleaved from the base model (overwrite any user value, with a warning); keep the other architecture attrs as setdefault. This makes training and export agree by construction. rope_scaling stays excluded (added only at export via dflash_export_rope_scaling). Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: Ye Yu --- .../torch/speculative/plugins/hf_dflash.py | 39 ++++++++++++++----- 1 file changed, 29 insertions(+), 10 deletions(-) diff --git a/modelopt/torch/speculative/plugins/hf_dflash.py b/modelopt/torch/speculative/plugins/hf_dflash.py index 1760cb2072d..9e54ee531ce 100644 --- a/modelopt/torch/speculative/plugins/hf_dflash.py +++ b/modelopt/torch/speculative/plugins/hf_dflash.py @@ -132,26 +132,45 @@ def modify(self, config): self.dflash_config.hidden_size = base_config.hidden_size self.dflash_config.vocab_size = base_config.vocab_size - # Inherit architecture settings from base model when not specified by user. - # Static defaults (hidden_act, attention_bias, etc.) are in dflash/default_config.py. - # NOTE: rope_scaling is intentionally excluded. DFlash draft uses Qwen3 - # RotaryEmbedding which only supports standard RoPE. Inheriting M-RoPE - # config from multimodal models (e.g. Qwen3.5) would be incorrect. - _base_model_attrs = [ + # Inherit architecture settings from base model when not specified by user + # (setdefault). Static defaults (hidden_act, attention_bias, etc.) are in + # dflash/default_config.py. + _setdefault_attrs = [ "max_position_embeddings", "intermediate_size", "num_attention_heads", "num_key_value_heads", - "rope_theta", - "rope_type", - "rope_interleaved", "rms_norm_eps", ] - for attr in _base_model_attrs: + for attr in _setdefault_attrs: if not hasattr(self.dflash_config, attr) or getattr(self.dflash_config, attr) is None: if hasattr(base_config, attr): setattr(self.dflash_config, attr, getattr(base_config, attr)) + # RoPE base settings are ENFORCED to match the base model (not setdefault): the + # DFlash draft injects the target's KV into every layer, so its RoPE base must + # match the target's for the injected positions to align — and the exporter writes + # the base model's rope_theta. Letting dflash_architecture_config override these + # would make training (draft rope) and inference (base rope) disagree, so we + # overwrite any user value and warn. (rope_scaling is intentionally NOT inherited: + # DFlash uses standard Qwen3 RotaryEmbedding; the long-context YaRN scaling is + # added only at export via dflash_export_rope_scaling.) + for attr in ("rope_theta", "rope_type", "rope_interleaved"): + if not hasattr(base_config, attr): + continue + base_val = getattr(base_config, attr) + user_val = getattr(self.dflash_config, attr, None) + if user_val is not None and user_val != base_val: + logger.warning( + "DFlash: ignoring dflash_architecture_config.%s=%r and enforcing the " + "base model's value %r — the draft injects the target's KV, so its RoPE " + "base must match the target's.", + attr, + user_val, + base_val, + ) + setattr(self.dflash_config, attr, base_val) + self.dflash_config.head_dim = getattr( self.dflash_config, "head_dim", From ff7dae6132f8710b7bbbb2e0a712a62717c35417 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Mon, 15 Jun 2026 14:14:20 -0700 Subject: [PATCH 28/31] review: extract _is_hf_format_checkpoint helper for the resume-format check MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per @hguo-nv: tidy the inline checkpoint-format probe in main.py into a named helper. Kept it distinct from the export-callback's save-mode gate on purpose — this inspects the *resume* checkpoint's on-disk format (a property of existing bytes, checked pre-trainer), whereas the gate reads the *current run's* fsdp_state_dict_type (post-trainer); the two answer different questions and can differ across runs, so they shouldn't share one FSDP2+ShardedSD condition. Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: Ye Yu --- examples/speculative_decoding/main.py | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index 7fc7c61215c..bf247cc3db4 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -150,6 +150,23 @@ def init_distributed_env(training_args: transformers.TrainingArguments) -> None: ) +def _is_hf_format_checkpoint(checkpoint: str | None) -> bool: + """True if the checkpoint dir holds consolidated HF weights (from_pretrained-loadable). + + FSDP2 SHARDED_STATE_DICT checkpoints contain only distributed shards + (``pytorch_model_fsdp_*/``), no ``model.safetensors`` — those return False, signalling + the caller to load the base model and resume via the Trainer instead. This inspects the + on-disk format of the *resume* checkpoint, which is a property of the existing bytes and + is independent of the current run's save mode (the two can differ across runs), so it's + intentionally separate from the save-time FSDP state-dict-type gate used for the export + callback. + """ + if not checkpoint: + return False + hf_files = ("model.safetensors", "pytorch_model.bin", "model.safetensors.index.json") + return any(os.path.isfile(os.path.join(checkpoint, f)) for f in hf_files) + + def train(): config_path, dry_run, overrides = _parse_cli() recipe = load_recipe(config_path, overrides=overrides) @@ -186,12 +203,10 @@ def train(): use_offline_training = recipe.data.mode != "online" - # Check if checkpoint has HF-format model files (compatible with from_pretrained). - # FSDP distributed checkpoints (pytorch_model_fsdp_*) don't — load base model instead. - _hf_ckpt_files = ("model.safetensors", "pytorch_model.bin", "model.safetensors.index.json") - checkpoint_is_hf = checkpoint and any( - os.path.isfile(os.path.join(checkpoint, f)) for f in _hf_ckpt_files - ) + # Resume path depends on the existing checkpoint's on-disk format: consolidated HF + # weights load via from_pretrained; FSDP sharded checkpoints load the base model and + # resume through the Trainer. + checkpoint_is_hf = _is_hf_format_checkpoint(checkpoint) if checkpoint_is_hf: assert checkpoint is not None # guaranteed by checkpoint_is_hf From 82814228ee840097c027eef97d1517771856faf3 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Mon, 15 Jun 2026 14:29:43 -0700 Subject: [PATCH 29/31] review(claude-bot): fix clip_grad_norm deadlock + silent fallbacks MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Addresses the Claude review's CRITICAL + silent-fallback findings: - CRITICAL: _clip_grad_norm no longer early-returns on an empty-grad rank before the collective — under FSDP2 sharding / MoE+LoRA co-train a rank can legitimately have no grads (e.g. an expert with no tokens), and returning early while other ranks all_reduce deadlocks the job. Every rank now reaches the (guarded) all_reduce; empty-grad ranks contribute a zero norm and clip nothing. - fsdp2_buffer_patch dtype-sync: raise on an unmapped dtype instead of silently coercing to fp32 (which would cast on the wire or trip an NCCL size mismatch). - DFlash export callback: warn loudly on the two fallback paths — the full-model gather when get_model_state_dict lacks submodules=, and the denylist heuristic when the prefix-based extraction finds nothing — so a slow/OOM gather or a malformed export isn't silent. - hf_spec_export rope_theta: use 'is not None' instead of 'or' so a base rope_theta of 0 isn't swallowed (parity with the eagle path). - Cleanups: drop the dead _orig capture; remove the duplicated num_hidden_layers doc row. Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: Ye Yu --- examples/speculative_decoding/doc/dflash.md | 1 - examples/speculative_decoding/eagle_utils.py | 17 +++++++- .../fsdp2_buffer_patch.py | 40 ++++++++++++++----- .../torch/export/plugins/hf_spec_export.py | 7 +++- 4 files changed, 51 insertions(+), 14 deletions(-) diff --git a/examples/speculative_decoding/doc/dflash.md b/examples/speculative_decoding/doc/dflash.md index 37bbe268f70..0150e0884e1 100644 --- a/examples/speculative_decoding/doc/dflash.md +++ b/examples/speculative_decoding/doc/dflash.md @@ -163,7 +163,6 @@ See [`modelopt_recipes/general/speculative_decoding/dflash.yaml`](../../../model | `dflash.dflash_num_anchors` | 512 | Random anchor positions per sample (see below) | | `dflash.dflash_loss_decay_factor` | 4.0 | Exponential decay gamma (0 disables, see below) | | `dflash.dflash_self_logit_distillation` | true | Use target model logits as soft labels (vs hard CE) | -| `dflash.dflash_architecture_config.num_hidden_layers` | 5 | Draft decoder layers | | `dflash.dflash_mask_token_id` | auto | Token ID for masked positions (see note below) | | `dflash.dflash_architecture_config.num_hidden_layers` | 5 | Draft decoder layers | | `training.answer_only_loss` | false | Mask loss on non-assistant tokens | diff --git a/examples/speculative_decoding/eagle_utils.py b/examples/speculative_decoding/eagle_utils.py index be54889be07..6d88632022b 100644 --- a/examples/speculative_decoding/eagle_utils.py +++ b/examples/speculative_decoding/eagle_utils.py @@ -271,7 +271,14 @@ def on_save(self, args, state, control, **kwargs): model, submodules={model.dflash_module}, options=options ) except TypeError: - # Older PyTorch without submodules parameter — gather full model + # Older PyTorch without the submodules= parameter: this gathers the FULL + # model (the entire base, e.g. ~229B for MiniMax-M2.7), defeating the + # submodule-only design and risking OOM. Warn loudly — upgrade PyTorch. + print_rank_0( + "WARNING: DFlash export: get_model_state_dict lacks submodules= on this " + "PyTorch — gathering the FULL base model (slow / may OOM). Upgrade PyTorch " + "for the submodule-only gather." + ) raw_sd = get_model_state_dict(model, options=options) except ImportError: # Non-distributed / single-GPU fallback @@ -283,6 +290,14 @@ def on_save(self, args, state, control, **kwargs): exporter = model.get_exporter() drafter_sd = exporter._extract_state_dict(raw_sd) if not drafter_sd: + # Fallback for the already-stripped-key layout: a denylist heuristic rather than + # the prefix-based extractor. Warn, since a key-naming change upstream could let + # a malformed draft ship and only fail at vLLM load time. + print_rank_0( + "WARNING: DFlash export: prefix-based extraction found no dflash_module keys; " + "falling back to a denylist heuristic on already-stripped keys. Verify the " + "exported draft loads in vLLM." + ) drafter_sd = { k: v for k, v in raw_sd.items() diff --git a/examples/speculative_decoding/fsdp2_buffer_patch.py b/examples/speculative_decoding/fsdp2_buffer_patch.py index 2a34e45edc7..388959c6b78 100644 --- a/examples/speculative_decoding/fsdp2_buffer_patch.py +++ b/examples/speculative_decoding/fsdp2_buffer_patch.py @@ -141,7 +141,19 @@ def apply(): import accelerate.utils.fsdp_utils as fsdp_utils from torch.distributed.tensor import DTensor - _orig = fsdp_utils.fsdp2_load_full_state_dict + def _dtype_code(dt): + """Map a dtype to its broadcast sync code, raising on anything unmapped. + + Silently coercing an unknown dtype to fp32 would cast data on the wire (or + make NCCL refuse on an element-size mismatch), so fail loudly instead. + """ + code = _DTYPE_TO_CODE.get(dt) + if code is None or code < 0: + raise ValueError( + f"fsdp2_buffer_patch: unsupported dtype {dt} in the broadcast " + f"dtype-sync; add it to _DTYPE_TO_CODE." + ) + return code def _patched(accelerator, model, full_sd, cpu_offload=False): import time @@ -175,7 +187,7 @@ def _patched(accelerator, model, full_sd, cpu_offload=False): # each broadcast tensor. if accelerator.is_main_process: dtype_codes = torch.tensor( - [_DTYPE_TO_CODE.get(full_sd[name].dtype, 0) for name in meta_sharded_sd], + [_dtype_code(full_sd[name].dtype) for name in meta_sharded_sd], dtype=torch.int32, device=accelerator.device, ) @@ -309,15 +321,19 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2): norm_type = float(norm_type) grads = [p.grad for p in parameters if p.grad is not None] - if len(grads) == 0: - # Match the device of the normal return path (GPU when training on CUDA) so - # callers don't hit a device mismatch on the empty-grad case. - dev = "cuda" if torch.cuda.is_available() else "cpu" - return torch.tensor(0.0, device=dev) + # Every rank MUST reach the all_reduce below: under FSDP2 sharding (and especially + # MoE + LoRA co-training) a rank can legitimately have no grads on a step — e.g. an + # expert that received no tokens, so the shard it owns gets no gradient. Early- + # returning here while other ranks call all_reduce would deadlock the job. So we + # never short-circuit before the collective; an empty-grad rank simply contributes a + # zero local norm and clips nothing. + if grads: + device = grads[0]._local_tensor.device if isinstance(grads[0], DTensor) else grads[0].device + else: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Shard DTensors hold partial data — need all_reduce for global norm. # Replicate DTensors and regular tensors already hold full data. - device = grads[0]._local_tensor.device if isinstance(grads[0], DTensor) else grads[0].device sharded_norm_p = torch.tensor(0.0, device=device) local_norm_p = torch.tensor(0.0, device=device) @@ -341,14 +357,18 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2): local_norm_p += n.pow(norm_type) n_regular += 1 - dist.all_reduce(sharded_norm_p, op=dist.ReduceOp.SUM) + # Symmetric across ranks: reached on every rank regardless of whether this rank had + # grads (see note above). Guard for the non-distributed case where local == global. + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(sharded_norm_p, op=dist.ReduceOp.SUM) total_norm = (sharded_norm_p + local_norm_p).pow(1.0 / norm_type) clip_coef = torch.clamp(max_norm / (total_norm + 1e-6), max=1.0) # Debug: log computation breakdown on first 5 calls (no collectives — safe). _clip_grad_norm_call_count += 1 - if _clip_grad_norm_call_count <= 5 and dist.get_rank() == 0: + _rank0 = not (dist.is_available() and dist.is_initialized()) or dist.get_rank() == 0 + if _clip_grad_norm_call_count <= 5 and _rank0: print( f"[clip_grad_norm debug] call={_clip_grad_norm_call_count} " f"total_norm={total_norm.item():.6f} " diff --git a/modelopt/torch/export/plugins/hf_spec_export.py b/modelopt/torch/export/plugins/hf_spec_export.py index f8efe0035af..95b2de864f4 100644 --- a/modelopt/torch/export/plugins/hf_spec_export.py +++ b/modelopt/torch/export/plugins/hf_spec_export.py @@ -379,8 +379,11 @@ def _export_config(self): # Inherit the target's rope_theta: DFlash injects the target's KV into every # draft layer, so the draft's RoPE base must match the target's. (The draft # arch config carries no rope_theta of its own.) - "rope_theta": getattr(base_config, "rope_theta", None) - or getattr(draft_config, "rope_theta", 1000000.0), + "rope_theta": ( + getattr(base_config, "rope_theta", None) + if getattr(base_config, "rope_theta", None) is not None + else getattr(draft_config, "rope_theta", 1000000.0) + ), # YaRN long-context scaling is injected below (see the rope_scaling block). "rope_scaling": None, "tie_word_embeddings": False, From a8c67e8744428cee0079f1003f1a5457ba01747f Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Mon, 15 Jun 2026 14:32:51 -0700 Subject: [PATCH 30/31] review(claude-bot): parameterize draft depth + harden /dev/shm staging in offline dump Addresses the remaining offline-dump-path findings (relevant to the upcoming M3 offline work): - compute_hidden_states_vllm.py: the standalone 'dflash' aux-layer resolver no longer hard-codes num_draft=5; added --num-draft-layers (default 5) so the dumped aux layers match the recipe's dflash_architecture_config.num_hidden_layers. A mismatch would silently mis-align the dump with what the draft consumes at training time. - compute_hidden_states_vllm.py: the connector staging dir is now overridable via DFLASH_HS_STAGING_DIR (default /dev/shm) for containers where /dev/shm is unmapped or undersized, and is removed on exit (atexit) so a crash doesn't strand RAM-backed files. - specdec_bench.yaml: clarified that exported-checkpoint-final comes from the launcher's post-run export, while the per-save callback writes exported-checkpoint-. Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: Ye Yu --- .../compute_hidden_states_vllm.py | 34 +++++++++++++++---- .../MiniMax-M2.7-DFlash/specdec_bench.yaml | 8 +++-- 2 files changed, 33 insertions(+), 9 deletions(-) diff --git a/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py b/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py index 45129131293..a14558f8a9e 100644 --- a/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py +++ b/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py @@ -25,6 +25,9 @@ """ import argparse +import atexit +import os +import shutil from pathlib import Path import torch @@ -44,22 +47,26 @@ ) -def _resolve_aux_layers_standalone(aux_layers: str, num_hidden_layers: int) -> list[int]: +def _resolve_aux_layers_standalone( + aux_layers: str, num_hidden_layers: int, num_draft: int = 5 +) -> list[int]: """Resolve aux-layer ids without importing modelopt. This dump runs in a stock vLLM container. ``common.resolve_aux_layers`` resolves the 'dflash'/'eagle' presets by importing ``modelopt.torch.speculative.plugins`` — which pulls in the full ``modelopt.torch`` init chain (omegaconf, etc.) that the vLLM container does not have, so the import fails. Resolve the 'dflash' preset inline - (mirroring ``modeling_dflash.build_target_layer_ids`` with num_draft=5, the recipe - default) and accept an explicit comma-separated int list. Keep in sync with modelopt. + (mirroring ``modeling_dflash.build_target_layer_ids`` for ``num_draft`` draft layers) + and accept an explicit comma-separated int list. ``num_draft`` MUST match the recipe's + ``dflash.dflash_architecture_config.num_hidden_layers`` (pass --num-draft-layers) or the + dumped aux layers silently mis-align with what the draft consumes at training time. + Keep in sync with modelopt. TODO: drop this once ``common.resolve_aux_layers`` is decoupled from the heavy ``modelopt.torch`` import chain so it can be reused directly in a vLLM container. """ spec = aux_layers.strip().lower() if spec == "dflash": - num_draft = 5 if num_draft == 1: return [num_hidden_layers // 2] start = min(1, num_hidden_layers - 1) @@ -107,6 +114,13 @@ def parse_args() -> argparse.Namespace: "--debug-max-num-conversations", type=int, default=None, help="Limit conversations." ) add_aux_layers_args(parser) + parser.add_argument( + "--num-draft-layers", + type=int, + default=5, + help="DFlash draft depth, for resolving the 'dflash' --aux-layers preset. MUST match " + "the recipe's dflash.dflash_architecture_config.num_hidden_layers (default: 5).", + ) add_answer_only_loss_args(parser) return parser.parse_args() @@ -157,7 +171,9 @@ def keep_conversation(entry): num_hidden_layers = getattr(config, "num_hidden_layers", None) if num_hidden_layers is None: raise ValueError(f"model config has no 'num_hidden_layers' attribute: {config}") - aux_layer_ids = _resolve_aux_layers_standalone(args.aux_layers, num_hidden_layers) + aux_layer_ids = _resolve_aux_layers_standalone( + args.aux_layers, num_hidden_layers, num_draft=args.num_draft_layers + ) # The trailing entry is the final output hidden state; the rest are aux layers. extract_layer_ids = [*aux_layer_ids, num_hidden_layers] print(f"Extracting hidden states from layers {extract_layer_ids} (last = final output)") @@ -220,9 +236,13 @@ def keep_conversation(entry): # Stage the connector's intermediate safetensors on local tmpfs, not the (lustre) # output dir: the producer writes one file per request and the client reads it back # immediately, so a fast local path avoids cross-node FS latency. Per-DP-rank dir so - # parallel shards don't collide. - storage_path = Path("/dev/shm") / f"vllm_hidden_states_dp{args.dp_rank}" + # parallel shards don't collide. Overridable via DFLASH_HS_STAGING_DIR for containers + # where /dev/shm is unmapped or undersized; cleaned up on exit so a crash doesn't strand + # RAM-backed files until the node reboots. + staging_root = os.environ.get("DFLASH_HS_STAGING_DIR", "/dev/shm") + storage_path = Path(staging_root) / f"vllm_hidden_states_dp{args.dp_rank}" storage_path.mkdir(parents=True, exist_ok=True) + atexit.register(lambda p=storage_path: shutil.rmtree(p, ignore_errors=True)) llm = LLM( model=args.model, diff --git a/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/specdec_bench.yaml b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/specdec_bench.yaml index 63ea3c47095..ad538805fb6 100644 --- a/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/specdec_bench.yaml +++ b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/specdec_bench.yaml @@ -16,9 +16,13 @@ # dir containing model.safetensors), e.g. the output of hf_online_dflash.yaml. # # Usage: -# Edit the two --draft_model_dir args to point at your exported draft checkpoint -# (an exported-checkpoint-* dir with model.safetensors), then: +# Edit the two --draft_model_dir args to point at your exported draft checkpoint, then: # uv run slurm.py --yaml modules/Model-Optimizer/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/specdec_bench.yaml --yes +# +# NOTE: the placeholder below is `exported-checkpoint-final` — that dir is written by the +# training launcher's post-run export (dflash_online_training.sh) for the final model. The +# per-save DFlashFSDP2ShardedSDExportCallback instead writes `exported-checkpoint-` +# dirs; to bench a specific intermediate step, point at one of those. job_name: MiniMax-M2.7-DFlash_specdec_bench pipeline: From 69ad872f66c0b846524611ec4ac55eb384606480 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Mon, 15 Jun 2026 14:40:19 -0700 Subject: [PATCH 31/31] =?UTF-8?q?fix:=20mypy=20=E2=80=94=20pass=20shutil.r?= =?UTF-8?q?mtree=20args=20via=20atexit.register=20instead=20of=20a=20lambd?= =?UTF-8?q?a?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit mypy couldn't infer the cleanup lambda's type ('Cannot infer type of lambda'). atexit forwards *args/**kwargs to the callback, so register shutil.rmtree directly with its args. Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: Ye Yu --- .../collect_hidden_states/compute_hidden_states_vllm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py b/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py index a14558f8a9e..77441f8f858 100644 --- a/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py +++ b/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py @@ -242,7 +242,7 @@ def keep_conversation(entry): staging_root = os.environ.get("DFLASH_HS_STAGING_DIR", "/dev/shm") storage_path = Path(staging_root) / f"vllm_hidden_states_dp{args.dp_rank}" storage_path.mkdir(parents=True, exist_ok=True) - atexit.register(lambda p=storage_path: shutil.rmtree(p, ignore_errors=True)) + atexit.register(shutil.rmtree, storage_path, ignore_errors=True) llm = LLM( model=args.model,