diff --git a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/concurrency/executor.py b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/concurrency/executor.py index 61bdbb0d..0cdf40e8 100644 --- a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/concurrency/executor.py +++ b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/concurrency/executor.py @@ -59,7 +59,7 @@ class TimerScheduler: """Manage timed suspend tasks with a background timer thread.""" def __init__( - self, resubmit_callback: Callable[[ExecutableWithState], None] + self, resubmit_callback: Callable[[list[ExecutableWithState]], None] ) -> None: self.resubmit_callback = resubmit_callback self._pending_resumes: list[tuple[float, int, ExecutableWithState]] = [] @@ -114,18 +114,31 @@ def _timer_loop(self) -> None: current_time = time.time() if current_time >= next_resume_time: - # Time to resume + # Drain every due resume under the lock, transitioning each to + # PENDING atomically with the pop. Keeping pop+reset_to_pending + # together is required: should_execution_suspend reads branch + # status without this lock, so an item that is removed from the + # heap but still SUSPENDED_WITH_TIMEOUT could trigger a spurious + # parent suspend. + ready: list[ExecutableWithState] = [] with self._lock: - # no branch cover because hard to test reliably - this is a double-safety check if heap mutated - # since the first peek on next_resume_time further up - if ( # pragma: no branch + while ( self._pending_resumes and self._pending_resumes[0][0] <= current_time ): _, _, exe_state = heapq.heappop(self._pending_resumes) if exe_state.can_resume: exe_state.reset_to_pending() - self.resubmit_callback(exe_state) + ready.append(exe_state) + # Resubmit outside the lock. Only the heap pop and the PENDING + # transition need the lock. The checkpoint refresh is a blocking + # network call and the submit hands work to the pool, so running + # them off the lock keeps timed resumes from serializing behind + # the network round trip and keeps the timer thread from + # re-entering this non-reentrant lock when a submitted future + # completes inline and its done-callback calls schedule_resume. + if ready: + self.resubmit_callback(ready) else: # Wait until next resume time wait_time = min(next_resume_time - current_time, 0.1) @@ -169,6 +182,7 @@ def __init__( # Event-driven state tracking for when the executor is done self._completion_event = threading.Event() self._suspend_exception: SuspendExecution | None = None + self._resume_error: Exception | None = None # ExecutionCounters will keep track of completion criteria and on-going counters min_successful = self.completion_config.min_successful or len(self.executables) @@ -222,11 +236,32 @@ def execute( ] self._completion_event.clear() self._suspend_exception = None - - def resubmitter(executable_with_state: ExecutableWithState) -> None: - """Resubmit a timed suspended task.""" - execution_state.create_checkpoint() - submit_task(executable_with_state) + self._resume_error = None + + def resubmitter(ready: list[ExecutableWithState]) -> None: + """Resubmit a wave of timed-suspended tasks. + + One checkpoint refresh serves the whole due wave: the fetch returns + all operations, so every resumed branch reads fresh state. The + refresh only raises when the background checkpoint subsystem has + failed, which is terminal for the whole execution, so record the + error and wake the parent to re-raise it. Catching here keeps the + single timer thread alive so a failure does not strand the other + pending resumes. + """ + try: + execution_state.create_checkpoint() + except Exception as exc: # noqa: BLE001 + # resubmitter runs only on the single timer thread, so this + # check-then-set needs no lock. First error wins: keep the + # earliest failure if several waves fail before execute() reads + # it (they are the same terminal checkpoint failure anyway). + if self._resume_error is None: # pragma: no branch + self._resume_error = exc + self._completion_event.set() + return + for executable_with_state in ready: + submit_task(executable_with_state) thread_executor = ThreadPoolExecutor(max_workers=max_workers) try: @@ -259,6 +294,12 @@ def on_done(future: Future) -> None: for future in futures: future.cancel() + # A timed resume failed to refresh state (terminal checkpoint + # subsystem failure). Re-raise so the invocation fails and the + # backend retries from the last durable checkpoint. + if self._resume_error is not None: + raise self._resume_error + # Suspend execution if everything done and at least one of the tasks raised a suspend exception. if self._suspend_exception: raise self._suspend_exception diff --git a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/state.py b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/state.py index 7fcfadcc..c24bd96d 100644 --- a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/state.py +++ b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/state.py @@ -250,7 +250,7 @@ def __init__( ): self.durable_execution_arn: str = durable_execution_arn self._current_checkpoint_token: str = initial_checkpoint_token - self.operations: MutableMapping[str, Operation] = operations + self._operations: dict[str, Operation] = dict(operations) self._service_client: DurableServiceClient = service_client self._plugin_executor: PluginExecutor = plugin_executor self._ordered_checkpoint_lock: OrderedLock = OrderedLock() @@ -279,6 +279,16 @@ def __init__( self._replay_status_lock: Lock = Lock() self._visited_operations: set[str] = set() + @property + def operations(self) -> dict[str, Operation]: + """Return a point-in-time snapshot copy of the operations map. + + The returned dict is a copy, so mutating it does not affect execution + state and iterating it is safe against concurrent updates. + """ + with self._operations_lock: + return dict(self._operations) + def fetch_paginated_operations( self, initial_operations: list[Operation], @@ -324,7 +334,7 @@ def fetch_paginated_operations( # Always store whatever operations we successfully fetched if all_operations: with self._operations_lock: - self.operations.update( + self._operations.update( {op.operation_id: op for op in all_operations} ) return all_operations @@ -341,7 +351,8 @@ def get_input_payload(self) -> str | None: def get_execution_operation(self) -> Operation | None: # invocation id is id of execution operation invocation_id = self.durable_execution_arn.split("/")[-1] - candidate = self.operations.get(invocation_id) + with self._operations_lock: + candidate = self._operations.get(invocation_id) if not candidate: # Due to payload size limitations we may have an empty operations list. # This will only happen when loading the initial page of results and is @@ -370,19 +381,21 @@ def track_replay(self, operation_id: str) -> None: with self._replay_status_lock: if self._replay_status == ReplayStatus.REPLAY: self._visited_operations.add(operation_id) - completed_ops = { - op_id - for op_id, op in self.operations.items() - if op.operation_type != OperationType.EXECUTION - and op.status - in { - OperationStatus.SUCCEEDED, - OperationStatus.FAILED, - OperationStatus.CANCELLED, - OperationStatus.STOPPED, - OperationStatus.TIMED_OUT, + # Lock order: _replay_status_lock then _operations_lock. + with self._operations_lock: + completed_ops = { + op_id + for op_id, op in self._operations.items() + if op.operation_type != OperationType.EXECUTION + and op.status + in { + OperationStatus.SUCCEEDED, + OperationStatus.FAILED, + OperationStatus.CANCELLED, + OperationStatus.STOPPED, + OperationStatus.TIMED_OUT, + } } - } if completed_ops.issubset(self._visited_operations): logger.debug( "Transitioning from REPLAY to NEW status at operation %s", @@ -404,7 +417,7 @@ def mark_replaying_if_prior_operations_exist(self) -> None: with self._operations_lock: has_prior_operations: bool = any( op.operation_type is not OperationType.EXECUTION - for op in self.operations.values() + for op in self._operations.values() ) if has_prior_operations: @@ -431,7 +444,7 @@ def get_checkpoint_result(self, checkpoint_id: str) -> CheckpointedResult: """ # checking status are deliberately under a lighter non-serialized lock with self._operations_lock: - if checkpoint := self.operations.get(checkpoint_id): + if checkpoint := self._operations.get(checkpoint_id): return CheckpointedResult.create_from_operation(checkpoint) return CHECKPOINT_NOT_FOUND diff --git a/packages/aws-durable-execution-sdk-python/tests/concurrency_test.py b/packages/aws-durable-execution-sdk-python/tests/concurrency_test.py index ef7d7f57..1bfa6318 100644 --- a/packages/aws-durable-execution-sdk-python/tests/concurrency_test.py +++ b/packages/aws-durable-execution-sdk-python/tests/concurrency_test.py @@ -1370,6 +1370,74 @@ def execute_item(self, child_context, executable): executor.execute(execution_state, executor_context) +def test_concurrent_executor_resume_checkpoint_failure_propagates(): + """A resume-time checkpoint refresh failure propagates out of execute(). + + Regression guard: the timer resubmit does a blocking checkpoint refresh. + That refresh only raises when the checkpoint subsystem has failed, which + is terminal. execute() must re-raise it (so the invocation fails and the + backend retries from the last durable checkpoint) rather than leave the + wave PENDING forever - the completion wait has no timeout, so a stranded + PENDING branch would hang the whole map. + """ + + class TestExecutor(ConcurrentExecutor): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.calls: dict[int, int] = {} + self.long_runner_release = threading.Event() + + def execute_item(self, child_context, executable): + task_id = executable.index + self.calls[task_id] = self.calls.get(task_id, 0) + 1 + if task_id == 0: + # Long-runner keeps the map alive so task 1 resumes in-process. + self.long_runner_release.wait(timeout=5) + return "result_A" + # Task 1 suspends with a past timestamp -> immediate in-process resume. + msg = "resume-me" + raise TimedSuspendExecution(msg, time.time() - 1) + + executables = [Executable(0, lambda: "task_A"), Executable(1, lambda: "task_B")] + completion_config = CompletionConfig( + min_successful=2, + tolerated_failure_count=None, + tolerated_failure_percentage=None, + ) + + executor = TestExecutor( + executables=executables, + max_concurrency=2, + completion_config=completion_config, + sub_type_top="TOP", + sub_type_iteration="ITER", + name_prefix="test_", + serdes=None, + ) + + execution_state = Mock() + + def checkpoint(*args, **kwargs): + # The resume refresh calls create_checkpoint() with no arguments. + # Fail that call; leave the branches' own checkpoints as no-ops. + if not args and not kwargs: + msg = "resume refresh failed" + raise RuntimeError(msg) + + execution_state.create_checkpoint = Mock(side_effect=checkpoint) + + executor_context = Mock() + executor_context._create_step_id_for_logical_step = lambda *args: "1" # noqa: SLF001 + child_context = Mock() + child_context.state.wrap_user_function = lambda func, *args, **kwargs: func + executor_context.create_child_context = lambda *args, **kwargs: child_context + + # Must re-raise (not hang): the resume failure surfaces as the original error. + with pytest.raises(RuntimeError, match="resume refresh failed"): + executor.execute(execution_state, executor_context) + executor.long_runner_release.set() + + def test_concurrent_executor_with_timed_resubmit_while_other_task_running(): """Test timed resubmission while other tasks are still running.""" @@ -3200,7 +3268,9 @@ def test_timer_scheduler_fifo_ordering_with_same_timestamp(): items synchronously, so callback order is deterministic. """ results = [] - resubmit_callback = Mock(side_effect=lambda exe: results.append(exe.index)) + resubmit_callback = Mock( + side_effect=lambda batch: results.extend(exe.index for exe in batch) + ) with TimerScheduler(resubmit_callback) as scheduler: # Use a past timestamp so they trigger immediately diff --git a/packages/aws-durable-execution-sdk-python/tests/state_test.py b/packages/aws-durable-execution-sdk-python/tests/state_test.py index 5e7e7fb8..f026d357 100644 --- a/packages/aws-durable-execution-sdk-python/tests/state_test.py +++ b/packages/aws-durable-execution-sdk-python/tests/state_test.py @@ -1397,7 +1397,7 @@ def test_concurrent_access_to_operations_dictionary(): operation_type=OperationType.STEP, status=OperationStatus.SUCCEEDED, ) - state.operations["op1"] = operation + state._operations["op1"] = operation results = [] errors = [] @@ -1422,7 +1422,7 @@ def writer_thread(): status=OperationStatus.SUCCEEDED, ) with state._operations_lock: - state.operations[f"op{i}"] = new_op + state._operations[f"op{i}"] = new_op time.sleep(0.001) except Exception as e: errors.append(e) @@ -4260,3 +4260,87 @@ def test_plugin_executor_not_called_for_pending_operations(): # endregion Plugin Executor Integration Tests + + +def _make_execution_state_for_operations( + mock_lambda_client, *, replay_status=ReplayStatus.NEW, operations=None +): + return ExecutionState( + durable_execution_arn="test_arn", + initial_checkpoint_token="token123", # noqa: S106 + operations=operations or {}, + service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), + replay_status=replay_status, + ) + + +def test_operations_property_returns_snapshot_copy(): + """The operations property exposes a copy; mutating it must not affect state.""" + mock_lambda_client = Mock(spec=LambdaClient) + op = Operation( + operation_id="op1", + operation_type=OperationType.STEP, + status=OperationStatus.SUCCEEDED, + ) + state = _make_execution_state_for_operations( + mock_lambda_client, operations={"op1": op} + ) + + snapshot = state.operations + assert snapshot == {"op1": op} + + snapshot["op2"] = op # mutating the returned copy must not leak into state + assert "op2" not in state.operations + assert len(state.operations) == 1 + + +def test_track_replay_iteration_safe_under_concurrent_update(): + """track_replay must not raise when operations are updated concurrently. + + A worker thread iterates operations inside track_replay while the checkpoint + path updates the same map. Without consistent locking this raises + "dictionary changed size during iteration". + """ + mock_lambda_client = Mock(spec=LambdaClient) + state = _make_execution_state_for_operations( + mock_lambda_client, replay_status=ReplayStatus.REPLAY + ) + # Seed completed operations so track_replay keeps iterating (stays REPLAY). + for i in range(50): + state._operations[f"seed{i}"] = Operation( + operation_id=f"seed{i}", + operation_type=OperationType.STEP, + status=OperationStatus.SUCCEEDED, + ) + + errors: list[Exception] = [] + stop = threading.Event() + + def writer(): + i = 0 + while not stop.is_set(): + with state._operations_lock: + state._operations[f"w{i}"] = Operation( + operation_id=f"w{i}", + operation_type=OperationType.STEP, + status=OperationStatus.SUCCEEDED, + ) + i += 1 + + def reader(): + try: + for _ in range(2000): + state.track_replay(operation_id="probe") + except Exception as e: # noqa: BLE001 + errors.append(e) + + writer_t = threading.Thread(target=writer, daemon=True) + reader_t = threading.Thread(target=reader, daemon=True) + writer_t.start() + reader_t.start() + reader_t.join(timeout=30) + stop.set() + writer_t.join(timeout=5) + + assert not errors, f"track_replay raced with concurrent update: {errors}"