diff --git a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/config.py b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/config.py index e8c0eb4..6763a53 100644 --- a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/config.py +++ b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/config.py @@ -2,6 +2,7 @@ from __future__ import annotations +import math import random from dataclasses import dataclass, field from enum import Enum, StrEnum @@ -589,5 +590,16 @@ def apply_jitter(self, delay: float) -> float: # Full jitter: random(0, delay) return random.random() * delay # noqa: S311 + def finalize_delay(self, base_delay: float) -> int: + """Apply jitter, round up, and clamp to a minimum of 1 second. + + Args: + base_delay: The base delay value before jitter is applied + + Returns: + The final delay in whole seconds, at least 1 + """ + return max(1, math.ceil(self.apply_jitter(base_delay))) + # endregion Jitter diff --git a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/retries.py b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/retries.py index 8a8b3db..6d0723d 100644 --- a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/retries.py +++ b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/retries.py @@ -2,7 +2,6 @@ from __future__ import annotations -import math import re from dataclasses import dataclass, field from typing import TYPE_CHECKING, Generic, TypeVar @@ -71,43 +70,83 @@ def max_delay_seconds(self) -> int: return self.max_delay.to_seconds() +@dataclass +class LinearRetryStrategyConfig: + max_attempts: int = 6 + initial_delay: Duration = field(default_factory=lambda: Duration.from_seconds(1)) + increment: Duration = field(default_factory=lambda: Duration.from_seconds(1)) + max_delay: Duration = field(default_factory=lambda: Duration.from_minutes(5)) + jitter_strategy: JitterStrategy = field(default=JitterStrategy.FULL) + retryable_errors: list[str | re.Pattern] | None = None + retryable_error_types: list[type[Exception]] | None = None + + @property + def initial_delay_seconds(self) -> int: + """Get initial delay in seconds.""" + return self.initial_delay.to_seconds() + + @property + def increment_seconds(self) -> int: + """Get increment in seconds.""" + return self.increment.to_seconds() + + @property + def max_delay_seconds(self) -> int: + """Get max delay in seconds.""" + return self.max_delay.to_seconds() + + +def _resolve_retryable_errors( + retryable_errors: list[str | re.Pattern] | None, + retryable_error_types: list[type[Exception]] | None, +) -> tuple[list[str | re.Pattern], list[type[Exception]]]: + """Resolve the error filters, applying the match-all default only when neither is set.""" + should_use_default_errors: bool = ( + retryable_errors is None and retryable_error_types is None + ) + resolved_errors: list[str | re.Pattern] = ( + retryable_errors + if retryable_errors is not None + else ([_DEFAULT_RETRYABLE_ERROR_PATTERN] if should_use_default_errors else []) + ) + resolved_error_types: list[type[Exception]] = retryable_error_types or [] + return resolved_errors, resolved_error_types + + +def _is_error_retryable( + error: Exception, + retryable_errors: list[str | re.Pattern], + retryable_error_types: list[type[Exception]], +) -> bool: + """Return True when the error matches one of the message patterns or types.""" + is_retryable_error_message: bool = any( + pattern.search(str(error)) + if isinstance(pattern, re.Pattern) + else pattern in str(error) + for pattern in retryable_errors + ) + is_retryable_error_type: bool = any( + isinstance(error, error_type) for error_type in retryable_error_types + ) + return is_retryable_error_message or is_retryable_error_type + + def create_retry_strategy( config: RetryStrategyConfig | None = None, ) -> Callable[[Exception, int], RetryDecision]: if config is None: config = RetryStrategyConfig() - # Apply default retryableErrors only if user didn't specify either filter - should_use_default_errors: bool = ( - config.retryable_errors is None and config.retryable_error_types is None - ) - - retryable_errors: list[str | re.Pattern] = ( - config.retryable_errors - if config.retryable_errors is not None - else ([_DEFAULT_RETRYABLE_ERROR_PATTERN] if should_use_default_errors else []) + retryable_errors, retryable_error_types = _resolve_retryable_errors( + config.retryable_errors, config.retryable_error_types ) - retryable_error_types: list[type[Exception]] = config.retryable_error_types or [] def retry_strategy(error: Exception, attempts_made: int) -> RetryDecision: # Check if we've exceeded max attempts if attempts_made >= config.max_attempts: return RetryDecision.no_retry() - # Check if error is retryable based on error message - is_retryable_error_message: bool = any( - pattern.search(str(error)) - if isinstance(pattern, re.Pattern) - else pattern in str(error) - for pattern in retryable_errors - ) - - # Check if error is retryable based on error type - is_retryable_error_type: bool = any( - isinstance(error, error_type) for error_type in retryable_error_types - ) - - if not is_retryable_error_message and not is_retryable_error_type: + if not _is_error_retryable(error, retryable_errors, retryable_error_types): return RetryDecision.no_retry() # Calculate delay with exponential backoff @@ -115,10 +154,7 @@ def retry_strategy(error: Exception, attempts_made: int) -> RetryDecision: config.initial_delay_seconds * (config.backoff_rate ** (attempts_made - 1)), config.max_delay_seconds, ) - # Apply jitter to get final delay - delay_with_jitter: float = config.jitter_strategy.apply_jitter(base_delay) - # Round up and ensure minimum of 1 second - final_delay: int = max(1, math.ceil(delay_with_jitter)) + final_delay: int = config.jitter_strategy.finalize_delay(base_delay) return RetryDecision.retry(Duration(seconds=final_delay)) @@ -126,29 +162,37 @@ def retry_strategy(error: Exception, attempts_made: int) -> RetryDecision: def create_linear_retry_strategy( - max_attempts: int = 6, - initial_delay: Duration | None = None, - increment: Duration | None = None, + config: LinearRetryStrategyConfig | None = None, ) -> Callable[[Exception, int], RetryDecision]: - """Linearly increasing delay between retries: initial + increment * (attempts_made - 1). + """Linearly increasing delay between retries. - Mirrors the JS SDK's ``createLinearRetryStrategy``. With the defaults this - yields delays of 1s, 2s, 3s, 4s, 5s. No jitter is applied and there is no - upper cap on the delay; callers who need either can build their own - strategy via ``create_retry_strategy``. + The base delay is ``initial_delay + increment * (attempts_made - 1)``, + capped at ``max_delay``, with jitter and error filtering applied the same + way as :func:`create_retry_strategy`. Mirrors the JS SDK's + ``createLinearRetryStrategy``. """ - initial: Duration = ( - initial_delay if initial_delay is not None else Duration.from_seconds(1) + if config is None: + config = LinearRetryStrategyConfig() + + retryable_errors, retryable_error_types = _resolve_retryable_errors( + config.retryable_errors, config.retryable_error_types ) - step: Duration = increment if increment is not None else Duration.from_seconds(1) - def linear_retry_strategy(_error: Exception, attempts_made: int) -> RetryDecision: - if attempts_made >= max_attempts: + def linear_retry_strategy(error: Exception, attempts_made: int) -> RetryDecision: + if attempts_made >= config.max_attempts: + return RetryDecision.no_retry() + + if not _is_error_retryable(error, retryable_errors, retryable_error_types): return RetryDecision.no_retry() - delay_seconds: int = initial.to_seconds() + step.to_seconds() * ( - attempts_made - 1 + + base_delay: float = min( + config.initial_delay_seconds + + config.increment_seconds * (attempts_made - 1), + config.max_delay_seconds, ) - return RetryDecision.retry(Duration(seconds=delay_seconds)) + final_delay: int = config.jitter_strategy.finalize_delay(base_delay) + + return RetryDecision.retry(Duration(seconds=final_delay)) return linear_retry_strategy @@ -212,9 +256,12 @@ def critical(cls) -> Callable[[Exception, int], RetryDecision]: def linear(cls) -> Callable[[Exception, int], RetryDecision]: """Linearly increasing delay between retries: 1s, 2s, 3s, 4s, 5s.""" return create_linear_retry_strategy( - max_attempts=6, - initial_delay=Duration.from_seconds(1), - increment=Duration.from_seconds(1), + LinearRetryStrategyConfig( + max_attempts=6, + initial_delay=Duration.from_seconds(1), + increment=Duration.from_seconds(1), + jitter_strategy=JitterStrategy.NONE, + ) ) @classmethod diff --git a/packages/aws-durable-execution-sdk-python/tests/retries_test.py b/packages/aws-durable-execution-sdk-python/tests/retries_test.py index 30b3115..b82d3ed 100644 --- a/packages/aws-durable-execution-sdk-python/tests/retries_test.py +++ b/packages/aws-durable-execution-sdk-python/tests/retries_test.py @@ -8,6 +8,7 @@ from aws_durable_execution_sdk_python.config import Duration from aws_durable_execution_sdk_python.retries import ( JitterStrategy, + LinearRetryStrategyConfig, RetryDecision, RetryPresets, RetryStrategyConfig, @@ -580,8 +581,10 @@ def test_mixed_error_types_and_patterns(): # region create_linear_retry_strategy -def test_linear_retry_strategy_uses_additive_formula(): - """Default config yields delays of 1s, 2s, 3s, 4s, 5s with no jitter.""" +@patch("random.random") +def test_linear_retry_strategy_uses_additive_formula(mock_random): + """Default config yields additive delays of 1s, 2s, 3s, 4s, 5s.""" + mock_random.return_value = 1.0 # FULL jitter at the upper bound keeps the base strategy = create_linear_retry_strategy() delays = [ @@ -593,7 +596,7 @@ def test_linear_retry_strategy_uses_additive_formula(): def test_linear_retry_strategy_stops_at_max_attempts(): """No retry once attempts_made reaches max_attempts.""" - strategy = create_linear_retry_strategy(max_attempts=3) + strategy = create_linear_retry_strategy(LinearRetryStrategyConfig(max_attempts=3)) assert strategy(Exception("e"), 1).should_retry is True assert strategy(Exception("e"), 2).should_retry is True @@ -603,9 +606,12 @@ def test_linear_retry_strategy_stops_at_max_attempts(): def test_linear_retry_strategy_respects_custom_initial_and_increment(): """Custom initial_delay and increment shift the additive sequence.""" strategy = create_linear_retry_strategy( - max_attempts=10, - initial_delay=Duration.from_seconds(2), - increment=Duration.from_seconds(3), + LinearRetryStrategyConfig( + max_attempts=10, + initial_delay=Duration.from_seconds(2), + increment=Duration.from_seconds(3), + jitter_strategy=JitterStrategy.NONE, + ) ) delays = [ @@ -616,6 +622,61 @@ def test_linear_retry_strategy_respects_custom_initial_and_increment(): assert delays == [2, 5, 8, 11] +def test_linear_retry_strategy_caps_at_max_delay(): + """The additive delay is capped at max_delay before jitter.""" + strategy = create_linear_retry_strategy( + LinearRetryStrategyConfig( + max_attempts=10, + initial_delay=Duration.from_seconds(10), + increment=Duration.from_seconds(10), + max_delay=Duration.from_seconds(25), + jitter_strategy=JitterStrategy.NONE, + ) + ) + + # 10, 20, then capped at 25 for the third attempt (would be 30). + delays = [ + strategy(Exception("e"), attempt).delay_seconds for attempt in range(1, 4) + ] + assert delays == [10, 20, 25] + + +@patch("random.random") +def test_linear_retry_strategy_applies_jitter(mock_random): + """FULL jitter scales the additive base delay by random().""" + mock_random.return_value = 0.5 + strategy = create_linear_retry_strategy( + LinearRetryStrategyConfig( + initial_delay=Duration.from_seconds(4), + increment=Duration.from_seconds(4), + jitter_strategy=JitterStrategy.FULL, + ) + ) + + # base = 4 + 4*1 = 8, full jitter = 0.5 * 8 = 4 + assert strategy(Exception("e"), 2).delay_seconds == 4 + + +def test_linear_retry_strategy_filters_by_error_message(): + """Only errors matching retryable_errors are retried.""" + strategy = create_linear_retry_strategy( + LinearRetryStrategyConfig(retryable_errors=["timeout"]) + ) + + assert strategy(Exception("connection timeout"), 1).should_retry is True + assert strategy(Exception("permission denied"), 1).should_retry is False + + +def test_linear_retry_strategy_filters_by_error_type(): + """Only errors matching retryable_error_types are retried.""" + strategy = create_linear_retry_strategy( + LinearRetryStrategyConfig(retryable_error_types=[ValueError]) + ) + + assert strategy(ValueError("bad"), 1).should_retry is True + assert strategy(KeyError("missing"), 1).should_retry is False + + # endregion