Fix step-1 NaN: run gemma attention + DeepSeek sparse-indexer logits in fp32#4136
Draft
ecnal-cienet wants to merge 2 commits into
Draft
Fix step-1 NaN: run gemma attention + DeepSeek sparse-indexer logits in fp32#4136ecnal-cienet wants to merge 2 commits into
ecnal-cienet wants to merge 2 commits into
Conversation
gemma4-31b/26b use head_dim=256 with no attn_logits_soft_cap, so bf16 attention logits can overflow to inf/nan at step 1 (fails on both Linen and NNX). Setting float32_qk_product and float32_logits runs the qk product and softmax inputs in fp32 — semantically identical, just stable. Applied to gemma4-31b, gemma4-26b, and gemma3-27b (same softcap-less attention).
The indexer qk product is computed at matmul_precision (bf16) and relu'd in bf16 while weights_proj is fp32; large bf16 logits can overflow to inf and propagate to a NaN loss. Cast the indexer logits to fp32 before relu/aggregation, matching how the main attention runs softmax in fp32.
Codecov Report✅ All modified and coverable lines are covered by tests. 📢 Thoughts on this report? Let us know! |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Two pre-existing numerical-stability bugs produce an early-step NaN loss from bf16 attention-logit overflow. Both reproduce on the Linen path too — so they are not NNX-specific — and are fixed here as one coherent "fp32 for stability" change. (Surfaced by the NNX-defaults e2e matrix: gemma4-31b
09_pdb_lt_1/13_scan_layers_false, and deepseek3-671b30_indexer_sparse.) Two independent commits.1. gemma4 / gemma3 attention (
configs/models/gemma4-31b.yml,gemma4-26b.yml,gemma3-27b.yml).gemma4 uses
head_dim=256with noattn_logits_soft_cap, andfloat32_qk_product/float32_logitsdefault tofalse— so the qk product and softmax inputs run in bf16 and can overflow to inf/nan once the step-0 update perturbs the weights off a clean init. step 0 is finite (loss ~12.98), step 1 isnan→ "Aborting training due to NaN loss". gemma2-27b avoids this because it setsattn_logits_soft_cap: 50.0; Gemma3/4 dropped the softcap in favor of qk-norm.Fix: set
float32_qk_product: true+float32_logits: truein the gemma model ymls. This is semantically identical to the model — it just runs the qk product and softmax inputs in fp32 (the gates already exist inlayers/attention_op.py). Chosen over re-introducingattn_logits_soft_cap, which would tanh-compress the logits and change the attention distribution (not faithful to the gemma4 architecture). Applied to gemma4-31b, gemma4-26b, and gemma3-27b — they share the same softcap-less attention; 26b/27b are preventive (only 31b is currently in the failing matrix).2. DeepSeek sparse indexer (
layers/attention_mla.py).The DeepSeek-V3.2 sparse indexer computes its qk product at
matmul_precision(bf16) andrelus it in bf16, whileweights_projis already fp32 — so large bf16 logits can overflow to inf and propagate to a NaN loss. Fix: cast the indexer logits to fp32 before relu/aggregation, matching how the main attention runs its softmax in fp32. Only affects theuse_indexer=Truepath. This is the leading cause for the30_indexer_sparseNaN; the qk-overflow class matches the gemma case, but the final NaN-clearing on the 671B model is pending the e2e run (see Tests) — the commit is independent and can be dropped if it doesn't clear it.Tests
float32_qk_product/float32_logitsresolvetruefor gemma4-31b/26b and gemma3-27b. The NaN itself only reproduces at scale (31B / V6e-32) — re-rungemma4-31b 09_pdb_lt_1and13_scan_layers_falseand confirm step-1 loss is finite (wasnan).model_name=deepseek3-tiny use_indexer=True indexer_sparse_training=True indexer_topk=4 attention=dot_product megablox=False, sized so the indexer qk path actually runs) trains 2 steps with finite loss (12.339 → 12.295), no crash. Full NaN-clearing ondeepseek3-671b 30_indexer_sparseneeds V6e-32.bash lint.shclean.Stats
attention_mla.py).Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.