-
Notifications
You must be signed in to change notification settings - Fork 444
DFlash speculative decoding for MiniMax-M2.7 (FSDP2): auto mask-token, FSDP2 resume fixes, per-checkpoint draft export #1621
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c23d002
4a83b91
a26022e
b43c7f9
967a2c0
ca89c84
7a21536
207f257
f63a387
b9baede
79634e9
b54c862
e00aaba
9dbd693
1316760
45f3a1e
035bad7
645cd5b
426fadf
bc97213
c7fec79
4df2d60
e82ef56
e826afb
5846657
4912379
01778c8
786b718
ff7dae6
8281422
a8c67e8
69ad872
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,10 +25,19 @@ | |
| """ | ||
|
|
||
| import argparse | ||
| import atexit | ||
| import os | ||
| import shutil | ||
| from pathlib import Path | ||
|
|
||
| import torch | ||
| from common import add_aux_layers_args, resolve_aux_layers | ||
| from common import ( | ||
| add_answer_only_loss_args, | ||
| add_aux_layers_args, | ||
| load_chat_template, | ||
| tokenize_with_loss_mask, | ||
| verify_generation_tags, | ||
| ) | ||
| from datasets import load_dataset | ||
| from tqdm import tqdm | ||
| from transformers import AutoConfig, AutoTokenizer | ||
|
|
@@ -38,6 +47,48 @@ | |
| ) | ||
|
|
||
|
|
||
| def _resolve_aux_layers_standalone( | ||
| aux_layers: str, num_hidden_layers: int, num_draft: int = 5 | ||
| ) -> list[int]: | ||
| """Resolve aux-layer ids without importing modelopt. | ||
|
|
||
| This dump runs in a stock vLLM container. ``common.resolve_aux_layers`` resolves the | ||
| 'dflash'/'eagle' presets by importing ``modelopt.torch.speculative.plugins`` — which | ||
| pulls in the full ``modelopt.torch`` init chain (omegaconf, etc.) that the vLLM | ||
| container does not have, so the import fails. Resolve the 'dflash' preset inline | ||
| (mirroring ``modeling_dflash.build_target_layer_ids`` for ``num_draft`` draft layers) | ||
| and accept an explicit comma-separated int list. ``num_draft`` MUST match the recipe's | ||
| ``dflash.dflash_architecture_config.num_hidden_layers`` (pass --num-draft-layers) or the | ||
| dumped aux layers silently mis-align with what the draft consumes at training time. | ||
| Keep in sync with modelopt. | ||
|
|
||
| TODO: drop this once ``common.resolve_aux_layers`` is decoupled from the heavy | ||
| ``modelopt.torch`` import chain so it can be reused directly in a vLLM container. | ||
| """ | ||
| spec = aux_layers.strip().lower() | ||
| if spec == "dflash": | ||
| if num_draft == 1: | ||
| return [num_hidden_layers // 2] | ||
| start = min(1, num_hidden_layers - 1) | ||
| end = max(start, num_hidden_layers - 3) | ||
| span = end - start | ||
| return sorted({round(start + (i * span) / (num_draft - 1)) for i in range(num_draft)}) | ||
|
Comment on lines
+69
to
+75
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [IMPORTANT Algorithm] The standalone resolver hard-codes
Two fixes either of which would close the gap:
The TODO already acknowledges this is fragile; the comment-coupling to a hard-coded |
||
| ids = sorted({int(t) for t in aux_layers.split(",") if t.strip()}) | ||
| # Match the shared helper's contract: ids must be valid layer indices. | ||
| out_of_range = [i for i in ids if not 0 <= i < num_hidden_layers] | ||
| if out_of_range: | ||
| raise ValueError( | ||
| f"--aux-layers ids {out_of_range} out of range [0, {num_hidden_layers}) " | ||
| f"for a {num_hidden_layers}-layer model." | ||
| ) | ||
| if not ids: | ||
| raise ValueError( | ||
| f"--aux-layers={aux_layers!r}: in the stock vLLM container (no modelopt) only the " | ||
| "'dflash' preset or an explicit comma-separated layer-id list are supported." | ||
| ) | ||
| return ids | ||
|
|
||
|
|
||
| def parse_args() -> argparse.Namespace: | ||
| parser = argparse.ArgumentParser( | ||
| description="""Collect hidden states from conversations using vLLM's native extractor.""" | ||
|
|
@@ -63,6 +114,14 @@ def parse_args() -> argparse.Namespace: | |
| "--debug-max-num-conversations", type=int, default=None, help="Limit conversations." | ||
| ) | ||
| add_aux_layers_args(parser) | ||
| parser.add_argument( | ||
| "--num-draft-layers", | ||
| type=int, | ||
| default=5, | ||
| help="DFlash draft depth, for resolving the 'dflash' --aux-layers preset. MUST match " | ||
| "the recipe's dflash.dflash_architecture_config.num_hidden_layers (default: 5).", | ||
| ) | ||
| add_answer_only_loss_args(parser) | ||
|
|
||
| return parser.parse_args() | ||
|
|
||
|
|
@@ -112,7 +171,9 @@ def keep_conversation(entry): | |
| num_hidden_layers = getattr(config, "num_hidden_layers", None) | ||
| if num_hidden_layers is None: | ||
| raise ValueError(f"model config has no 'num_hidden_layers' attribute: {config}") | ||
| aux_layer_ids = resolve_aux_layers(args, num_hidden_layers) | ||
| aux_layer_ids = _resolve_aux_layers_standalone( | ||
| args.aux_layers, num_hidden_layers, num_draft=args.num_draft_layers | ||
| ) | ||
| # The trailing entry is the final output hidden state; the rest are aux layers. | ||
| extract_layer_ids = [*aux_layer_ids, num_hidden_layers] | ||
| print(f"Extracting hidden states from layers {extract_layer_ids} (last = final output)") | ||
|
|
@@ -121,34 +182,37 @@ def keep_conversation(entry): | |
| tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=args.trust_remote_code) | ||
| if tokenizer.pad_token is None: | ||
| tokenizer.pad_token = tokenizer.eos_token | ||
| override_template = load_chat_template(args.chat_template) | ||
| if override_template is not None: | ||
| tokenizer.chat_template = override_template | ||
| if tokenizer.chat_template is not None: | ||
| tokenizer.chat_template = tokenizer.chat_template.replace(REMOVE_THINK_CHAT_TEMPLATE, "") | ||
| if args.answer_only_loss: | ||
| verify_generation_tags(tokenizer.chat_template) | ||
|
|
||
| # Prepare prompts for vLLM | ||
| prompts = [] | ||
| conversation_ids = [] | ||
| loss_masks = [] | ||
| num_skipped_too_long = 0 | ||
| num_invalid = 0 | ||
|
|
||
| for entry in dataset: | ||
| conversation_id = entry.get("conversation_id", entry.get("uuid")) | ||
| conversations = entry["conversations"] | ||
| # Accept either the "conversations" or OpenAI-style "messages" key (the | ||
| # MiniMax synthetic data uses "messages"). | ||
| conversations = entry.get("conversations") or entry.get("messages") | ||
| if not conversations or not isinstance(conversations, list): | ||
| num_invalid += 1 | ||
| continue | ||
|
|
||
| tokenized = tokenizer.apply_chat_template( | ||
| conversations, return_tensors="pt", add_generation_prompt=False | ||
| # One apply_chat_template call yields aligned input_ids + loss_mask. With | ||
| # --answer-only-loss the mask comes from the template's {% generation %} tags; | ||
| # otherwise it is all-ones. Same tokens are sent to vLLM, so the dumped hidden | ||
| # states line up with this loss_mask 1:1 (prefix caching is disabled below). | ||
| input_ids, loss_mask = tokenize_with_loss_mask( | ||
| tokenizer, conversations, args.answer_only_loss | ||
| ) | ||
| # transformers 5.x: BatchEncoding may not inherit from dict; use .input_ids | ||
| if hasattr(tokenized, "input_ids"): | ||
| input_ids = tokenized.input_ids | ||
| elif hasattr(tokenized, "__getitem__") and "input_ids" in tokenized: | ||
| input_ids = tokenized["input_ids"] | ||
| else: | ||
| input_ids = tokenized | ||
| if not hasattr(input_ids, "shape"): | ||
| input_ids = torch.tensor(input_ids) | ||
| input_ids = input_ids.squeeze(0) | ||
| num_tokens = input_ids.shape[0] | ||
| if num_tokens <= 10 or num_tokens > args.max_seq_len: | ||
|
|
@@ -157,6 +221,7 @@ def keep_conversation(entry): | |
|
|
||
| prompts.append(TokensPrompt(prompt_token_ids=input_ids.tolist())) | ||
| conversation_ids.append(conversation_id) | ||
| loss_masks.append(loss_mask) | ||
|
|
||
| print( | ||
| f"Prepared {len(prompts)} prompts ({num_skipped_too_long} skipped too long, {num_invalid} invalid)" | ||
|
|
@@ -168,15 +233,28 @@ def keep_conversation(entry): | |
|
|
||
| # Initialize vLLM with the native hidden-state extractor. | ||
| tp = args.tp if args.tp is not None else torch.cuda.device_count() | ||
| storage_path = output_dir / ".vllm_hidden_states" | ||
| # Stage the connector's intermediate safetensors on local tmpfs, not the (lustre) | ||
| # output dir: the producer writes one file per request and the client reads it back | ||
| # immediately, so a fast local path avoids cross-node FS latency. Per-DP-rank dir so | ||
| # parallel shards don't collide. Overridable via DFLASH_HS_STAGING_DIR for containers | ||
| # where /dev/shm is unmapped or undersized; cleaned up on exit so a crash doesn't strand | ||
| # RAM-backed files until the node reboots. | ||
| staging_root = os.environ.get("DFLASH_HS_STAGING_DIR", "/dev/shm") | ||
| storage_path = Path(staging_root) / f"vllm_hidden_states_dp{args.dp_rank}" | ||
| storage_path.mkdir(parents=True, exist_ok=True) | ||
| atexit.register(shutil.rmtree, storage_path, ignore_errors=True) | ||
|
|
||
| llm = LLM( | ||
| model=args.model, | ||
| tensor_parallel_size=tp, | ||
| max_model_len=args.max_seq_len, | ||
| trust_remote_code=args.trust_remote_code, | ||
| enable_chunked_prefill=False, # required by extract_hidden_states | ||
| # With prefix caching on, vLLM serves shared prefixes from cache in block-sized | ||
| # chunks and the hidden-state connector only emits the freshly-computed suffix, so | ||
| # the dumped hidden_states come out short by N*block_size vs the full input_ids / | ||
| # loss_mask. Disabling it forces a full prefill so every token's state is dumped. | ||
| enable_prefix_caching=False, | ||
| speculative_config={ | ||
| "method": "extract_hidden_states", | ||
| "num_speculative_tokens": 1, | ||
|
|
@@ -189,18 +267,22 @@ def keep_conversation(entry): | |
| kv_role="kv_producer", | ||
| kv_connector_extra_config={ | ||
| "shared_storage_path": str(storage_path), | ||
| "use_synchronization_lock": False, # batch generation, no concurrent readers | ||
| # The client reads each request's safetensors right after generation; the | ||
| # lock makes the producer signal completion so the reader doesn't race the | ||
| # writer (without it the reader looks for a .lock the producer never wrote). | ||
| "use_synchronization_lock": True, | ||
| }, | ||
| ), | ||
| ) | ||
|
|
||
| # max_tokens=1: we only need a single forward pass over the prompt tokens. | ||
| outputs = llm.generate(prompts, SamplingParams(max_tokens=1)) | ||
|
|
||
| # Save in the same format as compute_hidden_states_hf.py (sans loss_mask, which the | ||
| # vLLM path does not compute). | ||
| # Save in the same format as compute_hidden_states_hf.py, including loss_mask. | ||
| num_success = 0 | ||
| for conv_id, output in tqdm(zip(conversation_ids, outputs), total=len(outputs), desc="Saving"): | ||
| for conv_id, loss_mask, output in tqdm( | ||
| zip(conversation_ids, loss_masks, outputs), total=len(outputs), desc="Saving" | ||
| ): | ||
| hidden_states_path = output.kv_transfer_params.get("hidden_states_path") | ||
| if hidden_states_path is None: | ||
| print(f"WARNING: no hidden_states_path for conversation {conv_id}; skipping") | ||
|
|
@@ -220,13 +302,24 @@ def keep_conversation(entry): | |
| else: | ||
| aux_hidden_states = torch.empty(0) | ||
|
|
||
| # loss_mask is sliced to the dumped length below; a shorter loss_mask would slice | ||
| # to itself and silently misalign with the hidden states, so guard explicitly. | ||
| n_hs = output_hidden_states.shape[0] | ||
| if loss_mask.shape[0] < n_hs: | ||
| print( | ||
| f"WARNING: {conv_id}: loss_mask ({loss_mask.shape[0]}) shorter than hidden " | ||
| f"states ({n_hs}); skipping to avoid misalignment" | ||
| ) | ||
| continue | ||
|
|
||
| output_file = output_dir / f"{conv_id}.pt" | ||
| with open(output_file, "wb") as f: | ||
| torch.save( | ||
| { | ||
| "input_ids": token_ids.cpu(), | ||
| "hidden_states": output_hidden_states, | ||
| "aux_hidden_states": aux_hidden_states, | ||
| "loss_mask": loss_mask[: output_hidden_states.shape[0]].cpu(), | ||
| "conversation_id": conv_id, | ||
| }, | ||
| f, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -230,6 +230,110 @@ def on_step_begin(self, args, state, control, **kwargs): | |
| return control | ||
|
|
||
|
|
||
| class DFlashFSDP2ShardedSDExportCallback(TrainerCallback): | ||
| """Export the DFlash draft module after each checkpoint save, for FSDP2 sharded runs. | ||
|
|
||
| Applicable range: this is needed only under FSDP2 ``SHARDED_STATE_DICT``, where the | ||
| checkpoint holds distributed shards (``pytorch_model_fsdp_0/``) and no consolidated | ||
| ``model.safetensors`` — so the post-training ``export_hf_checkpoint.py`` pass can't read | ||
| it. It gathers just the small draft submodule and writes the deployable export. | ||
|
|
||
| Gating is the caller's responsibility: ``main.py`` adds this callback only when the | ||
| accelerator's FSDP state dict type is ``SHARDED_STATE_DICT`` (full-state-dict runs — | ||
| DDP, single-device, FSDP2 FULL_STATE_DICT — use the launcher's post-run export instead). | ||
| """ | ||
|
|
||
| def on_save(self, args, state, control, **kwargs): | ||
| """Export DFlash draft module weights + config after checkpoint save.""" | ||
| import json | ||
| import os | ||
|
|
||
| from safetensors.torch import save_file | ||
|
|
||
| model = kwargs["model"] | ||
| if not hasattr(model, "dflash_module"): | ||
| return control | ||
|
|
||
| step = state.global_step | ||
| export_dir = os.path.join(args.output_dir, f"exported-checkpoint-{step}") | ||
|
|
||
| # All ranks participate in the state_dict gather (FSDP2 collective op). Only the | ||
| # dflash_module submodule is gathered (~328 MB for MiniMax-M2.7), not the 229B base. | ||
| try: | ||
| from torch.distributed.checkpoint.state_dict import ( | ||
| StateDictOptions, | ||
| get_model_state_dict, | ||
| ) | ||
|
|
||
| options = StateDictOptions(full_state_dict=True, cpu_offload=True) | ||
| try: | ||
| raw_sd = get_model_state_dict( | ||
| model, submodules={model.dflash_module}, options=options | ||
| ) | ||
| except TypeError: | ||
| # Older PyTorch without the submodules= parameter: this gathers the FULL | ||
| # model (the entire base, e.g. ~229B for MiniMax-M2.7), defeating the | ||
| # submodule-only design and risking OOM. Warn loudly — upgrade PyTorch. | ||
| print_rank_0( | ||
| "WARNING: DFlash export: get_model_state_dict lacks submodules= on this " | ||
| "PyTorch — gathering the FULL base model (slow / may OOM). Upgrade PyTorch " | ||
| "for the submodule-only gather." | ||
| ) | ||
| raw_sd = get_model_state_dict(model, options=options) | ||
| except ImportError: | ||
| # Non-distributed / single-GPU fallback | ||
| raw_sd = model.state_dict() | ||
|
|
||
| # Reuse the exporter's extraction (strips the dflash_module prefix, drops rotary | ||
| # buffers) for the common full-model key layout. Some PyTorch versions return the | ||
| # submodule gather with keys already stripped of the prefix — handle that directly. | ||
| exporter = model.get_exporter() | ||
| drafter_sd = exporter._extract_state_dict(raw_sd) | ||
| if not drafter_sd: | ||
| # Fallback for the already-stripped-key layout: a denylist heuristic rather than | ||
| # the prefix-based extractor. Warn, since a key-naming change upstream could let | ||
| # a malformed draft ship and only fail at vLLM load time. | ||
| print_rank_0( | ||
| "WARNING: DFlash export: prefix-based extraction found no dflash_module keys; " | ||
| "falling back to a denylist heuristic on already-stripped keys. Verify the " | ||
| "exported draft loads in vLLM." | ||
| ) | ||
| drafter_sd = { | ||
| k: v | ||
| for k, v in raw_sd.items() | ||
|
Comment on lines
+269
to
+303
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [IMPORTANT Performance/ModeState] The silent full-model fallback when The whole point of this callback (per the docstring and PR description) is to gather only the ~328 MB DFlash submodule under FSDP2 SHARDED_STATE_DICT, without materializing the 229B base. But this fallback path: except TypeError:
# Older PyTorch without submodules parameter — gather full model
raw_sd = get_model_state_dict(model, options=options)silently calls Suggestion: either (a) raise a clear error pointing the user at a minimum supported PyTorch version, or (b) at minimum |
||
| if "rotary_emb" not in k | ||
| and not any(p in k for p in ("model.", "lm_head.", "embed_tokens.")) | ||
| } | ||
| del raw_sd | ||
| # Coerce to CPU for save_file (the distributed gather uses cpu_offload, but the | ||
| # single-GPU fallback may return CUDA tensors). | ||
| drafter_sd = {k: (v.cpu() if v.device.type != "cpu" else v) for k, v in drafter_sd.items()} | ||
|
Comment on lines
+292
to
+310
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [IMPORTANT Algorithm/Export] The fallback key heuristic doesn't match what
But the heuristic here: drafter_sd = {
k: v
for k, v in raw_sd.items()
if "rotary_emb" not in k
and not any(p in k for p in ("model.", "lm_head.", "embed_tokens."))
}filters out anything containing Recommendation: instead of the substring heuristic, detect the pre-stripped case explicitly and either (a) re-prefix the keys with |
||
|
|
||
| if not drafter_sd: | ||
| print_rank_0(f"Warning: No dflash_module weights found at step {step}, skipping export") | ||
| return control | ||
|
|
||
| # Only rank 0 writes files | ||
| if is_master(): | ||
| try: | ||
| os.makedirs(export_dir, exist_ok=True) | ||
| save_file(drafter_sd, os.path.join(export_dir, "model.safetensors")) | ||
|
|
||
| config = exporter._export_config() | ||
| with open(os.path.join(export_dir, "config.json"), "w") as f: | ||
| json.dump(config, f, indent=2) | ||
|
|
||
| total_mb = sum(v.nbytes for v in drafter_sd.values()) / 1024 / 1024 | ||
| print_rank_0( | ||
| f"Exported DFlash draft ({len(drafter_sd)} tensors, {total_mb:.0f}MB) " | ||
| f"to {export_dir}" | ||
| ) | ||
| except Exception as e: | ||
| print_rank_0(f"Warning: DFlash export failed at step {step}: {e}") | ||
|
|
||
| return control | ||
|
|
||
|
|
||
| class EagleTrainingPlot(TrainerCallback): | ||
| """Callback that plot training acc and AR during training.""" | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.