Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import math
import random
from dataclasses import dataclass, field
from enum import Enum, StrEnum
Expand Down Expand Up @@ -589,5 +590,16 @@ def apply_jitter(self, delay: float) -> float:
# Full jitter: random(0, delay)
return random.random() * delay # noqa: S311

def finalize_delay(self, base_delay: float) -> int:
"""Apply jitter, round up, and clamp to a minimum of 1 second.

Args:
base_delay: The base delay value before jitter is applied

Returns:
The final delay in whole seconds, at least 1
"""
return max(1, math.ceil(self.apply_jitter(base_delay)))


# endregion Jitter
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from __future__ import annotations

import math
import re
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Generic, TypeVar
Expand Down Expand Up @@ -71,84 +70,129 @@ def max_delay_seconds(self) -> int:
return self.max_delay.to_seconds()


@dataclass
class LinearRetryStrategyConfig:
max_attempts: int = 6
initial_delay: Duration = field(default_factory=lambda: Duration.from_seconds(1))
increment: Duration = field(default_factory=lambda: Duration.from_seconds(1))
max_delay: Duration = field(default_factory=lambda: Duration.from_minutes(5))
jitter_strategy: JitterStrategy = field(default=JitterStrategy.FULL)
retryable_errors: list[str | re.Pattern] | None = None
retryable_error_types: list[type[Exception]] | None = None

@property
def initial_delay_seconds(self) -> int:
"""Get initial delay in seconds."""
return self.initial_delay.to_seconds()

@property
def increment_seconds(self) -> int:
"""Get increment in seconds."""
return self.increment.to_seconds()

@property
def max_delay_seconds(self) -> int:
"""Get max delay in seconds."""
return self.max_delay.to_seconds()


def _resolve_retryable_errors(
retryable_errors: list[str | re.Pattern] | None,
retryable_error_types: list[type[Exception]] | None,
) -> tuple[list[str | re.Pattern], list[type[Exception]]]:
"""Resolve the error filters, applying the match-all default only when neither is set."""
should_use_default_errors: bool = (
retryable_errors is None and retryable_error_types is None
)
resolved_errors: list[str | re.Pattern] = (
retryable_errors
if retryable_errors is not None
else ([_DEFAULT_RETRYABLE_ERROR_PATTERN] if should_use_default_errors else [])
)
resolved_error_types: list[type[Exception]] = retryable_error_types or []
return resolved_errors, resolved_error_types


def _is_error_retryable(
error: Exception,
retryable_errors: list[str | re.Pattern],
retryable_error_types: list[type[Exception]],
) -> bool:
"""Return True when the error matches one of the message patterns or types."""
is_retryable_error_message: bool = any(
pattern.search(str(error))
if isinstance(pattern, re.Pattern)
else pattern in str(error)
for pattern in retryable_errors
)
is_retryable_error_type: bool = any(
isinstance(error, error_type) for error_type in retryable_error_types
)
return is_retryable_error_message or is_retryable_error_type


def create_retry_strategy(
config: RetryStrategyConfig | None = None,
) -> Callable[[Exception, int], RetryDecision]:
if config is None:
config = RetryStrategyConfig()

# Apply default retryableErrors only if user didn't specify either filter
should_use_default_errors: bool = (
config.retryable_errors is None and config.retryable_error_types is None
)

retryable_errors: list[str | re.Pattern] = (
config.retryable_errors
if config.retryable_errors is not None
else ([_DEFAULT_RETRYABLE_ERROR_PATTERN] if should_use_default_errors else [])
retryable_errors, retryable_error_types = _resolve_retryable_errors(
config.retryable_errors, config.retryable_error_types
)
retryable_error_types: list[type[Exception]] = config.retryable_error_types or []

def retry_strategy(error: Exception, attempts_made: int) -> RetryDecision:
# Check if we've exceeded max attempts
if attempts_made >= config.max_attempts:
return RetryDecision.no_retry()

# Check if error is retryable based on error message
is_retryable_error_message: bool = any(
pattern.search(str(error))
if isinstance(pattern, re.Pattern)
else pattern in str(error)
for pattern in retryable_errors
)

# Check if error is retryable based on error type
is_retryable_error_type: bool = any(
isinstance(error, error_type) for error_type in retryable_error_types
)

if not is_retryable_error_message and not is_retryable_error_type:
if not _is_error_retryable(error, retryable_errors, retryable_error_types):
return RetryDecision.no_retry()

# Calculate delay with exponential backoff
base_delay: float = min(
config.initial_delay_seconds * (config.backoff_rate ** (attempts_made - 1)),
config.max_delay_seconds,
)
# Apply jitter to get final delay
delay_with_jitter: float = config.jitter_strategy.apply_jitter(base_delay)
# Round up and ensure minimum of 1 second
final_delay: int = max(1, math.ceil(delay_with_jitter))
final_delay: int = config.jitter_strategy.finalize_delay(base_delay)

return RetryDecision.retry(Duration(seconds=final_delay))

return retry_strategy


def create_linear_retry_strategy(
max_attempts: int = 6,
initial_delay: Duration | None = None,
increment: Duration | None = None,
config: LinearRetryStrategyConfig | None = None,
) -> Callable[[Exception, int], RetryDecision]:
"""Linearly increasing delay between retries: initial + increment * (attempts_made - 1).
"""Linearly increasing delay between retries.

Mirrors the JS SDK's ``createLinearRetryStrategy``. With the defaults this
yields delays of 1s, 2s, 3s, 4s, 5s. No jitter is applied and there is no
upper cap on the delay; callers who need either can build their own
strategy via ``create_retry_strategy``.
The base delay is ``initial_delay + increment * (attempts_made - 1)``,
capped at ``max_delay``, with jitter and error filtering applied the same
way as :func:`create_retry_strategy`. Mirrors the JS SDK's
``createLinearRetryStrategy``.
"""
initial: Duration = (
initial_delay if initial_delay is not None else Duration.from_seconds(1)
if config is None:
config = LinearRetryStrategyConfig()

retryable_errors, retryable_error_types = _resolve_retryable_errors(
config.retryable_errors, config.retryable_error_types
)
step: Duration = increment if increment is not None else Duration.from_seconds(1)

def linear_retry_strategy(_error: Exception, attempts_made: int) -> RetryDecision:
if attempts_made >= max_attempts:
def linear_retry_strategy(error: Exception, attempts_made: int) -> RetryDecision:
if attempts_made >= config.max_attempts:
return RetryDecision.no_retry()

if not _is_error_retryable(error, retryable_errors, retryable_error_types):
return RetryDecision.no_retry()
delay_seconds: int = initial.to_seconds() + step.to_seconds() * (
attempts_made - 1

base_delay: float = min(
config.initial_delay_seconds
+ config.increment_seconds * (attempts_made - 1),
config.max_delay_seconds,
)
return RetryDecision.retry(Duration(seconds=delay_seconds))
final_delay: int = config.jitter_strategy.finalize_delay(base_delay)

return RetryDecision.retry(Duration(seconds=final_delay))

return linear_retry_strategy

Expand Down Expand Up @@ -212,9 +256,12 @@ def critical(cls) -> Callable[[Exception, int], RetryDecision]:
def linear(cls) -> Callable[[Exception, int], RetryDecision]:
"""Linearly increasing delay between retries: 1s, 2s, 3s, 4s, 5s."""
return create_linear_retry_strategy(
max_attempts=6,
initial_delay=Duration.from_seconds(1),
increment=Duration.from_seconds(1),
LinearRetryStrategyConfig(
max_attempts=6,
initial_delay=Duration.from_seconds(1),
increment=Duration.from_seconds(1),
jitter_strategy=JitterStrategy.NONE,
)
)

@classmethod
Expand Down
73 changes: 67 additions & 6 deletions packages/aws-durable-execution-sdk-python/tests/retries_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from aws_durable_execution_sdk_python.config import Duration
from aws_durable_execution_sdk_python.retries import (
JitterStrategy,
LinearRetryStrategyConfig,
RetryDecision,
RetryPresets,
RetryStrategyConfig,
Expand Down Expand Up @@ -580,8 +581,10 @@ def test_mixed_error_types_and_patterns():
# region create_linear_retry_strategy


def test_linear_retry_strategy_uses_additive_formula():
"""Default config yields delays of 1s, 2s, 3s, 4s, 5s with no jitter."""
@patch("random.random")
def test_linear_retry_strategy_uses_additive_formula(mock_random):
"""Default config yields additive delays of 1s, 2s, 3s, 4s, 5s."""
mock_random.return_value = 1.0 # FULL jitter at the upper bound keeps the base
strategy = create_linear_retry_strategy()

delays = [
Expand All @@ -593,7 +596,7 @@ def test_linear_retry_strategy_uses_additive_formula():

def test_linear_retry_strategy_stops_at_max_attempts():
"""No retry once attempts_made reaches max_attempts."""
strategy = create_linear_retry_strategy(max_attempts=3)
strategy = create_linear_retry_strategy(LinearRetryStrategyConfig(max_attempts=3))

assert strategy(Exception("e"), 1).should_retry is True
assert strategy(Exception("e"), 2).should_retry is True
Expand All @@ -603,9 +606,12 @@ def test_linear_retry_strategy_stops_at_max_attempts():
def test_linear_retry_strategy_respects_custom_initial_and_increment():
"""Custom initial_delay and increment shift the additive sequence."""
strategy = create_linear_retry_strategy(
max_attempts=10,
initial_delay=Duration.from_seconds(2),
increment=Duration.from_seconds(3),
LinearRetryStrategyConfig(
max_attempts=10,
initial_delay=Duration.from_seconds(2),
increment=Duration.from_seconds(3),
jitter_strategy=JitterStrategy.NONE,
)
)

delays = [
Expand All @@ -616,6 +622,61 @@ def test_linear_retry_strategy_respects_custom_initial_and_increment():
assert delays == [2, 5, 8, 11]


def test_linear_retry_strategy_caps_at_max_delay():
"""The additive delay is capped at max_delay before jitter."""
strategy = create_linear_retry_strategy(
LinearRetryStrategyConfig(
max_attempts=10,
initial_delay=Duration.from_seconds(10),
increment=Duration.from_seconds(10),
max_delay=Duration.from_seconds(25),
jitter_strategy=JitterStrategy.NONE,
)
)

# 10, 20, then capped at 25 for the third attempt (would be 30).
delays = [
strategy(Exception("e"), attempt).delay_seconds for attempt in range(1, 4)
]
assert delays == [10, 20, 25]


@patch("random.random")
def test_linear_retry_strategy_applies_jitter(mock_random):
"""FULL jitter scales the additive base delay by random()."""
mock_random.return_value = 0.5
strategy = create_linear_retry_strategy(
LinearRetryStrategyConfig(
initial_delay=Duration.from_seconds(4),
increment=Duration.from_seconds(4),
jitter_strategy=JitterStrategy.FULL,
)
)

# base = 4 + 4*1 = 8, full jitter = 0.5 * 8 = 4
assert strategy(Exception("e"), 2).delay_seconds == 4


def test_linear_retry_strategy_filters_by_error_message():
"""Only errors matching retryable_errors are retried."""
strategy = create_linear_retry_strategy(
LinearRetryStrategyConfig(retryable_errors=["timeout"])
)

assert strategy(Exception("connection timeout"), 1).should_retry is True
assert strategy(Exception("permission denied"), 1).should_retry is False


def test_linear_retry_strategy_filters_by_error_type():
"""Only errors matching retryable_error_types are retried."""
strategy = create_linear_retry_strategy(
LinearRetryStrategyConfig(retryable_error_types=[ValueError])
)

assert strategy(ValueError("bad"), 1).should_retry is True
assert strategy(KeyError("missing"), 1).should_retry is False


# endregion


Expand Down
Loading