diff --git a/src/maxtext/common/checkpointing.py b/src/maxtext/common/checkpointing.py index 8efd48a065..ba58600c07 100644 --- a/src/maxtext/common/checkpointing.py +++ b/src/maxtext/common/checkpointing.py @@ -973,15 +973,11 @@ def maybe_save_checkpoint(checkpoint_manager, state, config, data_iterator, step except elastic_utils.manager.ScaleUpSignalError as e: if config.elastic_enabled: max_logging.log(f"Elastic event detected, letting exception bubble up: {e}") - raise - else: - raise exceptions.StopTraining("Job is preempted.") from e + raise except jax.errors.JaxRuntimeError as e: if config.elastic_enabled: max_logging.log(f"Elastic event detected, letting exception bubble up: {e}") - raise - else: - raise exceptions.StopTraining("Job is preempted.") from e + raise except Exception as e: raise exceptions.StopTraining(f"Checkpointing failed. {str(e)}") from e @@ -991,7 +987,7 @@ def maybe_save_checkpoint(checkpoint_manager, state, config, data_iterator, step # Raise exception upon preemption if checkpoint_manager.reached_preemption(actual_step): - raise exceptions.StopTraining("Job is preempted.") + raise exceptions.StopTraining("Job received termination signal (SIGTERM).") def save_checkpoint(checkpoint_manager, step, state, config=None, data_iterator=None, force=False):