Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/maxtext/configs/models/gemma3-27b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ logits_via_embedding: true
sliding_window_size: 1024
use_post_attn_norm: true
use_post_ffw_norm: true
# Run the qk product and softmax inputs in fp32 to avoid bf16 attention-logit
# overflow (no attn_logits_soft_cap on gemma3/4).
float32_qk_product: true
float32_logits: true
local_rope_max_timescale: 10_000
rope_max_timescale: 1_000_000
rope_linear_scaling_factor: 8.0
Expand Down
4 changes: 4 additions & 0 deletions src/maxtext/configs/models/gemma4-26b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ global_rope_proportion: 0.25
local_rope_proportion: 1.0
v_norm_with_scale: false
final_logits_soft_cap: 30.0
# Run the qk product and softmax inputs in fp32. gemma4 has head_dim=256 and no
# attn_logits_soft_cap, so bf16 attention logits can overflow to inf/nan.
float32_qk_product: true
float32_logits: true

# MoE configuration
num_experts: 128
Expand Down
4 changes: 4 additions & 0 deletions src/maxtext/configs/models/gemma4-31b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ use_post_ffw_norm: true
sliding_window_size: 1024
share_kv_projections: true
v_norm_with_scale: false
# Run the qk product and softmax inputs in fp32. gemma4 has head_dim=256 and no
# attn_logits_soft_cap, so bf16 attention logits can overflow to inf/nan.
float32_qk_product: true
float32_logits: true

# RoPE scaling
local_rope_max_timescale: 10000
Expand Down
3 changes: 3 additions & 0 deletions src/maxtext/layers/attention_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,9 @@ def __call__(
# QK product: relu(q @ k.T), [b, t, s, h]
# Similar to MQA, each key is shared by h query head
logits = jnp.einsum("bthd, bsd -> btsh", q, k, precision=self.config.matmul_precision)
# Cast to fp32 before relu/aggregation: in bf16 the qk product can overflow to
# inf and propagate to a NaN loss (the main attention softmax is fp32 too).
logits = logits.astype(jnp.float32)
logits = jax.nn.relu(logits)
# Compute head weights: project from input, [b, t, embed_dim] -> [b, t, h]
weights = self.weights_proj(inputs_q)
Expand Down
Loading