[Feat]:Support DPace#1724
Conversation
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Enterprise Run ID: 📒 Files selected for processing (6)
📝 WalkthroughWalkthroughAdds D-PACE (Dynamic Position-Aware Cross-Entropy) as a new DFlash training loss objective. Two new ChangesD-PACE Loss Objective for DFlash
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes 🚥 Pre-merge checks | ✅ 5 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Comment |
|
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1724 +/- ##
==========================================
- Coverage 77.12% 76.55% -0.58%
==========================================
Files 511 511
Lines 56247 56267 +20
==========================================
- Hits 43381 43073 -308
- Misses 12866 13194 +328
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
|
/claude review |
|
/claude review |
1 similar comment
|
/claude review |
| with torch.no_grad(): | ||
| conf_ce = F.cross_entropy( | ||
| logits.view(-1, logits.size(-1)), target_ids.view(-1), reduction="none" | ||
| ).view(bsz, n_blocks, block_size) | ||
| confidences = torch.exp(-conf_ce[..., 1:].float()) | ||
| dpace = torch.ones_like(weight_mask) | ||
| dpace[..., 1:] = _dpace_position_weights(confidences, self.dflash_dpace_alpha) | ||
| weight_mask = weight_mask * dpace | ||
| elif self.dflash_loss_decay_factor > 0: | ||
| k = torch.arange(block_size, device=device).view(1, 1, -1) | ||
| decay = torch.exp(-(k - 1).clamp(min=0).float() / self.dflash_loss_decay_factor) | ||
| weight_mask = weight_mask * decay |
There was a problem hiding this comment.
[SUGGESTION] When base_logits is None (the non-KD path), the per-token cross-entropy is computed twice — once here under no_grad to derive confidences, and again at line 421 to compute the actual loss. Since the second computation is exactly the per-token CE you already have, you could reuse it (compute once with grad enabled, take .detach().exp() for the confidences). The PR description already acknowledges the ~2.3% overhead — eliminating this duplication would close most of that gap. The KD path correctly remains separate because its actual loss is KL, not CE.
Why it matters: small but free win on training throughput; CE is one of the more expensive ops in the inner training loop because of the vocab-size matmul.
How to apply: hoist a single loss_per_token = F.cross_entropy(...) computation, derive confidences = torch.exp(-loss_per_token.detach()).view(bsz, n_blocks, block_size)[..., 1:].float(), then later use the same loss_per_token in the loss reduction. Keep the no-grad CE only for the KD branch.
| Returns: | ||
| Detached weights with the same shape and dtype as ``confidences``. | ||
| """ | ||
| if not 0.0 <= alpha <= 1.0: |
There was a problem hiding this comment.
[SUGGESTION] Docstring says alpha is in (0, 1] but the validation accepts [0, 1] (closed at 0). The user-facing path through DFlashModel.modify() correctly rejects alpha=0, but a direct caller of _dpace_position_weights with alpha=0 would silently get all-zero weights (the cumulative product collapses on the first position) instead of an error. Tighten the check to 0.0 < alpha <= 1.0 to match the docstring, or relax the docstring to [0, 1].
| raise ValueError( | ||
| f"dflash_dpace_alpha must be in (0, 1] for the D-PACE objective, got " | ||
| f"{self.dflash_dpace_alpha}" | ||
| ) |
There was a problem hiding this comment.
[SUGGESTION] Consider warning (or rejecting) when dflash_loss_objective == "dpace" and dflash_loss_decay_factor != 0.0 (i.e. the user has explicitly set both). The default recipe modelopt_recipes/general/speculative_decoding/dflash.yaml already sets dflash_loss_decay_factor: 4.0, so a user who only flips dflash_loss_objective: dpace won't realize their non-default decay value is silently ignored (the doc notes the mutual exclusion, but the runtime is silent). A logger.warning(...) here would surface the misconfiguration without blocking the run.
There was a problem hiding this comment.
Claude review passed — no blocking issues found.
Summary (CRITICAL: 0, IMPORTANT: 0, SUGGESTION: 3)
The D-PACE implementation is correct and well-scoped:
- Algorithm matches paper Eq.7-8: smoothing
q~_i = (1-α)q_i + α, prefix-product, suffix-sum (reverse-cumsum-reverse), all underno_gradand explicitly detached. - Opt-in (default
dflash_loss_objective='decay'preserves prior behavior) — no backward-compat concern. - Validation lives in
DFlashModel.modify()so bad configs fail at convert-time, not deep in the training loop. - Tests cover formula correctness, detachment, monotonicity, smoothing floor, error paths, and
mtsp.convertwiring. - Position-0 (anchor) is correctly excluded from D-PACE weights via
[..., 1:]slicing. - No mode-state schema or export-path changes — purely a training-loss feature.
Inline suggestions (non-blocking):
- CE on the predicted positions is computed twice in the non-KD path (once for
confidences, once for the loss). Reusing one computation would close most of the documented ~2.3% overhead. _dpace_position_weightsacceptsalpha=0(silently zero weights) while its docstring claims(0, 1]— tighten the runtime check to match.- When
dflash_loss_objective='dpace'anddflash_loss_decay_factoris non-default, the decay value is silently ignored. Alogger.warningwould surface the misconfiguration since the default recipe already setsdecay_factor: 4.0.
LGTM.
What does this PR do?
Type of change: New feature
Adds the D-PACE (Dynamic Position-Aware Cross-Entropy) loss objective for DFlash speculative-decoding training (arXiv:2605.18810). It replaces the static exponential position decay with per-position CE weights derived from the draft's own confidence
q_i = exp(-CE_i): smoothedq̃_i = (1-α)q_i + α(Eq.7) and weighted by the suffix-sum of prefix productsw_j = Σ_{m≥j} ∏_{i≤m} q̃_i(Eq.8), which directly targets expected accepted block length and shifts signal toward whichever positions currently limit acceptance.Selected via
dflash_loss_objective: dpace(defaultdecaykeeps current behavior); smoothing viadflash_dpace_alpha(default 0.5). Weights are detached from the gradient — training-only, ~2.3% overhead, no architecture or inference change. Mutually exclusive withdflash_loss_decay_factor.Usage
Testing
CPU unit tests in
tests/unit/torch/speculative/plugins/test_hf_dflash.py: weights match the paper closed form, are detached and non-increasing, the α smoothing floor keeps later weights non-zero, and convert wires/validates the new fields (rejects bad objective and degenerate α). Training validated on Qwen3-8B (curve below).Before your PR is "Ready for review"
dflash_loss_objective=decayis unchanged)CONTRIBUTING.md: N/A (no new dependency)Additional Information
Reference: D-PACE, arXiv:2605.18810. See
examples/speculative_decoding/doc/dflash.mdfor the math and tuning notes.Summary by CodeRabbit
Release Notes
New Features
Documentation
Tests