diff --git a/src/maxtext/configs/models/gemma3-27b.yml b/src/maxtext/configs/models/gemma3-27b.yml index 5d3b70a3a9..587f0283eb 100644 --- a/src/maxtext/configs/models/gemma3-27b.yml +++ b/src/maxtext/configs/models/gemma3-27b.yml @@ -28,6 +28,10 @@ logits_via_embedding: true sliding_window_size: 1024 use_post_attn_norm: true use_post_ffw_norm: true +# Run the qk product and softmax inputs in fp32 to avoid bf16 attention-logit +# overflow (no attn_logits_soft_cap on gemma3/4). +float32_qk_product: true +float32_logits: true local_rope_max_timescale: 10_000 rope_max_timescale: 1_000_000 rope_linear_scaling_factor: 8.0 diff --git a/src/maxtext/configs/models/gemma4-26b.yml b/src/maxtext/configs/models/gemma4-26b.yml index d32149bccb..c92c58c578 100644 --- a/src/maxtext/configs/models/gemma4-26b.yml +++ b/src/maxtext/configs/models/gemma4-26b.yml @@ -39,6 +39,10 @@ global_rope_proportion: 0.25 local_rope_proportion: 1.0 v_norm_with_scale: false final_logits_soft_cap: 30.0 +# Run the qk product and softmax inputs in fp32. gemma4 has head_dim=256 and no +# attn_logits_soft_cap, so bf16 attention logits can overflow to inf/nan. +float32_qk_product: true +float32_logits: true # MoE configuration num_experts: 128 diff --git a/src/maxtext/configs/models/gemma4-31b.yml b/src/maxtext/configs/models/gemma4-31b.yml index 9dec302bc1..323956fae7 100644 --- a/src/maxtext/configs/models/gemma4-31b.yml +++ b/src/maxtext/configs/models/gemma4-31b.yml @@ -34,6 +34,10 @@ use_post_ffw_norm: true sliding_window_size: 1024 share_kv_projections: true v_norm_with_scale: false +# Run the qk product and softmax inputs in fp32. gemma4 has head_dim=256 and no +# attn_logits_soft_cap, so bf16 attention logits can overflow to inf/nan. +float32_qk_product: true +float32_logits: true # RoPE scaling local_rope_max_timescale: 10000 diff --git a/src/maxtext/layers/attention_mla.py b/src/maxtext/layers/attention_mla.py index df7fd16ea2..cc43fd190a 100644 --- a/src/maxtext/layers/attention_mla.py +++ b/src/maxtext/layers/attention_mla.py @@ -346,6 +346,9 @@ def __call__( # QK product: relu(q @ k.T), [b, t, s, h] # Similar to MQA, each key is shared by h query head logits = jnp.einsum("bthd, bsd -> btsh", q, k, precision=self.config.matmul_precision) + # Cast to fp32 before relu/aggregation: in bf16 the qk product can overflow to + # inf and propagate to a NaN loss (the main attention softmax is fp32 too). + logits = logits.astype(jnp.float32) logits = jax.nn.relu(logits) # Compute head weights: project from input, [b, t, embed_dim] -> [b, t, h] weights = self.weights_proj(inputs_q)