Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
c23d002
fix: auto-add mask token for DFlash when tokenizer lacks one
yeyu-nvidia Apr 17, 2026
4a83b91
feat: add requeue support to build_slurm_executor
yeyu-nvidia Apr 17, 2026
a26022e
feat: FSDP2 efficient loading + requeue retries fix
yeyu-nvidia Apr 20, 2026
b43c7f9
fix: restore FSDP2 buffer patch and remove buffer-to-CUDA block for c…
yeyu-nvidia Jun 1, 2026
967a2c0
fix: use broadcast dtype instead of local param dtype in FSDP2 buffer…
yeyu-nvidia Jun 1, 2026
ca89c84
feat: restore DFlashExportCallback for per-checkpoint draft export
yeyu-nvidia Jun 3, 2026
7a21536
feat: MiniMax-M2.7-DFlash launcher example (online + offline + specde…
yeyu-nvidia Jun 9, 2026
207f257
feat: loss_mask + chat-template support in compute_hidden_states_vllm.py
yeyu-nvidia Jun 9, 2026
f63a387
fix: tolerate missing dist metadata in modelopt.__version__
yeyu-nvidia Jun 9, 2026
b9baede
fix: standalone aux-layer resolution in vLLM dump + --block_size in D…
yeyu-nvidia Jun 9, 2026
79634e9
feat: DFlash export injects long-context RoPE (target rope_theta + YaRN)
yeyu-nvidia Jun 10, 2026
b54c862
fix: vLLM hidden-state dump accepts 'messages' key (not just 'convers…
yeyu-nvidia Jun 10, 2026
e00aaba
fix: vLLM hidden-state dump stages on /dev/shm with sync lock
yeyu-nvidia Jun 10, 2026
9dbd693
review: address CodeRabbit nitpicks + revert modelopt version guard
yeyu-nvidia Jun 10, 2026
1316760
style: ruff format fsdp2_buffer_patch.py (license header + line wrapp…
yeyu-nvidia Jun 10, 2026
45f3a1e
review: converge DFlash export RoPE to a config field; gate export ca…
yeyu-nvidia Jun 10, 2026
035bad7
review: avoid embedding resize for DFlash mask token; reuse an existi…
yeyu-nvidia Jun 10, 2026
645cd5b
review: document applicability/scope of fsdp2_buffer_patch
yeyu-nvidia Jun 10, 2026
426fadf
test: cover DFlash export rope-scaling field + mask-token resolution
yeyu-nvidia Jun 10, 2026
bc97213
review: rename PATCH_FSDP2_BUFFERS -> PATCH_FSDP2_BUFFERS_TF457; trim…
yeyu-nvidia Jun 10, 2026
c7fec79
fix: CI failures — drop bogus num_error, ruff nits, remove scratch YAMLs
yeyu-nvidia Jun 10, 2026
4df2d60
fix: launcher tests for the new requeue path in build_slurm_executor
yeyu-nvidia Jun 11, 2026
e82ef56
review: revert mask-token resize/helper; pin id in recipe (per @hguo-nv)
yeyu-nvidia Jun 11, 2026
e826afb
Merge remote-tracking branch 'origin/main' into yeyu/dflash-auto-mask…
yeyu-nvidia Jun 11, 2026
5846657
review: clarify rope_theta rationale (KV injection); trim rope_scalin…
yeyu-nvidia Jun 12, 2026
4912379
review: gate DFlashExportCallback on SHARDED_STATE_DICT; reuse export…
yeyu-nvidia Jun 12, 2026
01778c8
review: gate export callback by fsdp_state_dict_type, rename it, move…
yeyu-nvidia Jun 13, 2026
786b718
fix: enforce draft rope_theta/rope_type from base model (not setdefault)
yeyu-nvidia Jun 15, 2026
ff7dae6
review: extract _is_hf_format_checkpoint helper for the resume-format…
yeyu-nvidia Jun 15, 2026
8281422
review(claude-bot): fix clip_grad_norm deadlock + silent fallbacks
yeyu-nvidia Jun 15, 2026
a8c67e8
review(claude-bot): parameterize draft depth + harden /dev/shm stagin…
yeyu-nvidia Jun 15, 2026
69ad872
fix: mypy — pass shutil.rmtree args via atexit.register instead of a …
yeyu-nvidia Jun 15, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,19 @@
"""

import argparse
import atexit
import os
import shutil
from pathlib import Path

import torch
from common import add_aux_layers_args, resolve_aux_layers
from common import (
add_answer_only_loss_args,
add_aux_layers_args,
load_chat_template,
tokenize_with_loss_mask,
verify_generation_tags,
)
from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoConfig, AutoTokenizer
Expand All @@ -38,6 +47,48 @@
)


def _resolve_aux_layers_standalone(
aux_layers: str, num_hidden_layers: int, num_draft: int = 5
) -> list[int]:
"""Resolve aux-layer ids without importing modelopt.

This dump runs in a stock vLLM container. ``common.resolve_aux_layers`` resolves the
'dflash'/'eagle' presets by importing ``modelopt.torch.speculative.plugins`` — which
pulls in the full ``modelopt.torch`` init chain (omegaconf, etc.) that the vLLM
container does not have, so the import fails. Resolve the 'dflash' preset inline
Comment thread
h-guo18 marked this conversation as resolved.
(mirroring ``modeling_dflash.build_target_layer_ids`` for ``num_draft`` draft layers)
and accept an explicit comma-separated int list. ``num_draft`` MUST match the recipe's
``dflash.dflash_architecture_config.num_hidden_layers`` (pass --num-draft-layers) or the
dumped aux layers silently mis-align with what the draft consumes at training time.
Keep in sync with modelopt.

TODO: drop this once ``common.resolve_aux_layers`` is decoupled from the heavy
``modelopt.torch`` import chain so it can be reused directly in a vLLM container.
"""
spec = aux_layers.strip().lower()
if spec == "dflash":
if num_draft == 1:
return [num_hidden_layers // 2]
start = min(1, num_hidden_layers - 1)
end = max(start, num_hidden_layers - 3)
span = end - start
return sorted({round(start + (i * span) / (num_draft - 1)) for i in range(num_draft)})
Comment on lines +69 to +75

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[IMPORTANT Algorithm] The standalone resolver hard-codes num_draft=5 while the recipe value is configurable — a mismatch silently mis-aligns dumped aux layers.

build_target_layer_ids(num_target_layers, num_draft_layers) (modeling_dflash.py:58-69) is parameterized by num_draft_layers. The MiniMax recipes here pin dflash.dflash_architecture_config.num_hidden_layers=5, so today the constant matches — but if anyone runs offline DFlash with a different draft depth (4, 6, 8 etc.) and forgets to switch from --aux-layers dflash to an explicit comma-list, the dump silently captures a different set of target layers than the trained draft will actually consume at training time, since target_layer_ids is also derived from num_draft_layers via the same function in hf_dflash.py:169. The training run will load mis-aligned aux features and learn garbage; the failure mode is "AR mysteriously regresses" rather than a loud crash.

Two fixes either of which would close the gap:

  1. Take the draft depth as a CLI flag (--dflash-num-draft 5) and pass it through.
  2. At least raise a loud error / log if the recipe-default 5 does not match the model's actual num_hidden_layers divisibility expectations, and document the constant-coupling explicitly in the docstring.

The TODO already acknowledges this is fragile; the comment-coupling to a hard-coded 5 is the actual breakage vector. Suggest making the value a required CLI arg and removing the default.

ids = sorted({int(t) for t in aux_layers.split(",") if t.strip()})
# Match the shared helper's contract: ids must be valid layer indices.
out_of_range = [i for i in ids if not 0 <= i < num_hidden_layers]
if out_of_range:
raise ValueError(
f"--aux-layers ids {out_of_range} out of range [0, {num_hidden_layers}) "
f"for a {num_hidden_layers}-layer model."
)
if not ids:
raise ValueError(
f"--aux-layers={aux_layers!r}: in the stock vLLM container (no modelopt) only the "
"'dflash' preset or an explicit comma-separated layer-id list are supported."
)
return ids


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="""Collect hidden states from conversations using vLLM's native extractor."""
Expand All @@ -63,6 +114,14 @@ def parse_args() -> argparse.Namespace:
"--debug-max-num-conversations", type=int, default=None, help="Limit conversations."
)
add_aux_layers_args(parser)
parser.add_argument(
"--num-draft-layers",
type=int,
default=5,
help="DFlash draft depth, for resolving the 'dflash' --aux-layers preset. MUST match "
"the recipe's dflash.dflash_architecture_config.num_hidden_layers (default: 5).",
)
add_answer_only_loss_args(parser)

return parser.parse_args()

Expand Down Expand Up @@ -112,7 +171,9 @@ def keep_conversation(entry):
num_hidden_layers = getattr(config, "num_hidden_layers", None)
if num_hidden_layers is None:
raise ValueError(f"model config has no 'num_hidden_layers' attribute: {config}")
aux_layer_ids = resolve_aux_layers(args, num_hidden_layers)
aux_layer_ids = _resolve_aux_layers_standalone(
args.aux_layers, num_hidden_layers, num_draft=args.num_draft_layers
)
# The trailing entry is the final output hidden state; the rest are aux layers.
extract_layer_ids = [*aux_layer_ids, num_hidden_layers]
print(f"Extracting hidden states from layers {extract_layer_ids} (last = final output)")
Expand All @@ -121,34 +182,37 @@ def keep_conversation(entry):
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=args.trust_remote_code)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
override_template = load_chat_template(args.chat_template)
if override_template is not None:
tokenizer.chat_template = override_template
if tokenizer.chat_template is not None:
tokenizer.chat_template = tokenizer.chat_template.replace(REMOVE_THINK_CHAT_TEMPLATE, "")
if args.answer_only_loss:
verify_generation_tags(tokenizer.chat_template)

# Prepare prompts for vLLM
prompts = []
conversation_ids = []
loss_masks = []
num_skipped_too_long = 0
num_invalid = 0

for entry in dataset:
conversation_id = entry.get("conversation_id", entry.get("uuid"))
conversations = entry["conversations"]
# Accept either the "conversations" or OpenAI-style "messages" key (the
# MiniMax synthetic data uses "messages").
conversations = entry.get("conversations") or entry.get("messages")
if not conversations or not isinstance(conversations, list):
num_invalid += 1
continue

tokenized = tokenizer.apply_chat_template(
conversations, return_tensors="pt", add_generation_prompt=False
# One apply_chat_template call yields aligned input_ids + loss_mask. With
# --answer-only-loss the mask comes from the template's {% generation %} tags;
# otherwise it is all-ones. Same tokens are sent to vLLM, so the dumped hidden
# states line up with this loss_mask 1:1 (prefix caching is disabled below).
input_ids, loss_mask = tokenize_with_loss_mask(
tokenizer, conversations, args.answer_only_loss
)
# transformers 5.x: BatchEncoding may not inherit from dict; use .input_ids
if hasattr(tokenized, "input_ids"):
input_ids = tokenized.input_ids
elif hasattr(tokenized, "__getitem__") and "input_ids" in tokenized:
input_ids = tokenized["input_ids"]
else:
input_ids = tokenized
if not hasattr(input_ids, "shape"):
input_ids = torch.tensor(input_ids)
input_ids = input_ids.squeeze(0)
num_tokens = input_ids.shape[0]
if num_tokens <= 10 or num_tokens > args.max_seq_len:
Expand All @@ -157,6 +221,7 @@ def keep_conversation(entry):

prompts.append(TokensPrompt(prompt_token_ids=input_ids.tolist()))
conversation_ids.append(conversation_id)
loss_masks.append(loss_mask)

print(
f"Prepared {len(prompts)} prompts ({num_skipped_too_long} skipped too long, {num_invalid} invalid)"
Expand All @@ -168,15 +233,28 @@ def keep_conversation(entry):

# Initialize vLLM with the native hidden-state extractor.
tp = args.tp if args.tp is not None else torch.cuda.device_count()
storage_path = output_dir / ".vllm_hidden_states"
# Stage the connector's intermediate safetensors on local tmpfs, not the (lustre)
# output dir: the producer writes one file per request and the client reads it back
# immediately, so a fast local path avoids cross-node FS latency. Per-DP-rank dir so
# parallel shards don't collide. Overridable via DFLASH_HS_STAGING_DIR for containers
# where /dev/shm is unmapped or undersized; cleaned up on exit so a crash doesn't strand
# RAM-backed files until the node reboots.
staging_root = os.environ.get("DFLASH_HS_STAGING_DIR", "/dev/shm")
storage_path = Path(staging_root) / f"vllm_hidden_states_dp{args.dp_rank}"
storage_path.mkdir(parents=True, exist_ok=True)
atexit.register(shutil.rmtree, storage_path, ignore_errors=True)

llm = LLM(
model=args.model,
tensor_parallel_size=tp,
max_model_len=args.max_seq_len,
trust_remote_code=args.trust_remote_code,
enable_chunked_prefill=False, # required by extract_hidden_states
# With prefix caching on, vLLM serves shared prefixes from cache in block-sized
# chunks and the hidden-state connector only emits the freshly-computed suffix, so
# the dumped hidden_states come out short by N*block_size vs the full input_ids /
# loss_mask. Disabling it forces a full prefill so every token's state is dumped.
enable_prefix_caching=False,
speculative_config={
"method": "extract_hidden_states",
"num_speculative_tokens": 1,
Expand All @@ -189,18 +267,22 @@ def keep_conversation(entry):
kv_role="kv_producer",
kv_connector_extra_config={
"shared_storage_path": str(storage_path),
"use_synchronization_lock": False, # batch generation, no concurrent readers
# The client reads each request's safetensors right after generation; the
# lock makes the producer signal completion so the reader doesn't race the
# writer (without it the reader looks for a .lock the producer never wrote).
"use_synchronization_lock": True,
},
),
)

# max_tokens=1: we only need a single forward pass over the prompt tokens.
outputs = llm.generate(prompts, SamplingParams(max_tokens=1))

# Save in the same format as compute_hidden_states_hf.py (sans loss_mask, which the
# vLLM path does not compute).
# Save in the same format as compute_hidden_states_hf.py, including loss_mask.
num_success = 0
for conv_id, output in tqdm(zip(conversation_ids, outputs), total=len(outputs), desc="Saving"):
for conv_id, loss_mask, output in tqdm(
zip(conversation_ids, loss_masks, outputs), total=len(outputs), desc="Saving"
):
hidden_states_path = output.kv_transfer_params.get("hidden_states_path")
if hidden_states_path is None:
print(f"WARNING: no hidden_states_path for conversation {conv_id}; skipping")
Expand All @@ -220,13 +302,24 @@ def keep_conversation(entry):
else:
aux_hidden_states = torch.empty(0)

# loss_mask is sliced to the dumped length below; a shorter loss_mask would slice
# to itself and silently misalign with the hidden states, so guard explicitly.
n_hs = output_hidden_states.shape[0]
if loss_mask.shape[0] < n_hs:
print(
f"WARNING: {conv_id}: loss_mask ({loss_mask.shape[0]}) shorter than hidden "
f"states ({n_hs}); skipping to avoid misalignment"
)
continue

output_file = output_dir / f"{conv_id}.pt"
with open(output_file, "wb") as f:
torch.save(
{
"input_ids": token_ids.cpu(),
"hidden_states": output_hidden_states,
"aux_hidden_states": aux_hidden_states,
"loss_mask": loss_mask[: output_hidden_states.shape[0]].cpu(),
"conversation_id": conv_id,
},
f,
Expand Down
7 changes: 6 additions & 1 deletion examples/speculative_decoding/doc/dflash.md
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,15 @@ See [`modelopt_recipes/general/speculative_decoding/dflash.yaml`](../../../model
| `dflash.dflash_num_anchors` | 512 | Random anchor positions per sample (see below) |
| `dflash.dflash_loss_decay_factor` | 4.0 | Exponential decay gamma (0 disables, see below) |
| `dflash.dflash_self_logit_distillation` | true | Use target model logits as soft labels (vs hard CE) |
| `dflash.dflash_mask_token_id` | auto | Token ID for masked positions (see note below) |
| `dflash.dflash_architecture_config.num_hidden_layers` | 5 | Draft decoder layers |
| `dflash.dflash_architecture_config.mask_token_id` | auto | Token ID for masked positions |
| `training.answer_only_loss` | false | Mask loss on non-assistant tokens |

> **Note on `dflash_mask_token_id`:** the draft reuses the target's `embed_tokens`, so the
> mask id must be a token that exists in the target embedding. Pin it to an existing
> reserved token id — e.g. MiniMax-M2.7 uses `200054` (embedding is 200064 rows; tokens
> 0..200053 are real, 200054+ reserved).

> **Note on `answer_only_loss` and chat templates:** When `answer_only_loss=true`, the
> tokenizer's chat template must include `{% generation %}` / `{% endgeneration %}` tags
> around assistant content. HuggingFace uses these tags to produce `assistant_masks` via
Expand Down
104 changes: 104 additions & 0 deletions examples/speculative_decoding/eagle_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,110 @@ def on_step_begin(self, args, state, control, **kwargs):
return control


class DFlashFSDP2ShardedSDExportCallback(TrainerCallback):
"""Export the DFlash draft module after each checkpoint save, for FSDP2 sharded runs.

Applicable range: this is needed only under FSDP2 ``SHARDED_STATE_DICT``, where the
checkpoint holds distributed shards (``pytorch_model_fsdp_0/``) and no consolidated
``model.safetensors`` — so the post-training ``export_hf_checkpoint.py`` pass can't read
it. It gathers just the small draft submodule and writes the deployable export.

Gating is the caller's responsibility: ``main.py`` adds this callback only when the
accelerator's FSDP state dict type is ``SHARDED_STATE_DICT`` (full-state-dict runs —
DDP, single-device, FSDP2 FULL_STATE_DICT — use the launcher's post-run export instead).
"""

def on_save(self, args, state, control, **kwargs):
"""Export DFlash draft module weights + config after checkpoint save."""
import json
import os

from safetensors.torch import save_file

model = kwargs["model"]
if not hasattr(model, "dflash_module"):
return control

step = state.global_step
export_dir = os.path.join(args.output_dir, f"exported-checkpoint-{step}")

# All ranks participate in the state_dict gather (FSDP2 collective op). Only the
# dflash_module submodule is gathered (~328 MB for MiniMax-M2.7), not the 229B base.
try:
from torch.distributed.checkpoint.state_dict import (
StateDictOptions,
get_model_state_dict,
)

options = StateDictOptions(full_state_dict=True, cpu_offload=True)
try:
raw_sd = get_model_state_dict(
model, submodules={model.dflash_module}, options=options
)
except TypeError:
# Older PyTorch without the submodules= parameter: this gathers the FULL
# model (the entire base, e.g. ~229B for MiniMax-M2.7), defeating the
# submodule-only design and risking OOM. Warn loudly — upgrade PyTorch.
print_rank_0(
"WARNING: DFlash export: get_model_state_dict lacks submodules= on this "
"PyTorch — gathering the FULL base model (slow / may OOM). Upgrade PyTorch "
"for the submodule-only gather."
)
raw_sd = get_model_state_dict(model, options=options)
except ImportError:
# Non-distributed / single-GPU fallback
raw_sd = model.state_dict()

# Reuse the exporter's extraction (strips the dflash_module prefix, drops rotary
# buffers) for the common full-model key layout. Some PyTorch versions return the
# submodule gather with keys already stripped of the prefix — handle that directly.
exporter = model.get_exporter()
drafter_sd = exporter._extract_state_dict(raw_sd)
if not drafter_sd:
# Fallback for the already-stripped-key layout: a denylist heuristic rather than
# the prefix-based extractor. Warn, since a key-naming change upstream could let
# a malformed draft ship and only fail at vLLM load time.
print_rank_0(
"WARNING: DFlash export: prefix-based extraction found no dflash_module keys; "
"falling back to a denylist heuristic on already-stripped keys. Verify the "
"exported draft loads in vLLM."
)
drafter_sd = {
k: v
for k, v in raw_sd.items()
Comment on lines +269 to +303

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[IMPORTANT Performance/ModeState] The silent full-model fallback when submodules= is not supported is dangerous for the very target this callback exists to support.

The whole point of this callback (per the docstring and PR description) is to gather only the ~328 MB DFlash submodule under FSDP2 SHARDED_STATE_DICT, without materializing the 229B base. But this fallback path:

except TypeError:
    # Older PyTorch without submodules parameter — gather full model
    raw_sd = get_model_state_dict(model, options=options)

silently calls get_model_state_dict(model, options=options) (no submodule filter, full_state_dict=True, cpu_offload=True), which gathers the entire 229B base on every save. That is many minutes per checkpoint of all-gather plus ~230 GB of CPU memory pressure per node — exactly what the submodule gather was added to avoid. On older PyTorch it wouldn't fail loudly; the user would just see saves get mysteriously expensive (or OOM) without understanding why.

Suggestion: either (a) raise a clear error pointing the user at a minimum supported PyTorch version, or (b) at minimum print_rank_0 a loud warning naming the model size and that this path will gather the full base model. The current silent fallback masks a tail-latency / OOM failure mode in the only scenario that matters.

if "rotary_emb" not in k
and not any(p in k for p in ("model.", "lm_head.", "embed_tokens."))
}
del raw_sd
# Coerce to CPU for save_file (the distributed gather uses cpu_offload, but the
# single-GPU fallback may return CUDA tensors).
drafter_sd = {k: (v.cpu() if v.device.type != "cpu" else v) for k, v in drafter_sd.items()}
Comment on lines +292 to +310

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[IMPORTANT Algorithm/Export] The fallback key heuristic doesn't match what DFlashExporter._extract_state_dict actually returns — it will produce an empty export silently.

DFlashExporter._extract_state_dict (modelopt/torch/export/plugins/hf_spec_export.py:322-332) keeps only entries that contain "dflash_module." and strips that prefix. If get_model_state_dict(..., submodules={model.dflash_module}, full_state_dict=True) returns keys already pre-stripped (without the dflash_module. prefix), then exporter._extract_state_dict(raw_sd) returns {} — that's why this fallback exists.

But the heuristic here:

drafter_sd = {
    k: v
    for k, v in raw_sd.items()
    if "rotary_emb" not in k
    and not any(p in k for p in ("model.", "lm_head.", "embed_tokens."))
}

filters out anything containing "model.", which kills layer/decoder weights named like layers.0.self_attn.q_proj.weight (no model. substring) — fine — but it also kills any key that happens to legitimately contain "model." if a future DFlash module nests differently. More importantly, the substring tests "lm_head." and "embed_tokens." will NEVER match in the pre-stripped DFlash submodule case (the DFlash draft has neither). The exporter's _check_valid_sd is also bypassed in this branch — so a malformed state dict (e.g. weights from the wrong module after a refactor) would write a broken model.safetensors and only fail at vLLM load time.

Recommendation: instead of the substring heuristic, detect the pre-stripped case explicitly and either (a) re-prefix the keys with "dflash_module." and rerun exporter._extract_state_dict (so the same validation path runs), or (b) at minimum, assert the resulting key set looks sane (e.g. contains fc.weight, norm.weight, layers.0.self_attn.q_proj.weight) before writing. Right now a key-layout regression would pass through silently.


if not drafter_sd:
print_rank_0(f"Warning: No dflash_module weights found at step {step}, skipping export")
return control

# Only rank 0 writes files
if is_master():
try:
os.makedirs(export_dir, exist_ok=True)
save_file(drafter_sd, os.path.join(export_dir, "model.safetensors"))

config = exporter._export_config()
with open(os.path.join(export_dir, "config.json"), "w") as f:
json.dump(config, f, indent=2)

total_mb = sum(v.nbytes for v in drafter_sd.values()) / 1024 / 1024
print_rank_0(
f"Exported DFlash draft ({len(drafter_sd)} tensors, {total_mb:.0f}MB) "
f"to {export_dir}"
)
except Exception as e:
print_rank_0(f"Warning: DFlash export failed at step {step}: {e}")

return control


class EagleTrainingPlot(TrainerCallback):
"""Callback that plot training acc and AR during training."""

Expand Down
Loading
Loading