[Feat]: Domino support#1710
Conversation
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Enterprise Run ID: 📒 Files selected for processing (3)
🚧 Files skipped from review as they are similar to previous changes (3)
📝 WalkthroughWalkthroughAdds a "Domino" speculative decoding training variant on top of DFlash. New components include a ChangesDomino Speculative Decoding Training
Sequence Diagram(s)sequenceDiagram
participant Trainer
participant DominoLambdaCallback
participant HFDominoModel
participant DraftBackbone
participant DominoModule
Trainer->>DominoLambdaCallback: on_step_begin(global_step)
DominoLambdaCallback->>HFDominoModel: _lambda_base = compute_lambda_base(global_step, total_steps, ...)
Trainer->>HFDominoModel: forward(input_ids, labels, loss_mask)
HFDominoModel->>HFDominoModel: anchor sampling + loss_mask build
HFDominoModel->>DraftBackbone: run backbone (no-grad base hidden states)
DraftBackbone-->>HFDominoModel: hidden_states, base_logits
HFDominoModel->>DominoModule: _apply_domino_head(hidden_states, base_logits, anchors)
DominoModule-->>HFDominoModel: corrected final_logits
HFDominoModel->>HFDominoModel: _compute_domino_loss(final_logits, base_logits, _lambda_base)
HFDominoModel-->>Trainer: ModelOutput(loss, base_loss, final_loss, base_accuracy, lambda_base)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Suggested reviewers
🚥 Pre-merge checks | ✅ 5 | ❌ 1❌ Failed checks (1 inconclusive)
✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Comment |
|
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1710 +/- ##
==========================================
- Coverage 77.12% 77.00% -0.13%
==========================================
Files 511 513 +2
Lines 56236 56614 +378
==========================================
+ Hits 43374 43596 +222
- Misses 12862 13018 +156
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
|
/claude review |
There was a problem hiding this comment.
Warning
CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.
Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.
Actionable comments posted: 6
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@examples/speculative_decoding/main.py`:
- Line 285: The import of DominoLambdaCallback from
modelopt.torch.speculative.plugins.hf_domino at line 285 is currently inside a
function, which violates the repository's import-placement guidelines. Move this
import to the module scope at the top of the file with other imports, unless
there is a specific reason for the in-function placement (such as circular
dependency, optional dependency, or performance concerns). If such a reason
exists, keep the in-function import but add a brief comment above it explaining
the concrete justification for the non-standard placement.
In `@modelopt/torch/speculative/config.py`:
- Around line 135-150: Add schema bounds validation to the
dflash_lambda_base_start and dflash_lambda_base_decay_ratio ModeloptField
definitions to enforce that these normalized weight/fraction fields accept only
values in the valid range (0 to 1). This will cause invalid configuration values
to be rejected at config load time rather than being silently masked downstream,
following the coding guideline to validate external input at the interface
boundary.
In `@modelopt/torch/speculative/dflash/conversion.py`:
- Around line 44-49: The registry selection logic currently silently defaults to
DFlashDMRegistry for any unknown projector_type value, which can hide typos and
route users incorrectly. Replace the conditional expression with explicit
validation that checks if the projector_type is one of the supported values
("domino" or the default). Raise an appropriate error (e.g., ValueError) if the
projector_type is unsupported, ensuring invalid input is rejected at the
interface boundary rather than silently falling back to a default registry.
In `@modelopt/torch/speculative/plugins/modeling_domino.py`:
- Around line 58-59: Add validation for the pure_draft_prefix_len attribute
immediately after it is read from config in the initialization block. The
validation should check that pure_draft_prefix_len is non-negative and strictly
less than the block_size to ensure suffix correction works properly. If the
validation fails, raise a clear ValueError with a descriptive message indicating
the valid range requirement. This validation should occur at the config
interface boundary during module initialization, right after line 58 where
pure_draft_prefix_len is assigned from the config getattr call.
In `@tests/unit/torch/speculative/plugins/test_hf_domino.py`:
- Line 99: Move the in-function import of DFlashModule from line 99 and the
corresponding import at line 184 to the top-level of the test module (at the
beginning of the file with other imports). These imports do not have explicit
circular dependency, optional dependency, or heavy-import justifications that
would warrant keeping them in-function, so they should follow test conventions
by being at module-level to catch import errors during test collection rather
than execution.
In `@tools/launcher/examples/Qwen/Qwen3-8B/hf_online_domino.yaml`:
- Around line 72-75: The environment section in the YAML configuration file for
this new Qwen3-8B model config is missing two required environment variables. In
the environment list at lines 72-75 (which currently contains only
MAX_FINAL_LOSS and MIN_FINAL_ACC), add the two required launcher environment
variables MLM_MODEL_CFG and QUANT_CFG as additional list items, following the
same format as the existing environment variables. These must be explicitly set
according to the launcher coding guidelines for new model configurations.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: afea7d4d-38ed-4205-b10b-8d8ee062ad26
📒 Files selected for processing (12)
CHANGELOG.rstexamples/speculative_decoding/main.pymodelopt/torch/export/plugins/hf_spec_export.pymodelopt/torch/speculative/config.pymodelopt/torch/speculative/dflash/conversion.pymodelopt/torch/speculative/plugins/__init__.pymodelopt/torch/speculative/plugins/hf_dflash.pymodelopt/torch/speculative/plugins/hf_domino.pymodelopt/torch/speculative/plugins/modeling_domino.pymodelopt_recipes/general/speculative_decoding/domino.yamltests/unit/torch/speculative/plugins/test_hf_domino.pytools/launcher/examples/Qwen/Qwen3-8B/hf_online_domino.yaml
There was a problem hiding this comment.
Claude review
Reviewed the Domino training-only addition end-to-end. The algorithm is correct: GRU teacher-forcing on input_ids[anchor..anchor+bs-1] with shift_label=True predicts anchor+k+1 from anchor..anchor+k (no leakage), the suffix slice + base-logit add matches the SpecForge formulation, and the dual loss / λ-curriculum are wired correctly. The DominoDMRegistry split keeps HFDominoModel from shadowing HFDFlashModel, and config / state-dict round-tripping looks safe (new fields have defaults; old saved configs without projector_type keep routing to the DFlash registry).
Findings
- CRITICAL: 0
- IMPORTANT: 1
- Silent eval bypass: in non-training mode
forwarddelegates toHFDFlashModel.forward, sopseudo_speculative_generate/ AR validation never applies the trained Domino head. Acknowledged in the PR description, but anestimate_ar/ar_validate_stepsuser would silently get backbone-only acceptance numbers. Suggest alogger.warning_oncehere and a short note indomino.yaml.
- Silent eval bypass: in non-training mode
- SUGGESTION: 3
DominoExporter._export_configusesgetattr(draft_config, "emb_dim")(no default) — fails at export time after a long train if the user'sdflash_architecture_configomits a head field. Prefer validating inHFDominoModel.modify.DominoLambdaCallbackfalls back tototal_steps=1whenstate.max_stepsis unset, which silently flips λ_base to 0 from step 1. Add a one-shot warning when the fallback is taken.- With
lambda_base == 1.0the head params drop out of the autograd graph; the recipe correctly setsddp_find_unused_parameters: truebut the dependency is invisible — worth a one-line note next to that flag indomino.yaml.
Overall risk
Low. Training-only path, opt-in via projector_type=domino, no behavior change for existing DFlash users. The findings are about ergonomics / fail-loud-vs-fail-silent rather than algorithm correctness.
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
|
The results look good. The the final checkpoint the same as DFlash? Or require additional vLLM support? |
The final checkpoint contains new components other than typical DFlash. vLLM/TRTLLM is not yet ready in their main branches. We can add serve/specdecbench support once it's available. |
What does this PR do?
Type of change: New feature
Adds Domino speculative decoding: the parallel DFlash draft backbone plus a lightweight GRU causal correction head. The backbone produces base logits for a full draft block in one forward; a GRU over the block's teacher-forced tokens produces a causal state that is fused with the backbone hidden state and projected to a vocab-sized logit correction on the block suffix — injecting the intra-block causal dependency the parallel backbone lacks. Trained with a dual loss
(1-λ)*final + λ*base, whereλ_basedecays linearly 1→0 (curriculum: learn the parallel backbone first, then the correction).Reuses the DFlash mode/config/recipe; selected via
dflash_architecture_config.projector_type=dominoand routed to its own registry soHFDominoModeldoes not shadowHFDFlashModel. Exports in the z-lab/SpecForge drafter format (prefix_gru.*/embed_proj.*).Usage
# Online training (recipe: projector_type=domino) uv run launch.py --yaml examples/Qwen/Qwen3-8B/hf_online_domino.yaml --yesTesting
CPU unit tests in
tests/unit/torch/speculative/plugins/test_hf_domino.pycover conversion routing, the training forward (dual loss + grads), the λ schedule, and the export format. Online Qwen3-8B training validated end-to-end (loss curve below).Before your PR is "Ready for review"
projector_type=domino; DFlash path unchanged)CONTRIBUTING.md: N/A (no new dependency)Additional Information
Reference: SpecForge PR #571 (z-lab); drafter format
huggingface.co/Huang2020/Qwen3-8B-Domino-b16.Summary by CodeRabbit
Release Notes
New Features
Documentation & Configuration
Tests