diff --git a/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py b/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py index dd496480cbb..77441f8f858 100644 --- a/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py +++ b/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py @@ -25,10 +25,19 @@ """ import argparse +import atexit +import os +import shutil from pathlib import Path import torch -from common import add_aux_layers_args, resolve_aux_layers +from common import ( + add_answer_only_loss_args, + add_aux_layers_args, + load_chat_template, + tokenize_with_loss_mask, + verify_generation_tags, +) from datasets import load_dataset from tqdm import tqdm from transformers import AutoConfig, AutoTokenizer @@ -38,6 +47,48 @@ ) +def _resolve_aux_layers_standalone( + aux_layers: str, num_hidden_layers: int, num_draft: int = 5 +) -> list[int]: + """Resolve aux-layer ids without importing modelopt. + + This dump runs in a stock vLLM container. ``common.resolve_aux_layers`` resolves the + 'dflash'/'eagle' presets by importing ``modelopt.torch.speculative.plugins`` — which + pulls in the full ``modelopt.torch`` init chain (omegaconf, etc.) that the vLLM + container does not have, so the import fails. Resolve the 'dflash' preset inline + (mirroring ``modeling_dflash.build_target_layer_ids`` for ``num_draft`` draft layers) + and accept an explicit comma-separated int list. ``num_draft`` MUST match the recipe's + ``dflash.dflash_architecture_config.num_hidden_layers`` (pass --num-draft-layers) or the + dumped aux layers silently mis-align with what the draft consumes at training time. + Keep in sync with modelopt. + + TODO: drop this once ``common.resolve_aux_layers`` is decoupled from the heavy + ``modelopt.torch`` import chain so it can be reused directly in a vLLM container. + """ + spec = aux_layers.strip().lower() + if spec == "dflash": + if num_draft == 1: + return [num_hidden_layers // 2] + start = min(1, num_hidden_layers - 1) + end = max(start, num_hidden_layers - 3) + span = end - start + return sorted({round(start + (i * span) / (num_draft - 1)) for i in range(num_draft)}) + ids = sorted({int(t) for t in aux_layers.split(",") if t.strip()}) + # Match the shared helper's contract: ids must be valid layer indices. + out_of_range = [i for i in ids if not 0 <= i < num_hidden_layers] + if out_of_range: + raise ValueError( + f"--aux-layers ids {out_of_range} out of range [0, {num_hidden_layers}) " + f"for a {num_hidden_layers}-layer model." + ) + if not ids: + raise ValueError( + f"--aux-layers={aux_layers!r}: in the stock vLLM container (no modelopt) only the " + "'dflash' preset or an explicit comma-separated layer-id list are supported." + ) + return ids + + def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="""Collect hidden states from conversations using vLLM's native extractor.""" @@ -63,6 +114,14 @@ def parse_args() -> argparse.Namespace: "--debug-max-num-conversations", type=int, default=None, help="Limit conversations." ) add_aux_layers_args(parser) + parser.add_argument( + "--num-draft-layers", + type=int, + default=5, + help="DFlash draft depth, for resolving the 'dflash' --aux-layers preset. MUST match " + "the recipe's dflash.dflash_architecture_config.num_hidden_layers (default: 5).", + ) + add_answer_only_loss_args(parser) return parser.parse_args() @@ -112,7 +171,9 @@ def keep_conversation(entry): num_hidden_layers = getattr(config, "num_hidden_layers", None) if num_hidden_layers is None: raise ValueError(f"model config has no 'num_hidden_layers' attribute: {config}") - aux_layer_ids = resolve_aux_layers(args, num_hidden_layers) + aux_layer_ids = _resolve_aux_layers_standalone( + args.aux_layers, num_hidden_layers, num_draft=args.num_draft_layers + ) # The trailing entry is the final output hidden state; the rest are aux layers. extract_layer_ids = [*aux_layer_ids, num_hidden_layers] print(f"Extracting hidden states from layers {extract_layer_ids} (last = final output)") @@ -121,34 +182,37 @@ def keep_conversation(entry): tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=args.trust_remote_code) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token + override_template = load_chat_template(args.chat_template) + if override_template is not None: + tokenizer.chat_template = override_template if tokenizer.chat_template is not None: tokenizer.chat_template = tokenizer.chat_template.replace(REMOVE_THINK_CHAT_TEMPLATE, "") + if args.answer_only_loss: + verify_generation_tags(tokenizer.chat_template) # Prepare prompts for vLLM prompts = [] conversation_ids = [] + loss_masks = [] num_skipped_too_long = 0 num_invalid = 0 for entry in dataset: conversation_id = entry.get("conversation_id", entry.get("uuid")) - conversations = entry["conversations"] + # Accept either the "conversations" or OpenAI-style "messages" key (the + # MiniMax synthetic data uses "messages"). + conversations = entry.get("conversations") or entry.get("messages") if not conversations or not isinstance(conversations, list): num_invalid += 1 continue - tokenized = tokenizer.apply_chat_template( - conversations, return_tensors="pt", add_generation_prompt=False + # One apply_chat_template call yields aligned input_ids + loss_mask. With + # --answer-only-loss the mask comes from the template's {% generation %} tags; + # otherwise it is all-ones. Same tokens are sent to vLLM, so the dumped hidden + # states line up with this loss_mask 1:1 (prefix caching is disabled below). + input_ids, loss_mask = tokenize_with_loss_mask( + tokenizer, conversations, args.answer_only_loss ) - # transformers 5.x: BatchEncoding may not inherit from dict; use .input_ids - if hasattr(tokenized, "input_ids"): - input_ids = tokenized.input_ids - elif hasattr(tokenized, "__getitem__") and "input_ids" in tokenized: - input_ids = tokenized["input_ids"] - else: - input_ids = tokenized - if not hasattr(input_ids, "shape"): - input_ids = torch.tensor(input_ids) input_ids = input_ids.squeeze(0) num_tokens = input_ids.shape[0] if num_tokens <= 10 or num_tokens > args.max_seq_len: @@ -157,6 +221,7 @@ def keep_conversation(entry): prompts.append(TokensPrompt(prompt_token_ids=input_ids.tolist())) conversation_ids.append(conversation_id) + loss_masks.append(loss_mask) print( f"Prepared {len(prompts)} prompts ({num_skipped_too_long} skipped too long, {num_invalid} invalid)" @@ -168,8 +233,16 @@ def keep_conversation(entry): # Initialize vLLM with the native hidden-state extractor. tp = args.tp if args.tp is not None else torch.cuda.device_count() - storage_path = output_dir / ".vllm_hidden_states" + # Stage the connector's intermediate safetensors on local tmpfs, not the (lustre) + # output dir: the producer writes one file per request and the client reads it back + # immediately, so a fast local path avoids cross-node FS latency. Per-DP-rank dir so + # parallel shards don't collide. Overridable via DFLASH_HS_STAGING_DIR for containers + # where /dev/shm is unmapped or undersized; cleaned up on exit so a crash doesn't strand + # RAM-backed files until the node reboots. + staging_root = os.environ.get("DFLASH_HS_STAGING_DIR", "/dev/shm") + storage_path = Path(staging_root) / f"vllm_hidden_states_dp{args.dp_rank}" storage_path.mkdir(parents=True, exist_ok=True) + atexit.register(shutil.rmtree, storage_path, ignore_errors=True) llm = LLM( model=args.model, @@ -177,6 +250,11 @@ def keep_conversation(entry): max_model_len=args.max_seq_len, trust_remote_code=args.trust_remote_code, enable_chunked_prefill=False, # required by extract_hidden_states + # With prefix caching on, vLLM serves shared prefixes from cache in block-sized + # chunks and the hidden-state connector only emits the freshly-computed suffix, so + # the dumped hidden_states come out short by N*block_size vs the full input_ids / + # loss_mask. Disabling it forces a full prefill so every token's state is dumped. + enable_prefix_caching=False, speculative_config={ "method": "extract_hidden_states", "num_speculative_tokens": 1, @@ -189,7 +267,10 @@ def keep_conversation(entry): kv_role="kv_producer", kv_connector_extra_config={ "shared_storage_path": str(storage_path), - "use_synchronization_lock": False, # batch generation, no concurrent readers + # The client reads each request's safetensors right after generation; the + # lock makes the producer signal completion so the reader doesn't race the + # writer (without it the reader looks for a .lock the producer never wrote). + "use_synchronization_lock": True, }, ), ) @@ -197,10 +278,11 @@ def keep_conversation(entry): # max_tokens=1: we only need a single forward pass over the prompt tokens. outputs = llm.generate(prompts, SamplingParams(max_tokens=1)) - # Save in the same format as compute_hidden_states_hf.py (sans loss_mask, which the - # vLLM path does not compute). + # Save in the same format as compute_hidden_states_hf.py, including loss_mask. num_success = 0 - for conv_id, output in tqdm(zip(conversation_ids, outputs), total=len(outputs), desc="Saving"): + for conv_id, loss_mask, output in tqdm( + zip(conversation_ids, loss_masks, outputs), total=len(outputs), desc="Saving" + ): hidden_states_path = output.kv_transfer_params.get("hidden_states_path") if hidden_states_path is None: print(f"WARNING: no hidden_states_path for conversation {conv_id}; skipping") @@ -220,6 +302,16 @@ def keep_conversation(entry): else: aux_hidden_states = torch.empty(0) + # loss_mask is sliced to the dumped length below; a shorter loss_mask would slice + # to itself and silently misalign with the hidden states, so guard explicitly. + n_hs = output_hidden_states.shape[0] + if loss_mask.shape[0] < n_hs: + print( + f"WARNING: {conv_id}: loss_mask ({loss_mask.shape[0]}) shorter than hidden " + f"states ({n_hs}); skipping to avoid misalignment" + ) + continue + output_file = output_dir / f"{conv_id}.pt" with open(output_file, "wb") as f: torch.save( @@ -227,6 +319,7 @@ def keep_conversation(entry): "input_ids": token_ids.cpu(), "hidden_states": output_hidden_states, "aux_hidden_states": aux_hidden_states, + "loss_mask": loss_mask[: output_hidden_states.shape[0]].cpu(), "conversation_id": conv_id, }, f, diff --git a/examples/speculative_decoding/doc/dflash.md b/examples/speculative_decoding/doc/dflash.md index 44db5d39e72..0150e0884e1 100644 --- a/examples/speculative_decoding/doc/dflash.md +++ b/examples/speculative_decoding/doc/dflash.md @@ -163,10 +163,15 @@ See [`modelopt_recipes/general/speculative_decoding/dflash.yaml`](../../../model | `dflash.dflash_num_anchors` | 512 | Random anchor positions per sample (see below) | | `dflash.dflash_loss_decay_factor` | 4.0 | Exponential decay gamma (0 disables, see below) | | `dflash.dflash_self_logit_distillation` | true | Use target model logits as soft labels (vs hard CE) | +| `dflash.dflash_mask_token_id` | auto | Token ID for masked positions (see note below) | | `dflash.dflash_architecture_config.num_hidden_layers` | 5 | Draft decoder layers | -| `dflash.dflash_architecture_config.mask_token_id` | auto | Token ID for masked positions | | `training.answer_only_loss` | false | Mask loss on non-assistant tokens | +> **Note on `dflash_mask_token_id`:** the draft reuses the target's `embed_tokens`, so the +> mask id must be a token that exists in the target embedding. Pin it to an existing +> reserved token id — e.g. MiniMax-M2.7 uses `200054` (embedding is 200064 rows; tokens +> 0..200053 are real, 200054+ reserved). + > **Note on `answer_only_loss` and chat templates:** When `answer_only_loss=true`, the > tokenizer's chat template must include `{% generation %}` / `{% endgeneration %}` tags > around assistant content. HuggingFace uses these tags to produce `assistant_masks` via diff --git a/examples/speculative_decoding/eagle_utils.py b/examples/speculative_decoding/eagle_utils.py index 721b981eaae..6d88632022b 100644 --- a/examples/speculative_decoding/eagle_utils.py +++ b/examples/speculative_decoding/eagle_utils.py @@ -230,6 +230,110 @@ def on_step_begin(self, args, state, control, **kwargs): return control +class DFlashFSDP2ShardedSDExportCallback(TrainerCallback): + """Export the DFlash draft module after each checkpoint save, for FSDP2 sharded runs. + + Applicable range: this is needed only under FSDP2 ``SHARDED_STATE_DICT``, where the + checkpoint holds distributed shards (``pytorch_model_fsdp_0/``) and no consolidated + ``model.safetensors`` — so the post-training ``export_hf_checkpoint.py`` pass can't read + it. It gathers just the small draft submodule and writes the deployable export. + + Gating is the caller's responsibility: ``main.py`` adds this callback only when the + accelerator's FSDP state dict type is ``SHARDED_STATE_DICT`` (full-state-dict runs — + DDP, single-device, FSDP2 FULL_STATE_DICT — use the launcher's post-run export instead). + """ + + def on_save(self, args, state, control, **kwargs): + """Export DFlash draft module weights + config after checkpoint save.""" + import json + import os + + from safetensors.torch import save_file + + model = kwargs["model"] + if not hasattr(model, "dflash_module"): + return control + + step = state.global_step + export_dir = os.path.join(args.output_dir, f"exported-checkpoint-{step}") + + # All ranks participate in the state_dict gather (FSDP2 collective op). Only the + # dflash_module submodule is gathered (~328 MB for MiniMax-M2.7), not the 229B base. + try: + from torch.distributed.checkpoint.state_dict import ( + StateDictOptions, + get_model_state_dict, + ) + + options = StateDictOptions(full_state_dict=True, cpu_offload=True) + try: + raw_sd = get_model_state_dict( + model, submodules={model.dflash_module}, options=options + ) + except TypeError: + # Older PyTorch without the submodules= parameter: this gathers the FULL + # model (the entire base, e.g. ~229B for MiniMax-M2.7), defeating the + # submodule-only design and risking OOM. Warn loudly — upgrade PyTorch. + print_rank_0( + "WARNING: DFlash export: get_model_state_dict lacks submodules= on this " + "PyTorch — gathering the FULL base model (slow / may OOM). Upgrade PyTorch " + "for the submodule-only gather." + ) + raw_sd = get_model_state_dict(model, options=options) + except ImportError: + # Non-distributed / single-GPU fallback + raw_sd = model.state_dict() + + # Reuse the exporter's extraction (strips the dflash_module prefix, drops rotary + # buffers) for the common full-model key layout. Some PyTorch versions return the + # submodule gather with keys already stripped of the prefix — handle that directly. + exporter = model.get_exporter() + drafter_sd = exporter._extract_state_dict(raw_sd) + if not drafter_sd: + # Fallback for the already-stripped-key layout: a denylist heuristic rather than + # the prefix-based extractor. Warn, since a key-naming change upstream could let + # a malformed draft ship and only fail at vLLM load time. + print_rank_0( + "WARNING: DFlash export: prefix-based extraction found no dflash_module keys; " + "falling back to a denylist heuristic on already-stripped keys. Verify the " + "exported draft loads in vLLM." + ) + drafter_sd = { + k: v + for k, v in raw_sd.items() + if "rotary_emb" not in k + and not any(p in k for p in ("model.", "lm_head.", "embed_tokens.")) + } + del raw_sd + # Coerce to CPU for save_file (the distributed gather uses cpu_offload, but the + # single-GPU fallback may return CUDA tensors). + drafter_sd = {k: (v.cpu() if v.device.type != "cpu" else v) for k, v in drafter_sd.items()} + + if not drafter_sd: + print_rank_0(f"Warning: No dflash_module weights found at step {step}, skipping export") + return control + + # Only rank 0 writes files + if is_master(): + try: + os.makedirs(export_dir, exist_ok=True) + save_file(drafter_sd, os.path.join(export_dir, "model.safetensors")) + + config = exporter._export_config() + with open(os.path.join(export_dir, "config.json"), "w") as f: + json.dump(config, f, indent=2) + + total_mb = sum(v.nbytes for v in drafter_sd.values()) / 1024 / 1024 + print_rank_0( + f"Exported DFlash draft ({len(drafter_sd)} tensors, {total_mb:.0f}MB) " + f"to {export_dir}" + ) + except Exception as e: + print_rank_0(f"Warning: DFlash export failed at step {step}: {e}") + + return control + + class EagleTrainingPlot(TrainerCallback): """Callback that plot training acc and AR during training.""" diff --git a/examples/speculative_decoding/fsdp2_buffer_patch.py b/examples/speculative_decoding/fsdp2_buffer_patch.py new file mode 100644 index 00000000000..388959c6b78 --- /dev/null +++ b/examples/speculative_decoding/fsdp2_buffer_patch.py @@ -0,0 +1,413 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Monkey-patch for accelerate's fsdp2_load_full_state_dict buffer handling. + +Applicability (scope of this patch) +----------------------------------- +This is **not** needed for FSDP2 in general. It is required only for the narrow +combination of: **FSDP2 configured via an accelerate YAML config** (not torch-native +``ParallelismConfig``) **with** ``cpu_ram_efficient_loading=True``. Today that path is +forced by **models that require transformers 4.57.x** (their ``trust_remote_code`` code +predates transformers 5.x ``ParallelismConfig`` support) **and** are too large to load on +every rank — currently only **MiniMax-M2.7** (229B MoE). Models that run on transformers +5.x (Qwen, Llama, Nemotron, ...) use native ``ParallelismConfig`` (``dp_shard_size > 1``), +which handles buffers/dtypes correctly and never enters ``fsdp2_load_full_state_dict`` — +they need none of this. Gated off by default; activate with ``PATCH_FSDP2_BUFFERS_TF457=1``. + +Problem +------- +accelerate's ``fsdp2_load_full_state_dict`` (called during model preparation +when ``cpu_ram_efficient_loading=True``) iterates ``model.state_dict()`` and +unconditionally accesses ``.device_mesh`` on every entry, assuming they are all +DTensors. After ``fully_shard()``, **parameters** become DTensors but +**persistent buffers** (e.g., rotary-embedding ``inv_freq``) remain plain +``torch.Tensor``. This crashes with:: + + AttributeError: 'Tensor' object has no attribute 'device_mesh' + +Additionally, ``cpu_ram_efficient_loading`` causes a dtype divergence: rank 0 +loads the model on CPU (inheriting the checkpoint's dtype, e.g. bfloat16) while +other ranks use ``meta`` device (defaulting to float32 for newly-added modules +like the DFlash head). After ``fully_shard()``, the DTensor dtypes differ +across ranks for these modules. Since ``dist.broadcast()`` requires matching +dtypes and element sizes on all ranks, broadcasting a bfloat16 tensor (2 +bytes/elem) to a float32 receive buffer (4 bytes/elem) causes an NCCL deadlock. + +Why we need FSDP2 via accelerate config (not ParallelismConfig) +--------------------------------------------------------------- +MiniMax-M2.7's ``trust_remote_code`` model code requires **transformers 4.57.x**. +Transformers' native FSDP2 support via ``ParallelismConfig`` requires +**transformers 5.x**. So we fall back to configuring FSDP2 through an +``accelerate`` YAML config file (``accelerate_fsdp2.yaml``), which works with +transformers 4.57.x. We set ``dp_shard_size=1`` to prevent ``main.py`` from +creating a ``ParallelismConfig``, letting the accelerate config handle sharding. + +Why we need cpu_ram_efficient_loading +------------------------------------- +MiniMax-M2.7 is a 229B MoE model (~230 GB in FP8). Each GB200 node has 4 GPUs +and ~800 GB system RAM. Without ``cpu_ram_efficient_loading``, all 4 ranks per +node load the model to CPU simultaneously (4 × 230 GB ≈ 920 GB), exceeding +system RAM and triggering OOM kills. With ``cpu_ram_efficient_loading``, only +rank 0 loads the model; other ranks initialize on ``meta`` device. The weights +are then broadcast via ``fsdp2_load_full_state_dict`` — which is where the bug +hits. + +What this patch does +-------------------- +1. Before accessing ``.device_mesh``, checks ``isinstance(entry, DTensor)``. + For non-DTensor entries (persistent buffers), broadcasts the raw tensor from + rank 0 without calling ``distribute_tensor()``. + +2. All ranks iterate ``model.state_dict()`` (post-shard) in the same order so + broadcast calls match 1-to-1. Rank 0 looks up the full parameter **by key + name** from the pre-shard state dict — never by positional ``zip``, because + ``model.to("meta")`` + ``fully_shard()`` can reorder keys. + +3. **Dtype synchronization**: rank 0 broadcasts a dtype code for each entry + before the main loop. All ranks then use the same dtype for their broadcast + tensors. This fixes the dtype divergence caused by rank 0 loading in + bfloat16 while other ranks default to float32 for newly-added modules. + +Accelerate config constraints (for reference) +---------------------------------------------- +``accelerate_fsdp2.yaml`` also requires: + +- ``fsdp_use_orig_params: true`` — without this, FSDP flattens all params into + FlatParameter, losing per-parameter ``requires_grad`` flags. The DFlash head + can't train because its gradients mix with frozen base model zeros. +- ``fsdp_transformer_layer_cls_to_wrap: MiniMaxM2DecoderLayer,DFlashModule`` — + DFlash head params at the model root must be in the wrap policy so they become + DTensors. Without this, ``fsdp2_load_full_state_dict`` also crashes. +- ``fsdp_sync_module_states: true`` — accelerate's launch validator requires it + when ``cpu_ram_efficient_loading`` is enabled, even though FSDP2 ignores it at + runtime (sets it to None with a warning). + +Does NOT affect models on transformers 5.x +------------------------------------------- +This entire workaround exists ONLY because MiniMax-M2.7 requires +transformers 4.57.x. Models that support transformers 5.x (Qwen, Llama, +Nemotron, etc.) use ``ParallelismConfig`` natively by setting +``dp_shard_size > 1`` in the training args. That code path handles buffers +correctly and does not go through ``fsdp2_load_full_state_dict`` at all. +No accelerate config file, no ``PATCH_FSDP2_BUFFERS_TF457``, no +``OVERRIDE_TRANSFORMERS`` needed. + +When to remove +-------------- +This patch can be removed when EITHER of these happens: + +1. MiniMax updates ``trust_remote_code`` for transformers 5.x, allowing native + ``ParallelismConfig`` (which handles this correctly). +2. Upstream accelerate fixes ``fsdp2_load_full_state_dict`` to skip non-DTensor + entries. Track: https://github.com/huggingface/accelerate + +Activation +---------- +Set ``PATCH_FSDP2_BUFFERS_TF457=1`` in the environment to activate. Off by default. +Only needed in MiniMax-M2.7 pipeline YAMLs. +""" + +import torch + +# Dtype encoding for the broadcast dtype-sync step. +_DTYPE_TO_CODE = { + torch.float32: 0, + torch.bfloat16: 1, + torch.float16: 2, + torch.float8_e4m3fn: 3 if hasattr(torch, "float8_e4m3fn") else -1, +} +_CODE_TO_DTYPE = {v: k for k, v in _DTYPE_TO_CODE.items() if v >= 0} + + +def apply(): + """Patch fsdp2_load_full_state_dict if the buffer bug is present.""" + try: + import accelerate.utils.fsdp_utils as fsdp_utils + from torch.distributed.tensor import DTensor + + def _dtype_code(dt): + """Map a dtype to its broadcast sync code, raising on anything unmapped. + + Silently coercing an unknown dtype to fp32 would cast data on the wire (or + make NCCL refuse on an element-size mismatch), so fail loudly instead. + """ + code = _DTYPE_TO_CODE.get(dt) + if code is None or code < 0: + raise ValueError( + f"fsdp2_buffer_patch: unsupported dtype {dt} in the broadcast " + f"dtype-sync; add it to _DTYPE_TO_CODE." + ) + return code + + def _patched(accelerator, model, full_sd, cpu_offload=False): + import time + + import torch.distributed as dist + from torch.distributed.tensor import distribute_tensor + + meta_sharded_sd = model.state_dict() + sharded_sd = {} + n_total = len(meta_sharded_sd) + n_dtensor = sum(1 for v in meta_sharded_sd.values() if isinstance(v, DTensor)) + n_buffer = n_total - n_dtensor + + if accelerator.is_main_process: + print( + f"[fsdp2_buffer_patch] State dict: {n_total} entries " + f"({n_dtensor} DTensor, {n_buffer} buffer), full_sd: {len(full_sd)}" + ) + else: + print( + f"[fsdp2_buffer_patch] State dict: {n_total} entries " + f"({n_dtensor} DTensor, {n_buffer} buffer)" + ) + t0 = time.time() + + # --- Step 0: broadcast dtype codes from rank 0 --- + # cpu_ram_efficient_loading causes rank 0 to load in bfloat16 while + # other ranks default to float32 for newly-added modules (DFlash). + # After fully_shard(), DTensor dtypes diverge across ranks. + # Broadcast rank 0's dtypes so all ranks use the same dtype for + # each broadcast tensor. + if accelerator.is_main_process: + dtype_codes = torch.tensor( + [_dtype_code(full_sd[name].dtype) for name in meta_sharded_sd], + dtype=torch.int32, + device=accelerator.device, + ) + else: + dtype_codes = torch.empty( + n_total, + dtype=torch.int32, + device=accelerator.device, + ) + dist.broadcast(dtype_codes, src=0, group=dist.group.WORLD) + broadcast_dtypes = [_CODE_TO_DTYPE[c.item()] for c in dtype_codes] + del dtype_codes + + # Infer dtype/contiguity for cast — copied from upstream + def _infer_parameter_dtype(mdl, param_name, empty_param): + try: + old_param = mdl.get_parameter_or_buffer(param_name) + except AttributeError: + base, local = param_name.rsplit(".", 1) + old_param = getattr(mdl.get_submodule(base), local) + is_f8 = hasattr(torch, "float8_e4m3fn") and empty_param.dtype == torch.float8_e4m3fn + casting_dtype = ( + old_param.dtype if (empty_param.dtype.is_floating_point and not is_f8) else None + ) + return old_param is not None and old_param.is_contiguous(), casting_dtype + + def _finish(st, contig, dtype, offload): + if dtype is not None: + st = st.to(dtype=dtype) + if contig: + st = st.contiguous() + if offload: + st = st.to("cpu") + return st + + # --- Step 1: broadcast all entries --- + # All ranks iterate meta_sharded_sd in the same order to ensure + # identical broadcast sequences. Rank 0 looks up the full parameter + # by name — never positional zip (model.to("meta") + fully_shard() + # can reorder keys). broadcast_dtypes[idx] is used for the tensor + # dtype on ALL ranks to prevent NCCL deadlocks from dtype divergence. + for idx, (param_name, sharded_param) in enumerate(meta_sharded_sd.items()): + is_dtensor = isinstance(sharded_param, DTensor) + bcast_dtype = broadcast_dtypes[idx] + + if not is_dtensor: + # Persistent buffer — broadcast raw, no distribute_tensor + if accelerator.is_main_process: + t = ( + full_sd[param_name] + .detach() + .to(device=accelerator.device, dtype=bcast_dtype) + ) + else: + t = torch.empty( + sharded_param.size(), + device=accelerator.device, + dtype=bcast_dtype, + ) + dist.broadcast(t, src=0, group=dist.group.WORLD) + sharded_sd[param_name] = t + continue + + device_mesh = sharded_param.device_mesh + if accelerator.is_main_process: + ft = ( + full_sd[param_name] + .detach() + .to(device=device_mesh.device_type, dtype=bcast_dtype) + ) + if isinstance(ft, DTensor): + ft = ft.to_local() + else: + ft = torch.empty( + sharded_param.size(), + device=device_mesh.device_type, + dtype=bcast_dtype, + ) + dist.broadcast(ft, src=0, group=dist.group.WORLD) + st = distribute_tensor(ft, device_mesh, sharded_param.placements) + contig, _ = _infer_parameter_dtype(model, param_name, ft) + # Use bcast_dtype (from rank 0) instead of the model's local + # param dtype — with cpu_ram_efficient_loading, non-rank-0 + # processes have fp32 meta-device params for DFlash, and + # _infer_parameter_dtype would incorrectly cast bf16 back to fp32. + is_f8 = hasattr(torch, "float8_e4m3fn") and bcast_dtype == torch.float8_e4m3fn + final_dtype = None if is_f8 else bcast_dtype + sharded_sd[param_name] = _finish(st, contig, final_dtype, cpu_offload) + + elapsed = time.time() - t0 + print( + f"[fsdp2_buffer_patch] Broadcast done in {elapsed:.1f}s, " + f"loading {len(sharded_sd)} entries into model..." + ) + model.load_state_dict(sharded_sd, assign=True) + print( + f"[fsdp2_buffer_patch] State dict loaded successfully " + f"({time.time() - t0:.1f}s total)" + ) + return model + + fsdp_utils.fsdp2_load_full_state_dict = _patched + print("[fsdp2_buffer_patch] Patched fsdp2_load_full_state_dict for buffer compatibility") + except Exception as e: + print(f"[fsdp2_buffer_patch] Patch skipped: {e}") + + +_clip_grad_norm_call_count = 0 + + +def _clip_grad_norm(parameters, max_norm, norm_type=2): + """Clip gradient norms for FSDP2 DTensor parameters. + + Bypasses DTensor dispatch (which deadlocks with partially-frozen models + on the accelerate FSDP2 path) by extracting local tensor shards and + doing an explicit all_reduce for the global norm. + + Handles Shard (need all_reduce) and Replicate/regular (already global) + placements. Safe for DFlash-only and LoRA co-training. + """ + global _clip_grad_norm_call_count + import torch.distributed as dist + from torch.distributed.tensor import DTensor + from torch.distributed.tensor.placement_types import Shard + + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + + parameters = list(parameters) # materialize generator + max_norm = float(max_norm) + norm_type = float(norm_type) + + grads = [p.grad for p in parameters if p.grad is not None] + # Every rank MUST reach the all_reduce below: under FSDP2 sharding (and especially + # MoE + LoRA co-training) a rank can legitimately have no grads on a step — e.g. an + # expert that received no tokens, so the shard it owns gets no gradient. Early- + # returning here while other ranks call all_reduce would deadlock the job. So we + # never short-circuit before the collective; an empty-grad rank simply contributes a + # zero local norm and clips nothing. + if grads: + device = grads[0]._local_tensor.device if isinstance(grads[0], DTensor) else grads[0].device + else: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Shard DTensors hold partial data — need all_reduce for global norm. + # Replicate DTensors and regular tensors already hold full data. + sharded_norm_p = torch.tensor(0.0, device=device) + local_norm_p = torch.tensor(0.0, device=device) + + n_sharded = 0 + n_replicate = 0 + n_regular = 0 + + for g in grads: + if isinstance(g, DTensor): + is_sharded = any(isinstance(p, Shard) for p in g.placements) + t = g._local_tensor.detach().to(torch.float32) + n = torch.linalg.vector_norm(t, norm_type) + if is_sharded: + sharded_norm_p += n.pow(norm_type) + n_sharded += 1 + else: + local_norm_p += n.pow(norm_type) + n_replicate += 1 + else: + n = torch.linalg.vector_norm(g.detach().to(torch.float32), norm_type) + local_norm_p += n.pow(norm_type) + n_regular += 1 + + # Symmetric across ranks: reached on every rank regardless of whether this rank had + # grads (see note above). Guard for the non-distributed case where local == global. + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(sharded_norm_p, op=dist.ReduceOp.SUM) + total_norm = (sharded_norm_p + local_norm_p).pow(1.0 / norm_type) + + clip_coef = torch.clamp(max_norm / (total_norm + 1e-6), max=1.0) + + # Debug: log computation breakdown on first 5 calls (no collectives — safe). + _clip_grad_norm_call_count += 1 + _rank0 = not (dist.is_available() and dist.is_initialized()) or dist.get_rank() == 0 + if _clip_grad_norm_call_count <= 5 and _rank0: + print( + f"[clip_grad_norm debug] call={_clip_grad_norm_call_count} " + f"total_norm={total_norm.item():.6f} " + f"sharded_norm_p={sharded_norm_p.item():.6f} local_norm_p={local_norm_p.item():.6f} " + f"grads={len(grads)} (sharded={n_sharded} replicate={n_replicate} regular={n_regular}) " + f"max_norm={max_norm} clip_coef={clip_coef.item():.6f}" + ) + for g in grads: + if isinstance(g, DTensor): + g._local_tensor.mul_(clip_coef) + else: + g.mul_(clip_coef) + + return total_norm + + +def patch_accelerator(accelerator): + """Replace accelerator's clip_grad_norm_ with FSDP2-safe version.""" + accelerator.clip_grad_norm_ = _clip_grad_norm + print( + "[fsdp2_buffer_patch] Patched accelerator.clip_grad_norm_ for FSDP2 DTensor compatibility" + ) + + +def log_param_dtypes(model): + """Debug aid: log per-rank parameter dtype counts (gated by DFLASH_LOG_PARAM_DTYPES=1). + + Used to verify the FSDP2 dtype synchronization above — after ``fully_shard()`` params + are DTensors whose dtype lives on ``_local_tensor``. Off by default; this is purely + diagnostic and has no effect on training. + """ + import os + + if os.environ.get("DFLASH_LOG_PARAM_DTYPES") != "1": + return + rank = int(os.environ.get("RANK", 0)) + dtypes = {} + for name, p in model.named_parameters(): + dt_key = str(p.dtype) if not hasattr(p, "_local_tensor") else str(p._local_tensor.dtype) + dtypes.setdefault(dt_key, []).append(name) + for dt, names in dtypes.items(): + print(f"[dtype_check rank={rank}] {dt}: {len(names)} params (e.g. {names[0]})") diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index f62b099121d..bf247cc3db4 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -33,9 +33,11 @@ import dataclasses import os +import fsdp2_buffer_patch import torch import transformers from eagle_utils import ( + DFlashFSDP2ShardedSDExportCallback, EagleTrainerWithAccLog, EagleTrainingPlot, LoRAWarmupCallback, @@ -64,6 +66,9 @@ torch.manual_seed(0) mto.enable_huggingface_checkpointing() +if os.environ.get("PATCH_FSDP2_BUFFERS_TF457") == "1": + fsdp2_buffer_patch.apply() + # HF-compatible TrainingArguments with our speculative-decoding extensions, auto-derived # from :class:`SpecTrainingArgs` so its field set can't drift from the Pydantic recipe schema. @@ -145,6 +150,23 @@ def init_distributed_env(training_args: transformers.TrainingArguments) -> None: ) +def _is_hf_format_checkpoint(checkpoint: str | None) -> bool: + """True if the checkpoint dir holds consolidated HF weights (from_pretrained-loadable). + + FSDP2 SHARDED_STATE_DICT checkpoints contain only distributed shards + (``pytorch_model_fsdp_*/``), no ``model.safetensors`` — those return False, signalling + the caller to load the base model and resume via the Trainer instead. This inspects the + on-disk format of the *resume* checkpoint, which is a property of the existing bytes and + is independent of the current run's save mode (the two can differ across runs), so it's + intentionally separate from the save-time FSDP state-dict-type gate used for the export + callback. + """ + if not checkpoint: + return False + hf_files = ("model.safetensors", "pytorch_model.bin", "model.safetensors.index.json") + return any(os.path.isfile(os.path.join(checkpoint, f)) for f in hf_files) + + def train(): config_path, dry_run, overrides = _parse_cli() recipe = load_recipe(config_path, overrides=overrides) @@ -181,7 +203,13 @@ def train(): use_offline_training = recipe.data.mode != "online" - if checkpoint: + # Resume path depends on the existing checkpoint's on-disk format: consolidated HF + # weights load via from_pretrained; FSDP sharded checkpoints load the base model and + # resume through the Trainer. + checkpoint_is_hf = _is_hf_format_checkpoint(checkpoint) + + if checkpoint_is_hf: + assert checkpoint is not None # guaranteed by checkpoint_is_hf with patch_transformers5_params_loading(): model = load_vlm_or_llm( checkpoint, dtype="auto", trust_remote_code=recipe.model.trust_remote_code @@ -190,6 +218,11 @@ def train(): checkpoint, trust_remote_code=recipe.model.trust_remote_code ) else: + if checkpoint: + print_rank_0( + f"Checkpoint {checkpoint} is not in HF format (FSDP distributed checkpoint). " + f"Loading base model and resuming via Trainer." + ) model_name_or_path = recipe.model.model_name_or_path if model_name_or_path is None: raise ValueError( @@ -245,20 +278,6 @@ def train(): ) return - # Move any remaining CPU buffers to CUDA so DDP (NCCL-only) can broadcast - # them. We iterate named_buffers and reassign via the owning module to - # keep the module tree consistent. Parameters are left on CPU — the HF - # Trainer will move them during init. - if torch.cuda.is_available(): - _target_dev = torch.device("cuda", 0) - for name, buf in list(model.named_buffers()): - if buf.device.type == "cpu": - parts = name.split(".") - mod = model - for p in parts[:-1]: - mod = getattr(mod, p) - setattr(mod, parts[-1], buf.to(_target_dev)) - print_rank_0("Loading dataset...") is_dflash = isinstance(recipe, ModelOptDFlashRecipe) data_module = make_speculative_data_module( @@ -289,6 +308,26 @@ def train(): **data_module, ) + if os.environ.get("PATCH_FSDP2_BUFFERS_TF457") == "1": + fsdp2_buffer_patch.patch_accelerator(trainer.accelerator) + + # DFlash: export the draft submodule after each checkpoint save — but only under FSDP2 + # SHARDED_STATE_DICT, where checkpoints are distributed shards the post-training + # export_hf_checkpoint.py pass can't read. Gate by reading the live FSDP state dict + # type off the accelerator; full-state-dict runs (DDP, single-device, FSDP2 + # FULL_STATE_DICT) use the launcher's post-run export instead. + if isinstance(recipe, ModelOptDFlashRecipe): + fsdp_plugin = getattr(trainer.accelerator.state, "fsdp_plugin", None) + sd_type = str(getattr(fsdp_plugin, "state_dict_type", "") or "") + if "SHARDED_STATE_DICT" in sd_type: + trainer.add_callback(DFlashFSDP2ShardedSDExportCallback()) + print_rank_0("DFlash: FSDP2 SHARDED_STATE_DICT — enabling per-save draft export.") + else: + print_rank_0( + f"DFlash: checkpoints use {sd_type or 'a full state dict'}; relying on the " + "launcher's post-run export (no per-save export callback added)." + ) + # Manually enable this to return loss in eval trainer.can_return_loss = True # Make sure label_smoother is None @@ -296,6 +335,9 @@ def train(): "label_smoother is not supported in speculative decoding!" ) + # Diagnostic (no-op unless DFLASH_LOG_PARAM_DTYPES=1): verifies FSDP2 dtype sync. + fsdp2_buffer_patch.log_param_dtypes(trainer.model) + print_rank_0("Start training...") trainer.train(resume_from_checkpoint=checkpoint) trainer.save_state() diff --git a/modelopt/torch/export/plugins/hf_spec_export.py b/modelopt/torch/export/plugins/hf_spec_export.py index 54d6e493c25..95b2de864f4 100644 --- a/modelopt/torch/export/plugins/hf_spec_export.py +++ b/modelopt/torch/export/plugins/hf_spec_export.py @@ -376,11 +376,15 @@ def _export_config(self): "initializer_range": getattr(base_config, "initializer_range", 0.02), "attention_bias": getattr(draft_config, "attention_bias", False), "attention_dropout": getattr(draft_config, "attention_dropout", 0.0), - "rope_theta": getattr( - draft_config, "rope_theta", getattr(base_config, "rope_theta", 1000000.0) + # Inherit the target's rope_theta: DFlash injects the target's KV into every + # draft layer, so the draft's RoPE base must match the target's. (The draft + # arch config carries no rope_theta of its own.) + "rope_theta": ( + getattr(base_config, "rope_theta", None) + if getattr(base_config, "rope_theta", None) is not None + else getattr(draft_config, "rope_theta", 1000000.0) ), - # DFlash draft uses standard Qwen3 RoPE, not M-RoPE from multimodal models. - # z-lab uses null; vLLM handles null rope_scaling correctly. + # YaRN long-context scaling is injected below (see the rope_scaling block). "rope_scaling": None, "tie_word_embeddings": False, "torch_dtype": str(getattr(base_config, "torch_dtype", torch.bfloat16)).replace( @@ -395,6 +399,12 @@ def _export_config(self): else: config["layer_types"] = ["full_attention"] * draft_config.num_hidden_layers + # Inject the export-time YaRN rope_scaling from the dflash_export_rope_scaling + # config field (empty dict disables). Mirrors eagle's eagle_export_rope_scaling. + export_rope_scaling = getattr(self.model, "dflash_export_rope_scaling", None) + if export_rope_scaling: + config["rope_scaling"] = export_rope_scaling + return config def export(self, export_dir: Path | str, dtype: torch.dtype | None = None): diff --git a/modelopt/torch/speculative/config.py b/modelopt/torch/speculative/config.py index 7649b2d0357..8da7b6ec93e 100644 --- a/modelopt/torch/speculative/config.py +++ b/modelopt/torch/speculative/config.py @@ -132,6 +132,20 @@ class DFlashConfig(ModeloptBaseConfig): description="Whether to use torch.compile on DFlash forward/loss methods.", ) + dflash_export_rope_scaling: dict = ModeloptField( + default={}, + description=( + "The rope_scaling config to inject into the exported HuggingFace draft config. " + "The DFlash draft trains on a short window but must draft for the target at long " + "context, so — mirroring published Eagle3 drafts such as nvidia/Kimi-K2.6-Eagle3 — " + "a YaRN rope_scaling is injected at export to extend the training window to the " + "target's full context. Example: " + '{"type": "yarn", "factor": 48.0, "original_max_position_embeddings": 4096, ' + '"beta_fast": 1.0, "beta_slow": 1.0, "mscale": 1.0, "mscale_all_dim": 1.0}. ' + "Set to empty dict {} (default) to disable rope scaling injection at export." + ), + ) + class MedusaConfig(ModeloptBaseConfig): """Medusa config.""" diff --git a/modelopt/torch/speculative/dflash/dflash_model.py b/modelopt/torch/speculative/dflash/dflash_model.py index a99e93c816a..702d9812482 100644 --- a/modelopt/torch/speculative/dflash/dflash_model.py +++ b/modelopt/torch/speculative/dflash/dflash_model.py @@ -35,3 +35,4 @@ def modify(self, config): self.dflash_num_anchors = config.dflash_num_anchors self.dflash_report_acc = config.dflash_report_acc self.dflash_use_torch_compile = config.dflash_use_torch_compile + self.dflash_export_rope_scaling = config.dflash_export_rope_scaling diff --git a/modelopt/torch/speculative/plugins/hf_dflash.py b/modelopt/torch/speculative/plugins/hf_dflash.py index 1760cb2072d..9e54ee531ce 100644 --- a/modelopt/torch/speculative/plugins/hf_dflash.py +++ b/modelopt/torch/speculative/plugins/hf_dflash.py @@ -132,26 +132,45 @@ def modify(self, config): self.dflash_config.hidden_size = base_config.hidden_size self.dflash_config.vocab_size = base_config.vocab_size - # Inherit architecture settings from base model when not specified by user. - # Static defaults (hidden_act, attention_bias, etc.) are in dflash/default_config.py. - # NOTE: rope_scaling is intentionally excluded. DFlash draft uses Qwen3 - # RotaryEmbedding which only supports standard RoPE. Inheriting M-RoPE - # config from multimodal models (e.g. Qwen3.5) would be incorrect. - _base_model_attrs = [ + # Inherit architecture settings from base model when not specified by user + # (setdefault). Static defaults (hidden_act, attention_bias, etc.) are in + # dflash/default_config.py. + _setdefault_attrs = [ "max_position_embeddings", "intermediate_size", "num_attention_heads", "num_key_value_heads", - "rope_theta", - "rope_type", - "rope_interleaved", "rms_norm_eps", ] - for attr in _base_model_attrs: + for attr in _setdefault_attrs: if not hasattr(self.dflash_config, attr) or getattr(self.dflash_config, attr) is None: if hasattr(base_config, attr): setattr(self.dflash_config, attr, getattr(base_config, attr)) + # RoPE base settings are ENFORCED to match the base model (not setdefault): the + # DFlash draft injects the target's KV into every layer, so its RoPE base must + # match the target's for the injected positions to align — and the exporter writes + # the base model's rope_theta. Letting dflash_architecture_config override these + # would make training (draft rope) and inference (base rope) disagree, so we + # overwrite any user value and warn. (rope_scaling is intentionally NOT inherited: + # DFlash uses standard Qwen3 RotaryEmbedding; the long-context YaRN scaling is + # added only at export via dflash_export_rope_scaling.) + for attr in ("rope_theta", "rope_type", "rope_interleaved"): + if not hasattr(base_config, attr): + continue + base_val = getattr(base_config, attr) + user_val = getattr(self.dflash_config, attr, None) + if user_val is not None and user_val != base_val: + logger.warning( + "DFlash: ignoring dflash_architecture_config.%s=%r and enforcing the " + "base model's value %r — the draft injects the target's KV, so its RoPE " + "base must match the target's.", + attr, + user_val, + base_val, + ) + setattr(self.dflash_config, attr, base_val) + self.dflash_config.head_dim = getattr( self.dflash_config, "head_dim", diff --git a/tests/unit/torch/export/test_hf_spec_rope_export.py b/tests/unit/torch/export/test_hf_spec_rope_export.py index 171fe6263c6..720082bc617 100644 --- a/tests/unit/torch/export/test_hf_spec_rope_export.py +++ b/tests/unit/torch/export/test_hf_spec_rope_export.py @@ -13,11 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Unit tests for EAGLE export rope scaling logic in hf_spec_export.py.""" +"""Unit tests for EAGLE/DFlash export rope scaling logic in hf_spec_export.py.""" +from types import SimpleNamespace from unittest.mock import MagicMock -from modelopt.torch.export.plugins.hf_spec_export import EagleExporter +import torch + +from modelopt.torch.export.plugins.hf_spec_export import DFlashExporter, EagleExporter DEFAULT_ROPE_SCALING = { "rope_type": "yarn", @@ -76,3 +79,63 @@ def test_rope_theta_fallback_from_rope_scaling(): """rope_theta is populated from rope_scaling when not available as top-level attr.""" config = _make_exporter(rope_type="default", rope_theta=500000)._export_config() assert config["rope_theta"] == 500000 + + +# --------------------------------------------------------------------------- +# DFlash export rope scaling (config-field convergence, mirrors the eagle style) +# --------------------------------------------------------------------------- + +DFLASH_YARN = { + "type": "yarn", + "factor": 48.0, + "original_max_position_embeddings": 4096, + "beta_fast": 1.0, + "beta_slow": 1.0, + "mscale": 1.0, + "mscale_all_dim": 1.0, +} + + +def _make_dflash_exporter(dflash_export_rope_scaling=None, base_rope_theta=5000000.0): + base_config = SimpleNamespace( + hidden_size=128, + num_attention_heads=4, + num_key_value_heads=2, + intermediate_size=256, + vocab_size=1000, + max_position_embeddings=196608, + initializer_range=0.02, + num_hidden_layers=8, + rope_theta=base_rope_theta, + torch_dtype=torch.bfloat16, + ) + draft_config = SimpleNamespace(num_hidden_layers=2) + model = SimpleNamespace( + config=base_config, + dflash_config=draft_config, + dflash_block_size=8, + mask_token_id=999, + target_layer_ids=[1, 3, 5, 7], + dflash_export_rope_scaling=dflash_export_rope_scaling, + ) + exporter = DFlashExporter.__new__(DFlashExporter) + exporter.model = model + return exporter + + +def test_dflash_yarn_rope_injected_from_config_field(): + """YaRN rope_scaling from dflash_export_rope_scaling is injected verbatim.""" + config = _make_dflash_exporter(dflash_export_rope_scaling=DFLASH_YARN)._export_config() + assert config["rope_scaling"] == DFLASH_YARN + + +def test_dflash_rope_not_injected_when_field_empty(): + """Empty dict (default) disables rope scaling injection.""" + config = _make_dflash_exporter(dflash_export_rope_scaling={})._export_config() + assert config["rope_scaling"] is None + + +def test_dflash_rope_theta_inherits_base(): + """rope_theta is inherited from the target/base config (draft drafts for the base).""" + config = _make_dflash_exporter(base_rope_theta=5000000.0)._export_config() + assert config["rope_theta"] == 5000000.0 diff --git a/tools/launcher/common/specdec/dflash_online_training.sh b/tools/launcher/common/specdec/dflash_online_training.sh index 60c2b9901e1..efc38be3b17 100644 --- a/tools/launcher/common/specdec/dflash_online_training.sh +++ b/tools/launcher/common/specdec/dflash_online_training.sh @@ -42,6 +42,13 @@ pip install -r modules/Model-Optimizer/examples/speculative_decoding/requirement pip install huggingface-hub>=1.2.1 export PATH=$PATH:/workspace/.local/bin +# Some trust_remote_code MoE models pin an older transformers (e.g. MiniMax-M2.7 +# needs 4.57.x; its modeling code is incompatible with the 5.x the recipe defaults +# pull in). Override here when the model needs it. +if [ -n "${OVERRIDE_TRANSFORMERS:-}" ]; then + pip install "transformers==${OVERRIDE_TRANSFORMERS}" +fi + ################################################################################################### trap 'error_handler $0 $LINENO' ERR @@ -99,9 +106,18 @@ fi export TOKENIZERS_PARALLELISM=False +# ACCELERATE_CONFIG selects an explicit accelerate config file (e.g. an FSDP2 YAML). +# Required for trust_remote_code MoE models that need FSDP2 via accelerate config +# rather than transformers-native ParallelismConfig (MiniMax-M2.7 on transformers 4.57.x). +ACCEL_CONFIG_ARGS="" +if [ -n "${ACCELERATE_CONFIG:-}" ] && [ -f "${ACCELERATE_CONFIG}" ]; then + ACCEL_CONFIG_ARGS="--config_file ${ACCELERATE_CONFIG}" + echo "Using accelerate config: ${ACCELERATE_CONFIG}" +fi + set -x start_time=$(date +%s) -accelerate launch --mixed_precision bf16 $MULTI_NODE_ARGS $MAIN_PY "$@" +accelerate launch $ACCEL_CONFIG_ARGS --mixed_precision "${MIXED_PRECISION:-bf16}" $MULTI_NODE_ARGS $MAIN_PY "$@" echo "Training time: $(( $(date +%s) - start_time )) seconds" set +x diff --git a/tools/launcher/core.py b/tools/launcher/core.py index 7f12c368efe..17d6cb44138 100644 --- a/tools/launcher/core.py +++ b/tools/launcher/core.py @@ -316,8 +316,17 @@ def build_slurm_executor( retries=0, packager=packager, srun_args=slurm_config.srun_args, + # Copy into a fresh dict so the requeue mutation below doesn't leak back into + # the shared slurm_config.additional_parameters. + additional_parameters=dict(getattr(slurm_config, "additional_parameters", None) or {}), **optional_kwargs, ) + if getattr(slurm_config, "requeue", False): + executor.additional_parameters["requeue"] = True + # The nemo-run sbatch wrapper only calls `scontrol requeue` when + # TORCHX_MAX_RETRIES > SLURM_RESTART_COUNT. retries=0 (the default) + # disables this, so bump it when requeue is requested. + executor.retries = max(executor.retries, 3) return executor diff --git a/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/accelerate_fsdp2_hybrid.yaml b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/accelerate_fsdp2_hybrid.yaml new file mode 100644 index 00000000000..3310d1e5c7e --- /dev/null +++ b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/accelerate_fsdp2_hybrid.yaml @@ -0,0 +1,11 @@ +compute_environment: LOCAL_MACHINE +distributed_type: FSDP +fsdp_config: + fsdp_version: 2 + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_transformer_layer_cls_to_wrap: MiniMaxM2DecoderLayer,DFlashModule + fsdp_sharding_strategy: HYBRID_SHARD + fsdp_use_orig_params: true + fsdp_state_dict_type: SHARDED_STATE_DICT + fsdp_sync_module_states: true + fsdp_cpu_ram_efficient_loading: true diff --git a/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/chat_template_train.jinja b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/chat_template_train.jinja new file mode 100644 index 00000000000..178dc002069 --- /dev/null +++ b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/chat_template_train.jinja @@ -0,0 +1,162 @@ +{# MiniMax-M2.7 chat template with {% generation %} tags for answer_only_loss training. + Adapted from https://huggingface.co/MiniMaxAI/MiniMax-M2.7/blob/main/chat_template.jinja + with {% generation %} / {% endgeneration %} wrapping assistant content. +#} +{%- set toolcall_begin_token = '' -%} +{%- set toolcall_end_token = '' -%} +{#- Tool Rendering Functions ============================================== -#} +{%- macro render_tool_namespace(namespace_name, tool_list) -%} +{%- for tool in tool_list -%} +{{ tool.function | tojson(ensure_ascii=False) }} +{% endfor -%} +{%- endmacro -%} +{%- macro visible_text(content) -%} + {%- if content is string -%} + {{ content }} + {%- elif content is iterable and content is not mapping -%} + {%- for item in content -%} + {%- if item is mapping and item.type == 'text' -%} + {{- item.text }} + {%- elif item is string -%} + {{- item }} + {%- endif -%} + {%- endfor -%} + {%- else -%} + {{- content }} + {%- endif -%} +{%- endmacro -%} +{#- System Message Construction ============================================ -#} +{%- macro build_system_message(system_message) -%} + {%- if system_message and system_message.content -%} + {{- visible_text(system_message.content) }} + {%- else -%} + {%- if model_identity is not defined -%} + {%- set model_identity = "You are a helpful assistant. Your name is MiniMax-M2.7 and is built by MiniMax." -%} + {%- endif -%} + {{- model_identity }} + {%- endif -%} + + {#- Handle current_date -#} + {%- if system_message and system_message.current_date -%} + {{- '\n' ~ 'Current date: ' + system_message.current_date }} + {%- endif -%} + {#- Handle current_location -#} + {%- if system_message and system_message.current_location -%} + {{- '\n' ~ 'Current location: ' + system_message.current_location }} + {%- endif -%} +{%- endmacro -%} +{#- Main Template Logic ================================================= -#} +{#- Extract system message (only first message if it's system) -#} +{%- set system_message = none -%} +{%- set conversation_messages = messages -%} +{%- if messages and messages[0].role == "system" -%} + {%- set system_message = messages[0] -%} + {%- set conversation_messages = messages[1:] -%} +{%- endif -%} +{#- Get the last user message turn, for interleaved thinking -#} +{%- set ns = namespace(last_user_index=-1) %} +{% for m in conversation_messages %} + {%- if m.role == 'user' %} + {% set ns.last_user_index = loop.index0 -%} + {%- endif %} +{%- endfor %} +{#- Render system message -#} +{{- ']~!b[' ~ ']~b]system' ~ '\n' }} +{{- build_system_message(system_message) }} +{#- Render tools if available -#} +{%- if tools -%} + {{- '\n\n' ~ '# Tools' ~ '\n' ~ 'You may call one or more tools to assist with the user query.\nHere are the tools available in JSONSchema format:' ~ '\n' }} + {{- '\n' ~ '' ~ '\n' }} + {{- render_tool_namespace("functions", tools) }} + {{- '' ~ '\n\n' }} +{{- 'When making tool calls, use XML format to invoke tools and pass parameters:' ~ '\n' }} +{{- '\n' ~ toolcall_begin_token }} + +param-value-1 +param-value-2 +... + +{{- '\n' ~ toolcall_end_token }} +{%- endif -%} +{{- '[e~[\n' }} + +{#- Render messages -#} +{%- set last_tool_call = namespace(name=none) -%} +{%- for message in conversation_messages -%} + {%- if message.role == 'assistant' -%} + {{- ']~b]ai' ~ '\n' }} + {%- generation -%} + {%- set reasoning_content = '' %} + {%- set content = visible_text(message.content) %} + {%- if message.reasoning_content is string %} + {%- set reasoning_content = message.reasoning_content %} + {%- else %} + {%- if '' in content %} + {%- set reasoning_content = content.split('')[0].strip('\n').split('')[-1].strip('\n') %} + {%- set content = content.split('')[-1].strip('\n') %} + {%- endif %} + {%- endif %} + {%- if reasoning_content and loop.index0 > ns.last_user_index -%} + {{- '' ~ '\n' ~ reasoning_content ~ '\n' ~ '' ~ '\n\n' }} + {%- endif -%} + {%- if content -%} + {{- content }} + {%- endif -%} + {%- if message.tool_calls -%} + {{- '\n' ~ toolcall_begin_token ~ '\n' }} + + {%- for tool_call in message.tool_calls -%} + {%- if tool_call.function %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '' }} + {% set _args = tool_call.arguments %} + {%- for k, v in _args.items() %} + {{- '' }} + {{- v | tojson(ensure_ascii=False) if v is not string else v }} + {{- '' }} + {% endfor %} + {{- '' ~ '\n' }} + {%- endfor -%} + + {{- toolcall_end_token}} + {%- set last_tool_call.name = message.tool_calls[-1].name -%} + {%- else -%} + {%- set last_tool_call.name = none -%} + {%- endif -%} + {%- endgeneration -%} + {{- '[e~[' ~ '\n' }} + + {%- elif message.role == 'tool' -%} + {%- if last_tool_call.name is none -%} + {{- raise_exception("Message has tool role, but there was no previous assistant message with a tool call!") }} + {%- endif -%} + {%- if loop.first or (conversation_messages[loop.index0 - 1].role != 'tool') -%} + {{- ']~b]tool' }} + {%- endif -%} + {%- if message.content is string -%} + {{- '\n' }} + {{- message.content }} + {{- '' }} + {%- else -%} + {%- for tr in message.content -%} + {{- '\n' }} + {{- tr.output if tr.output is defined else (tr.text if tr.type == 'text' and tr.text is defined else tr) }} + {{- '\n' }} + {%- endfor -%} + {%- endif -%} + {%- if loop.last or (conversation_messages[loop.index0 + 1].role != 'tool') -%} + {{- '[e~[\n' -}} + {%- endif -%} + + {%- elif message.role == 'user' -%} + {{- ']~b]user' ~ '\n' }} + {{- visible_text(message.content) }} + {{- '[e~[' ~ '\n' }} + {%- endif -%} +{%- endfor -%} + +{#- Generation prompt -#} +{%- if add_generation_prompt -%} +{{- ']~b]ai' ~ '\n' ~ '' ~ '\n' }} +{%- endif -%} diff --git a/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_offline_dflash.yaml b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_offline_dflash.yaml new file mode 100644 index 00000000000..dc3dba4a65a --- /dev/null +++ b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_offline_dflash.yaml @@ -0,0 +1,97 @@ +# DFlash offline speculative decoding training for MiniMax-M2.7 (229B MoE). +# +# 2-step pipeline (compare with hf_online_dflash.yaml, which streams the base model +# forward at training time instead): +# task_0: Dump base-model hidden states once via vLLM extract_hidden_states. +# task_1: Train the DFlash draft on the dump (FakeBaseModel — loads only lm_head + +# embed_tokens, not the full 229B base). +# +# We use the vLLM dump (compute_hidden_states_vllm.py) rather than the HF dump because +# the 229B MoE is impractical to forward on a single GPU; vLLM shards it with TP. The +# dump disables prefix caching so every token's hidden state is emitted and the dumped +# sequence lengths line up with input_ids/loss_mask. +# +# Reference: "DFlash: Block Diffusion for Flash Speculative Decoding" (arXiv:2602.06036) +# +# Usage: +# uv run slurm.py --yaml modules/Model-Optimizer/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_offline_dflash.yaml --yes + +job_name: MiniMax-M2.7-DFlash_offline +pipeline: + global_vars: + hf_model: /hf-local/MiniMaxAI/MiniMax-M2.7 + + # Step 1: Dump base-model hidden states via vLLM extract_hidden_states (TP=4). + task_0: + script: common/eagle3/dump_offline_data_vllm.sh + args: + - --input-data /hf-local/modelopt/MiniMax-M2.7-synthetic-data-clean-v2 + - --output-dir /scratchspace/dflash_minimax_m2.7_hidden_states + # Must match the draft model's num_hidden_layers (recipe default: 5). + - --aux-layers dflash + - --answer-only-loss + - --chat-template examples/MiniMax/MiniMax-M2.7-DFlash/chat_template_train.jinja + - --max-seq-len 4096 + - --tp 4 + environment: + - HF_MODEL_CKPT: <> + - TRUST_REMOTE_CODE: "1" + slurm_config: + _factory_: "slurm_factory" + nodes: 1 + ntasks_per_node: 1 + gpus_per_node: 8 + container: vllm/vllm-openai:nightly + + # Step 2: Train DFlash offline on the dumped hidden states. FakeBaseModel avoids + # loading the full 229B — only lm_head + embed_tokens are read from the checkpoint. + task_1: + script: common/specdec/dflash_online_training.sh + args: + - --config modules/Model-Optimizer/modelopt_recipes/general/speculative_decoding/dflash.yaml + - model.model_name_or_path=<> + - model.trust_remote_code=true + - model.use_fake_base_for_offline=true + - data.mode=offline + - data.offline_data_path=/scratchspace/dflash_minimax_m2.7_hidden_states + - data.chat_template=examples/MiniMax/MiniMax-M2.7-DFlash/chat_template_train.jinja + - training.output_dir=/scratchspace/dflash_minimax_m2.7_offline + - training.num_train_epochs=10 + - training.per_device_train_batch_size=2 + - training.learning_rate=1.2e-3 + - training.warmup_steps=100 + - training.training_seq_len=4096 + - training.logging_steps=100 + - training.save_steps=400 + - training.disable_tqdm=true + - training.dp_shard_size=1 + - training.answer_only_loss=true + - training.ddp_timeout=3600 + - training.bf16=false + - dflash.dflash_self_logit_distillation=true + - dflash.dflash_block_size=8 + - dflash.dflash_num_anchors=512 + - dflash.dflash_loss_decay_factor=4.0 + - dflash.dflash_architecture_config.num_hidden_layers=5 + # Mask token id (an existing reserved row in MiniMax-M2.7's embedding). + - dflash.dflash_mask_token_id=200054 + # YaRN rope_scaling injected at export for long context (factor = 196608/4096 = 48). + - dflash.dflash_export_rope_scaling.type=yarn + - dflash.dflash_export_rope_scaling.factor=48.0 + - dflash.dflash_export_rope_scaling.original_max_position_embeddings=4096 + - dflash.dflash_export_rope_scaling.beta_fast=1.0 + - dflash.dflash_export_rope_scaling.beta_slow=1.0 + - dflash.dflash_export_rope_scaling.mscale=1.0 + - dflash.dflash_export_rope_scaling.mscale_all_dim=1.0 + environment: + - NUM_NODES: "8" + # Offline training uses a lightweight FakeBaseModel, so plain DDP suffices (no + # ACCELERATE_CONFIG / FSDP2 patches). OVERRIDE_TRANSFORMERS pins 4.57.1 for the + # MiniMax config load. + - OVERRIDE_TRANSFORMERS: "4.57.1" + - MIXED_PRECISION: "no" + slurm_config: + _factory_: "slurm_factory" + nodes: 8 + ntasks_per_node: 1 + gpus_per_node: 8 diff --git a/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_online_dflash.yaml b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_online_dflash.yaml new file mode 100644 index 00000000000..92c450ee550 --- /dev/null +++ b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_online_dflash.yaml @@ -0,0 +1,108 @@ +# DFlash online speculative decoding training for MiniMax-M2.7 (229B MoE). +# +# 3-step pipeline: +# task_0: Online DFlash training (8 nodes x 8 GPU, FSDP2 HYBRID_SHARD) +# task_1: vLLM smoke test with the exported DFlash draft +# task_2: HF AR (acceptance-length) evaluation on MT-Bench (1 GPU) +# +# MiniMax-M2.7 specifics: +# - trust_remote_code model whose code requires transformers 4.57.x, so FSDP2 is +# configured via an accelerate config (accelerate_fsdp2_hybrid.yaml) rather than +# transformers-native ParallelismConfig. Hence OVERRIDE_TRANSFORMERS + ACCELERATE_CONFIG +# + dp_shard_size=1 (keeps main.py from building a ParallelismConfig) + PATCH_FSDP2_BUFFERS_TF457. +# - 229B in FP8 needs cpu_ram_efficient_loading (set in the accelerate config) and +# MIXED_PRECISION=no (the recipe / checkpoint already carry the right dtypes). +# - The DFlash draft is Qwen3-architecture (5 layers); the target/draft architectures +# are independent by design. +# +# Reference: "DFlash: Block Diffusion for Flash Speculative Decoding" (arXiv:2602.06036) +# +# Usage: +# uv run launch.py --yaml examples/MiniMax/MiniMax-M2.7-DFlash/hf_online_dflash.yaml --yes +# uv run slurm.py --yaml modules/Model-Optimizer/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/hf_online_dflash.yaml --yes + +job_name: MiniMax-M2.7-DFlash_online +pipeline: + global_vars: + hf_model: /hf-local/MiniMaxAI/MiniMax-M2.7 + + # Step 1: Online DFlash training. + task_0: + script: common/specdec/dflash_online_training.sh + args: + - --config modules/Model-Optimizer/modelopt_recipes/general/speculative_decoding/dflash.yaml + - model.model_name_or_path=<> + - model.trust_remote_code=true + - data.data_path=/hf-local/modelopt/MiniMax-M2.7-synthetic-data-clean-v2 + - data.chat_template=examples/MiniMax/MiniMax-M2.7-DFlash/chat_template_train.jinja + - training.output_dir=/scratchspace/dflash_minimax_m2.7 + - training.num_train_epochs=10 + - training.per_device_train_batch_size=2 + - training.learning_rate=1.2e-3 + - training.warmup_steps=100 + - training.training_seq_len=4096 + - training.logging_steps=100 + - training.save_steps=400 + - training.disable_tqdm=true + # dp_shard_size=1 stops main.py from creating a ParallelismConfig; the + # accelerate config below owns FSDP2 sharding. + - training.dp_shard_size=1 + - training.answer_only_loss=true + - training.ddp_timeout=3600 + - training.bf16=false + - dflash.dflash_self_logit_distillation=true + - dflash.dflash_block_size=8 + - dflash.dflash_num_anchors=512 + - dflash.dflash_loss_decay_factor=4.0 + - dflash.dflash_architecture_config.num_hidden_layers=5 + # Mask token id. The draft reuses the target's embed_tokens, so this must be an id + # that exists in the target embedding; 200054 is a reserved row in MiniMax-M2.7. + - dflash.dflash_mask_token_id=200054 + # YaRN rope_scaling injected at export for long context (factor = 196608/4096 = 48). + - dflash.dflash_export_rope_scaling.type=yarn + - dflash.dflash_export_rope_scaling.factor=48.0 + - dflash.dflash_export_rope_scaling.original_max_position_embeddings=4096 + - dflash.dflash_export_rope_scaling.beta_fast=1.0 + - dflash.dflash_export_rope_scaling.beta_slow=1.0 + - dflash.dflash_export_rope_scaling.mscale=1.0 + - dflash.dflash_export_rope_scaling.mscale_all_dim=1.0 + environment: + - NUM_NODES: "8" + - OVERRIDE_TRANSFORMERS: "4.57.1" + - ACCELERATE_CONFIG: examples/MiniMax/MiniMax-M2.7-DFlash/accelerate_fsdp2_hybrid.yaml + - PATCH_FSDP2_BUFFERS_TF457: "1" + - MIXED_PRECISION: "no" + slurm_config: + _factory_: "slurm_factory" + nodes: 8 + ntasks_per_node: 1 + gpus_per_node: 8 + + # Step 2: vLLM smoke test with the exported DFlash draft. + task_1: + script: common/specdec/vllm_smoke_test.sh + environment: + - HF_MODEL_CKPT: <> + - DRAFT_CKPT_DIR: /scratchspace/dflash_minimax_m2.7 + - SPEC_METHOD: "dflash" + - NUM_SPEC_TOKENS: "7" + - TP_SIZE: "4" + slurm_config: + _factory_: "slurm_factory" + nodes: 1 + ntasks_per_node: 1 + gpus_per_node: 8 + container: vllm/vllm-openai:nightly + + # Step 3: HF AR (acceptance-length) evaluation on MT-Bench. + task_2: + script: common/specdec/ar_eval_mtbench.sh + args: + - --ckpt_dir /scratchspace/dflash_minimax_m2.7 + - --osl 512 + - --steps 7 + slurm_config: + _factory_: "slurm_factory" + nodes: 1 + ntasks_per_node: 1 + gpus_per_node: 8 diff --git a/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/specdec_bench.yaml b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/specdec_bench.yaml new file mode 100644 index 00000000000..ad538805fb6 --- /dev/null +++ b/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/specdec_bench.yaml @@ -0,0 +1,90 @@ +# SPEED-Bench run for MiniMax-M2.7 with the trained DFlash draft, via vLLM. +# +# Two-task pipeline: +# task_0 qualitative split (nvidia/SPEED-Bench-Internal/qualitative, 80 prompts) +# task_1 throughput_32k split (nvidia/SPEED-Bench-Internal/throughput_32k, long context) +# +# Both use --speculative_algorithm DFLASH with the exported draft. Acceptance length +# (AL) is the primary metric; the throughput_32k split checks that AL holds at long +# context (if it drops, add YaRN/rope_scaling to the draft export). +# +# Results write to /scratchspace/minimax_m2.7_dflash_vllm//. The +# pensieve-intern specdec_bench workflow's wrap_up stage publishes these to +# s3://team-specdec-workgroup/results/minimax_m2.7_dflash_vllm//. +# +# Set draft_model to the trained, exported draft checkpoint (an exported-checkpoint-* +# dir containing model.safetensors), e.g. the output of hf_online_dflash.yaml. +# +# Usage: +# Edit the two --draft_model_dir args to point at your exported draft checkpoint, then: +# uv run slurm.py --yaml modules/Model-Optimizer/tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/specdec_bench.yaml --yes +# +# NOTE: the placeholder below is `exported-checkpoint-final` — that dir is written by the +# training launcher's post-run export (dflash_online_training.sh) for the final model. The +# per-save DFlashFSDP2ShardedSDExportCallback instead writes `exported-checkpoint-` +# dirs; to bench a specific intermediate step, point at one of those. + +job_name: MiniMax-M2.7-DFlash_specdec_bench +pipeline: + global_vars: + hf_model: /hf-local/MiniMaxAI/MiniMax-M2.7 + + # Step 1: qualitative split — quality / acceptance-length numbers. + task_0: + script: common/specdec_bench/run.sh + args: + - --dataset speed + - --dataset_path /hf-local/nvidia/SPEED-Bench-Internal/qualitative + - --engine VLLM + - --speculative_algorithm DFLASH + - --draft_model_dir /scratchspace/dflash_minimax_m2.7/exported-checkpoint-final + - --block_size 8 + - --tp_size 4 + - --ep_size 4 + - --concurrency 32 + - --output_length 4096 + - --trust_remote_code + - --aa_timing + - --show_progress + - --save_dir /scratchspace/minimax_m2.7_dflash_vllm/qualitative + environment: + - HF_MODEL_CKPT: <> + - HF_LOCAL: /hf-local + slurm_config: + _factory_: "slurm_factory" + nodes: 1 + ntasks_per_node: 1 + gpus_per_node: 8 + container: vllm/vllm-openai:nightly + + # Step 2: throughput_32k split — long-context throughput / AL. + # --num_requests 80 caps the 1,536-sample split to fit the time limit; --max_seq_len + # 40960 = 32K input + 4K output + 4K headroom so vLLM doesn't auto-cap below 36K. + task_1: + script: common/specdec_bench/run.sh + args: + - --dataset speed + - --dataset_path /hf-local/nvidia/SPEED-Bench-Internal/throughput_32k + - --engine VLLM + - --speculative_algorithm DFLASH + - --draft_model_dir /scratchspace/dflash_minimax_m2.7/exported-checkpoint-final + - --block_size 8 + - --tp_size 4 + - --ep_size 4 + - --concurrency 8 + - --num_requests 80 + - --output_length 4096 + - --max_seq_len 40960 + - --trust_remote_code + - --aa_timing + - --show_progress + - --save_dir /scratchspace/minimax_m2.7_dflash_vllm/throughput_32k + environment: + - HF_MODEL_CKPT: <> + - HF_LOCAL: /hf-local + slurm_config: + _factory_: "slurm_factory" + nodes: 1 + ntasks_per_node: 1 + gpus_per_node: 8 + container: vllm/vllm-openai:nightly diff --git a/tools/launcher/tests/test_slurm_executor.py b/tools/launcher/tests/test_slurm_executor.py index 900616136e3..6baec571e3b 100644 --- a/tools/launcher/tests/test_slurm_executor.py +++ b/tools/launcher/tests/test_slurm_executor.py @@ -34,6 +34,7 @@ def test_scratch_and_modelopt_mounts(self, mock_tunnel, mock_executor): mock_tunnel.return_value = MagicMock() slurm_config = MagicMock( + requeue=False, host="test-host", port=22, account="test_account", @@ -77,6 +78,7 @@ def test_scratch_path_uses_experiment_title(self, mock_tunnel, mock_executor): mock_tunnel.return_value = MagicMock() slurm_config = MagicMock( + requeue=False, host="host", port=22, account="acct", @@ -112,6 +114,7 @@ def test_tunnel_created_with_correct_params(self, mock_tunnel, mock_executor): mock_tunnel.return_value = MagicMock() slurm_config = MagicMock( + requeue=False, host="login.cluster.com", port=30022, account="acct", @@ -150,6 +153,7 @@ def test_executor_params(self, mock_tunnel, mock_executor): mock_tunnel.return_value = MagicMock() slurm_config = MagicMock( + requeue=False, host="h", port=22, account="my_acct", @@ -195,6 +199,7 @@ def test_none_container_mounts_handled(self, mock_tunnel, mock_executor): mock_tunnel.return_value = MagicMock() slurm_config = MagicMock( + requeue=False, host="h", port=22, account="a", @@ -222,3 +227,42 @@ def test_none_container_mounts_handled(self, mock_tunnel, mock_executor): # Should not crash; mounts should still include scratch + modelopt + title mounts = mock_executor.call_args[1]["container_mounts"] assert len(mounts) >= 3 + + @patch("core.run.SlurmExecutor") + @patch("core.run.SSHTunnel") + def test_requeue_sets_param_and_bumps_retries(self, mock_tunnel, mock_executor): + mock_tunnel.return_value = MagicMock() + executor = mock_executor.return_value + executor.retries = 0 + executor.additional_parameters = {} + + slurm_config = MagicMock( + requeue=True, + host="h", + port=22, + account="a", + partition="b", + container="c", + modelopt_install_path="/m", + container_mounts=[], + srun_args=[], + nodes=1, + ntasks_per_node=1, + gpus_per_node=1, + array=None, + ) + + build_slurm_executor( + user="u", + identity=None, + slurm_config=slurm_config, + experiment_id="e", + job_dir="/j", + task_name="t", + packager=MagicMock(), + ) + + # requeue=True flags the additional parameter and bumps retries above 0 so + # nemo-run's sbatch wrapper actually issues `scontrol requeue` on preemption. + assert executor.additional_parameters["requeue"] is True + assert executor.retries == 3