From 3ba4b07ecf2c0185cb1e689d013dfe4157f2f239 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Wed, 10 Jun 2026 18:25:30 +0000 Subject: [PATCH 1/2] gemma: run attention qk-product and softmax in fp32 to fix step-1 NaN MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit gemma4-31b/26b use head_dim=256 with no attn_logits_soft_cap, so bf16 attention logits can overflow to inf/nan at step 1 (fails on both Linen and NNX). Setting float32_qk_product and float32_logits runs the qk product and softmax inputs in fp32 — semantically identical, just stable. Applied to gemma4-31b, gemma4-26b, and gemma3-27b (same softcap-less attention). --- src/maxtext/configs/models/gemma3-27b.yml | 4 ++++ src/maxtext/configs/models/gemma4-26b.yml | 4 ++++ src/maxtext/configs/models/gemma4-31b.yml | 4 ++++ 3 files changed, 12 insertions(+) diff --git a/src/maxtext/configs/models/gemma3-27b.yml b/src/maxtext/configs/models/gemma3-27b.yml index 5d3b70a3a9..587f0283eb 100644 --- a/src/maxtext/configs/models/gemma3-27b.yml +++ b/src/maxtext/configs/models/gemma3-27b.yml @@ -28,6 +28,10 @@ logits_via_embedding: true sliding_window_size: 1024 use_post_attn_norm: true use_post_ffw_norm: true +# Run the qk product and softmax inputs in fp32 to avoid bf16 attention-logit +# overflow (no attn_logits_soft_cap on gemma3/4). +float32_qk_product: true +float32_logits: true local_rope_max_timescale: 10_000 rope_max_timescale: 1_000_000 rope_linear_scaling_factor: 8.0 diff --git a/src/maxtext/configs/models/gemma4-26b.yml b/src/maxtext/configs/models/gemma4-26b.yml index d32149bccb..c92c58c578 100644 --- a/src/maxtext/configs/models/gemma4-26b.yml +++ b/src/maxtext/configs/models/gemma4-26b.yml @@ -39,6 +39,10 @@ global_rope_proportion: 0.25 local_rope_proportion: 1.0 v_norm_with_scale: false final_logits_soft_cap: 30.0 +# Run the qk product and softmax inputs in fp32. gemma4 has head_dim=256 and no +# attn_logits_soft_cap, so bf16 attention logits can overflow to inf/nan. +float32_qk_product: true +float32_logits: true # MoE configuration num_experts: 128 diff --git a/src/maxtext/configs/models/gemma4-31b.yml b/src/maxtext/configs/models/gemma4-31b.yml index 9dec302bc1..323956fae7 100644 --- a/src/maxtext/configs/models/gemma4-31b.yml +++ b/src/maxtext/configs/models/gemma4-31b.yml @@ -34,6 +34,10 @@ use_post_ffw_norm: true sliding_window_size: 1024 share_kv_projections: true v_norm_with_scale: false +# Run the qk product and softmax inputs in fp32. gemma4 has head_dim=256 and no +# attn_logits_soft_cap, so bf16 attention logits can overflow to inf/nan. +float32_qk_product: true +float32_logits: true # RoPE scaling local_rope_max_timescale: 10000 From ff174b75c8f38c76b3a2e616e419bc330b2675f6 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Wed, 10 Jun 2026 18:25:31 +0000 Subject: [PATCH 2/2] DeepSeek: cast sparse-indexer logits to fp32 to avoid bf16 NaN The indexer qk product is computed at matmul_precision (bf16) and relu'd in bf16 while weights_proj is fp32; large bf16 logits can overflow to inf and propagate to a NaN loss. Cast the indexer logits to fp32 before relu/aggregation, matching how the main attention runs softmax in fp32. --- src/maxtext/layers/attention_mla.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/maxtext/layers/attention_mla.py b/src/maxtext/layers/attention_mla.py index df7fd16ea2..cc43fd190a 100644 --- a/src/maxtext/layers/attention_mla.py +++ b/src/maxtext/layers/attention_mla.py @@ -346,6 +346,9 @@ def __call__( # QK product: relu(q @ k.T), [b, t, s, h] # Similar to MQA, each key is shared by h query head logits = jnp.einsum("bthd, bsd -> btsh", q, k, precision=self.config.matmul_precision) + # Cast to fp32 before relu/aggregation: in bf16 the qk product can overflow to + # inf and propagate to a NaN loss (the main attention softmax is fp32 too). + logits = logits.astype(jnp.float32) logits = jax.nn.relu(logits) # Compute head weights: project from input, [b, t, embed_dim] -> [b, t, h] weights = self.weights_proj(inputs_q)