From 5913cc55fe9b1d7f6a822beb464f861852d94eba Mon Sep 17 00:00:00 2001 From: Luke Baumann Date: Tue, 9 Jun 2026 22:22:48 +0000 Subject: [PATCH] Fix misleading preemption logs in checkpointing - Raise JaxRuntimeError directly instead of masking it as StopTraining('Job is preempted.') when elasticity is disabled. This prevents hiding the true cause of crashes in logs and allows proper crash handling. - Reword StopTraining message when reached_preemption is true to 'Job received termination signal (SIGTERM).' as SIGTERM is sent for various termination events (scaling, updates), not just preemption. Fixes: b/516962538 --- src/maxtext/common/checkpointing.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/maxtext/common/checkpointing.py b/src/maxtext/common/checkpointing.py index 8efd48a065..ba58600c07 100644 --- a/src/maxtext/common/checkpointing.py +++ b/src/maxtext/common/checkpointing.py @@ -973,15 +973,11 @@ def maybe_save_checkpoint(checkpoint_manager, state, config, data_iterator, step except elastic_utils.manager.ScaleUpSignalError as e: if config.elastic_enabled: max_logging.log(f"Elastic event detected, letting exception bubble up: {e}") - raise - else: - raise exceptions.StopTraining("Job is preempted.") from e + raise except jax.errors.JaxRuntimeError as e: if config.elastic_enabled: max_logging.log(f"Elastic event detected, letting exception bubble up: {e}") - raise - else: - raise exceptions.StopTraining("Job is preempted.") from e + raise except Exception as e: raise exceptions.StopTraining(f"Checkpointing failed. {str(e)}") from e @@ -991,7 +987,7 @@ def maybe_save_checkpoint(checkpoint_manager, state, config, data_iterator, step # Raise exception upon preemption if checkpoint_manager.reached_preemption(actual_step): - raise exceptions.StopTraining("Job is preempted.") + raise exceptions.StopTraining("Job received termination signal (SIGTERM).") def save_checkpoint(checkpoint_manager, step, state, config=None, data_iterator=None, force=False):