diff --git a/pathwaysutils/elastic/manager.py b/pathwaysutils/elastic/manager.py index ae3dc81..3337afd 100644 --- a/pathwaysutils/elastic/manager.py +++ b/pathwaysutils/elastic/manager.py @@ -66,6 +66,19 @@ def _elastic_event_cleanup() -> None: array.delete() +class ElasticRetryLimit: + """A retry callback that limits the number of attempts.""" + + def __init__(self, max_attempts: int): + if max_attempts <= 0: + raise ValueError("max_attempts must be positive.") + self.max_attempts = max_attempts + + def __call__(self, attempt: int, error: Exception) -> bool: + del error # Unused + return attempt < self.max_attempts + + class Manager: """Utility class for elastic training. @@ -191,12 +204,13 @@ def _monitor_new_slices( def elastic_retry( self, - max_retries: int, + max_retries: int | None = None, minimum_slice_count: int | None = None, poll_interval: float | int = 10, timeout: float | None = None, pre_callback: Callable[..., Any] | None = None, on_elastic_event_callback: Callable[..., Any] | None = None, + retry_policy: Callable[[int, Exception], bool] | None = None, ) -> Callable[[_F], _F]: """Retries a function with elasticity fault tolerance. @@ -224,6 +238,7 @@ def elastic_retry( Args: max_retries: The maximum number of times to retry the function. + Deprecated: Use `retry_policy` instead. minimum_slice_count: The minimum number of slices required to run the function. If None, defaults to the total number of slices. poll_interval: The number of seconds to wait between activity checks. @@ -233,6 +248,10 @@ def elastic_retry( pre_callback: A callback to call before the function is attempted. on_elastic_event_callback: A callback to call after an elastic failure occurs. + retry_policy: A policy (callable) to determine if a retry should be + attempted. It accepts the attempt number (1-indexed) and the exception + that triggered the retry. If it returns False, no more retries are + attempted. Returns: A decorator that retries the wrapped function. @@ -248,17 +267,23 @@ def elastic_retry( else minimum_slice_count ) - if max_retries <= 0: - raise ValueError("max_retries must be positive.") + if max_retries is not None and retry_policy is not None: + raise ValueError("Cannot specify both max_retries and retry_policy.") + + if retry_policy is None: + if max_retries is None: + retry_policy = lambda attempt, error: True + else: + if max_retries <= 0: + raise ValueError("max_retries must be positive.") + retry_policy = ElasticRetryLimit(max_retries) def decorator(func: _F) -> _F: @functools.wraps(func) def wrapper(*args: Any, **kwargs: Any) -> Any: - def attempt_execution(retry_index: int) -> Any: - _logger.info( - "Elastic attempt %d out of %d", retry_index + 1, max_retries - ) + def attempt_execution(attempt: int) -> Any: + _logger.info("Elastic attempt %d", attempt) self.active_slice_indices = elastic.wait_for_slices( slice_count=target_slice_count, slice_to_devices=self.slice_to_devices, @@ -289,34 +314,49 @@ def attempt_execution(retry_index: int) -> Any: if monitor_thread is not None: monitor_thread.join() - for retry_index in range(max_retries): + attempt = 1 + while True: try: - return attempt_execution(retry_index) - except ScaleUpSignalError: - _logger.info("Scale up requested. Retrying.") + return attempt_execution(attempt) + except ScaleUpSignalError as error: + _logger.info("Scale up requested.") _elastic_event_cleanup() if on_elastic_event_callback is not None: on_elastic_event_callback() + + if not retry_policy(attempt, error): + _logger.info( + "Retry policy rejected retry after ScaleUpSignalError." + ) + raise ElasticRuntimeError( + f"Elastic attempt {attempt} failed." + ) from error + + _logger.info("Retrying.") except jax.errors.JaxRuntimeError as error: if not elastic.is_error_due_to_slice_down(error): raise if self.new_slice_event.is_set(): - _logger.info( - "Slice down event and new slice available detected. Retrying." - ) + _logger.info("Slice down event and new slice available detected.") else: - _logger.info("Slice down event detected. Retrying.") + _logger.info("Slice down event detected.") _elastic_event_cleanup() if on_elastic_event_callback is not None: on_elastic_event_callback() - else: - raise ElasticRuntimeError( - f"Elastic attempt {max_retries} out of {max_retries} failed." - ) + + if not retry_policy(attempt, error): + _logger.info("Retry policy rejected retry after JaxRuntimeError.") + raise ElasticRuntimeError( + f"Elastic attempt {attempt} failed." + ) from error + + _logger.info("Retrying.") + + attempt += 1 return wrapper