From 77d2ff01416eba21dc73c14367ace3e83d1aa584 Mon Sep 17 00:00:00 2001 From: yaythomas Date: Sat, 13 Jun 2026 08:21:31 +0000 Subject: [PATCH 1/2] fix: concurrency drain timer resumes outside lock A map or parallel stays in the current invocation while at least one branch is still running. When one iteration keeps it alive and the other branches wait on short timers (a wait, a wait_for_condition poll, or a step retry backoff), those branches resume in-process as their timers come due. Before this commit, a single timer thread runs those resumes one at a time and holds its lock across a blocking checkpoint that refreshes state. Every resume costs one network round trip under the lock, and every branch trying to register its next wait queues behind it. When many timers come due together the timer thread falls behind, the invocation reaches its function timeout, and the backend reinvokes, so a map that should finish in seconds runs for minutes across several timed-out invocations. Holding the lock across the submit also allowed a latent self-deadlock, where a branch that finished inline reacquired the same lock on the timer thread through its done-callback. This commit holds the lock only long enough to take all due timers off the queue and mark them pending, then releases it and runs one shared refresh for the whole wave before handing the branches back to the worker pool. One round trip now serves the whole wave instead of one per resume, and new waits no longer queue behind a network call. The take and the mark stay atomic so a branch never looks parked while it is about to resume, which would otherwise suspend the whole operation by mistake. If the refresh fails, which happens only when the checkpoint subsystem has already failed and is terminal, the timer thread records that one error and re-raises it from execute() so the platform retries from the last checkpoint. Closes #473 --- .../concurrency/executor.py | 63 +++++++++++++--- .../tests/concurrency_test.py | 72 ++++++++++++++++++- 2 files changed, 123 insertions(+), 12 deletions(-) diff --git a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/concurrency/executor.py b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/concurrency/executor.py index 61bdbb0d..0cdf40e8 100644 --- a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/concurrency/executor.py +++ b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/concurrency/executor.py @@ -59,7 +59,7 @@ class TimerScheduler: """Manage timed suspend tasks with a background timer thread.""" def __init__( - self, resubmit_callback: Callable[[ExecutableWithState], None] + self, resubmit_callback: Callable[[list[ExecutableWithState]], None] ) -> None: self.resubmit_callback = resubmit_callback self._pending_resumes: list[tuple[float, int, ExecutableWithState]] = [] @@ -114,18 +114,31 @@ def _timer_loop(self) -> None: current_time = time.time() if current_time >= next_resume_time: - # Time to resume + # Drain every due resume under the lock, transitioning each to + # PENDING atomically with the pop. Keeping pop+reset_to_pending + # together is required: should_execution_suspend reads branch + # status without this lock, so an item that is removed from the + # heap but still SUSPENDED_WITH_TIMEOUT could trigger a spurious + # parent suspend. + ready: list[ExecutableWithState] = [] with self._lock: - # no branch cover because hard to test reliably - this is a double-safety check if heap mutated - # since the first peek on next_resume_time further up - if ( # pragma: no branch + while ( self._pending_resumes and self._pending_resumes[0][0] <= current_time ): _, _, exe_state = heapq.heappop(self._pending_resumes) if exe_state.can_resume: exe_state.reset_to_pending() - self.resubmit_callback(exe_state) + ready.append(exe_state) + # Resubmit outside the lock. Only the heap pop and the PENDING + # transition need the lock. The checkpoint refresh is a blocking + # network call and the submit hands work to the pool, so running + # them off the lock keeps timed resumes from serializing behind + # the network round trip and keeps the timer thread from + # re-entering this non-reentrant lock when a submitted future + # completes inline and its done-callback calls schedule_resume. + if ready: + self.resubmit_callback(ready) else: # Wait until next resume time wait_time = min(next_resume_time - current_time, 0.1) @@ -169,6 +182,7 @@ def __init__( # Event-driven state tracking for when the executor is done self._completion_event = threading.Event() self._suspend_exception: SuspendExecution | None = None + self._resume_error: Exception | None = None # ExecutionCounters will keep track of completion criteria and on-going counters min_successful = self.completion_config.min_successful or len(self.executables) @@ -222,11 +236,32 @@ def execute( ] self._completion_event.clear() self._suspend_exception = None - - def resubmitter(executable_with_state: ExecutableWithState) -> None: - """Resubmit a timed suspended task.""" - execution_state.create_checkpoint() - submit_task(executable_with_state) + self._resume_error = None + + def resubmitter(ready: list[ExecutableWithState]) -> None: + """Resubmit a wave of timed-suspended tasks. + + One checkpoint refresh serves the whole due wave: the fetch returns + all operations, so every resumed branch reads fresh state. The + refresh only raises when the background checkpoint subsystem has + failed, which is terminal for the whole execution, so record the + error and wake the parent to re-raise it. Catching here keeps the + single timer thread alive so a failure does not strand the other + pending resumes. + """ + try: + execution_state.create_checkpoint() + except Exception as exc: # noqa: BLE001 + # resubmitter runs only on the single timer thread, so this + # check-then-set needs no lock. First error wins: keep the + # earliest failure if several waves fail before execute() reads + # it (they are the same terminal checkpoint failure anyway). + if self._resume_error is None: # pragma: no branch + self._resume_error = exc + self._completion_event.set() + return + for executable_with_state in ready: + submit_task(executable_with_state) thread_executor = ThreadPoolExecutor(max_workers=max_workers) try: @@ -259,6 +294,12 @@ def on_done(future: Future) -> None: for future in futures: future.cancel() + # A timed resume failed to refresh state (terminal checkpoint + # subsystem failure). Re-raise so the invocation fails and the + # backend retries from the last durable checkpoint. + if self._resume_error is not None: + raise self._resume_error + # Suspend execution if everything done and at least one of the tasks raised a suspend exception. if self._suspend_exception: raise self._suspend_exception diff --git a/packages/aws-durable-execution-sdk-python/tests/concurrency_test.py b/packages/aws-durable-execution-sdk-python/tests/concurrency_test.py index ef7d7f57..1bfa6318 100644 --- a/packages/aws-durable-execution-sdk-python/tests/concurrency_test.py +++ b/packages/aws-durable-execution-sdk-python/tests/concurrency_test.py @@ -1370,6 +1370,74 @@ def execute_item(self, child_context, executable): executor.execute(execution_state, executor_context) +def test_concurrent_executor_resume_checkpoint_failure_propagates(): + """A resume-time checkpoint refresh failure propagates out of execute(). + + Regression guard: the timer resubmit does a blocking checkpoint refresh. + That refresh only raises when the checkpoint subsystem has failed, which + is terminal. execute() must re-raise it (so the invocation fails and the + backend retries from the last durable checkpoint) rather than leave the + wave PENDING forever - the completion wait has no timeout, so a stranded + PENDING branch would hang the whole map. + """ + + class TestExecutor(ConcurrentExecutor): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.calls: dict[int, int] = {} + self.long_runner_release = threading.Event() + + def execute_item(self, child_context, executable): + task_id = executable.index + self.calls[task_id] = self.calls.get(task_id, 0) + 1 + if task_id == 0: + # Long-runner keeps the map alive so task 1 resumes in-process. + self.long_runner_release.wait(timeout=5) + return "result_A" + # Task 1 suspends with a past timestamp -> immediate in-process resume. + msg = "resume-me" + raise TimedSuspendExecution(msg, time.time() - 1) + + executables = [Executable(0, lambda: "task_A"), Executable(1, lambda: "task_B")] + completion_config = CompletionConfig( + min_successful=2, + tolerated_failure_count=None, + tolerated_failure_percentage=None, + ) + + executor = TestExecutor( + executables=executables, + max_concurrency=2, + completion_config=completion_config, + sub_type_top="TOP", + sub_type_iteration="ITER", + name_prefix="test_", + serdes=None, + ) + + execution_state = Mock() + + def checkpoint(*args, **kwargs): + # The resume refresh calls create_checkpoint() with no arguments. + # Fail that call; leave the branches' own checkpoints as no-ops. + if not args and not kwargs: + msg = "resume refresh failed" + raise RuntimeError(msg) + + execution_state.create_checkpoint = Mock(side_effect=checkpoint) + + executor_context = Mock() + executor_context._create_step_id_for_logical_step = lambda *args: "1" # noqa: SLF001 + child_context = Mock() + child_context.state.wrap_user_function = lambda func, *args, **kwargs: func + executor_context.create_child_context = lambda *args, **kwargs: child_context + + # Must re-raise (not hang): the resume failure surfaces as the original error. + with pytest.raises(RuntimeError, match="resume refresh failed"): + executor.execute(execution_state, executor_context) + executor.long_runner_release.set() + + def test_concurrent_executor_with_timed_resubmit_while_other_task_running(): """Test timed resubmission while other tasks are still running.""" @@ -3200,7 +3268,9 @@ def test_timer_scheduler_fifo_ordering_with_same_timestamp(): items synchronously, so callback order is deterministic. """ results = [] - resubmit_callback = Mock(side_effect=lambda exe: results.append(exe.index)) + resubmit_callback = Mock( + side_effect=lambda batch: results.extend(exe.index for exe in batch) + ) with TimerScheduler(resubmit_callback) as scheduler: # Use a past timestamp so they trigger immediately From a4b0b46de53f8e60522e416e791be805d6162117 Mon Sep 17 00:00:00 2001 From: yaythomas Date: Tue, 16 Jun 2026 00:23:18 +0000 Subject: [PATCH 2/2] fix: lock operations access in ExecutionState - Make operations private (_operations); add a read-only snapshot property to preserve the public attribute - Read operations under _operations_lock in track_replay and get_execution_operation, closing a dictionary-changed-size race against the concurrent checkpoint update path - Add regression test for concurrent track_replay and update --- .../aws_durable_execution_sdk_python/state.py | 47 ++++++---- .../tests/state_test.py | 88 ++++++++++++++++++- 2 files changed, 116 insertions(+), 19 deletions(-) diff --git a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/state.py b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/state.py index 7fcfadcc..c24bd96d 100644 --- a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/state.py +++ b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/state.py @@ -250,7 +250,7 @@ def __init__( ): self.durable_execution_arn: str = durable_execution_arn self._current_checkpoint_token: str = initial_checkpoint_token - self.operations: MutableMapping[str, Operation] = operations + self._operations: dict[str, Operation] = dict(operations) self._service_client: DurableServiceClient = service_client self._plugin_executor: PluginExecutor = plugin_executor self._ordered_checkpoint_lock: OrderedLock = OrderedLock() @@ -279,6 +279,16 @@ def __init__( self._replay_status_lock: Lock = Lock() self._visited_operations: set[str] = set() + @property + def operations(self) -> dict[str, Operation]: + """Return a point-in-time snapshot copy of the operations map. + + The returned dict is a copy, so mutating it does not affect execution + state and iterating it is safe against concurrent updates. + """ + with self._operations_lock: + return dict(self._operations) + def fetch_paginated_operations( self, initial_operations: list[Operation], @@ -324,7 +334,7 @@ def fetch_paginated_operations( # Always store whatever operations we successfully fetched if all_operations: with self._operations_lock: - self.operations.update( + self._operations.update( {op.operation_id: op for op in all_operations} ) return all_operations @@ -341,7 +351,8 @@ def get_input_payload(self) -> str | None: def get_execution_operation(self) -> Operation | None: # invocation id is id of execution operation invocation_id = self.durable_execution_arn.split("/")[-1] - candidate = self.operations.get(invocation_id) + with self._operations_lock: + candidate = self._operations.get(invocation_id) if not candidate: # Due to payload size limitations we may have an empty operations list. # This will only happen when loading the initial page of results and is @@ -370,19 +381,21 @@ def track_replay(self, operation_id: str) -> None: with self._replay_status_lock: if self._replay_status == ReplayStatus.REPLAY: self._visited_operations.add(operation_id) - completed_ops = { - op_id - for op_id, op in self.operations.items() - if op.operation_type != OperationType.EXECUTION - and op.status - in { - OperationStatus.SUCCEEDED, - OperationStatus.FAILED, - OperationStatus.CANCELLED, - OperationStatus.STOPPED, - OperationStatus.TIMED_OUT, + # Lock order: _replay_status_lock then _operations_lock. + with self._operations_lock: + completed_ops = { + op_id + for op_id, op in self._operations.items() + if op.operation_type != OperationType.EXECUTION + and op.status + in { + OperationStatus.SUCCEEDED, + OperationStatus.FAILED, + OperationStatus.CANCELLED, + OperationStatus.STOPPED, + OperationStatus.TIMED_OUT, + } } - } if completed_ops.issubset(self._visited_operations): logger.debug( "Transitioning from REPLAY to NEW status at operation %s", @@ -404,7 +417,7 @@ def mark_replaying_if_prior_operations_exist(self) -> None: with self._operations_lock: has_prior_operations: bool = any( op.operation_type is not OperationType.EXECUTION - for op in self.operations.values() + for op in self._operations.values() ) if has_prior_operations: @@ -431,7 +444,7 @@ def get_checkpoint_result(self, checkpoint_id: str) -> CheckpointedResult: """ # checking status are deliberately under a lighter non-serialized lock with self._operations_lock: - if checkpoint := self.operations.get(checkpoint_id): + if checkpoint := self._operations.get(checkpoint_id): return CheckpointedResult.create_from_operation(checkpoint) return CHECKPOINT_NOT_FOUND diff --git a/packages/aws-durable-execution-sdk-python/tests/state_test.py b/packages/aws-durable-execution-sdk-python/tests/state_test.py index 5e7e7fb8..f026d357 100644 --- a/packages/aws-durable-execution-sdk-python/tests/state_test.py +++ b/packages/aws-durable-execution-sdk-python/tests/state_test.py @@ -1397,7 +1397,7 @@ def test_concurrent_access_to_operations_dictionary(): operation_type=OperationType.STEP, status=OperationStatus.SUCCEEDED, ) - state.operations["op1"] = operation + state._operations["op1"] = operation results = [] errors = [] @@ -1422,7 +1422,7 @@ def writer_thread(): status=OperationStatus.SUCCEEDED, ) with state._operations_lock: - state.operations[f"op{i}"] = new_op + state._operations[f"op{i}"] = new_op time.sleep(0.001) except Exception as e: errors.append(e) @@ -4260,3 +4260,87 @@ def test_plugin_executor_not_called_for_pending_operations(): # endregion Plugin Executor Integration Tests + + +def _make_execution_state_for_operations( + mock_lambda_client, *, replay_status=ReplayStatus.NEW, operations=None +): + return ExecutionState( + durable_execution_arn="test_arn", + initial_checkpoint_token="token123", # noqa: S106 + operations=operations or {}, + service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), + replay_status=replay_status, + ) + + +def test_operations_property_returns_snapshot_copy(): + """The operations property exposes a copy; mutating it must not affect state.""" + mock_lambda_client = Mock(spec=LambdaClient) + op = Operation( + operation_id="op1", + operation_type=OperationType.STEP, + status=OperationStatus.SUCCEEDED, + ) + state = _make_execution_state_for_operations( + mock_lambda_client, operations={"op1": op} + ) + + snapshot = state.operations + assert snapshot == {"op1": op} + + snapshot["op2"] = op # mutating the returned copy must not leak into state + assert "op2" not in state.operations + assert len(state.operations) == 1 + + +def test_track_replay_iteration_safe_under_concurrent_update(): + """track_replay must not raise when operations are updated concurrently. + + A worker thread iterates operations inside track_replay while the checkpoint + path updates the same map. Without consistent locking this raises + "dictionary changed size during iteration". + """ + mock_lambda_client = Mock(spec=LambdaClient) + state = _make_execution_state_for_operations( + mock_lambda_client, replay_status=ReplayStatus.REPLAY + ) + # Seed completed operations so track_replay keeps iterating (stays REPLAY). + for i in range(50): + state._operations[f"seed{i}"] = Operation( + operation_id=f"seed{i}", + operation_type=OperationType.STEP, + status=OperationStatus.SUCCEEDED, + ) + + errors: list[Exception] = [] + stop = threading.Event() + + def writer(): + i = 0 + while not stop.is_set(): + with state._operations_lock: + state._operations[f"w{i}"] = Operation( + operation_id=f"w{i}", + operation_type=OperationType.STEP, + status=OperationStatus.SUCCEEDED, + ) + i += 1 + + def reader(): + try: + for _ in range(2000): + state.track_replay(operation_id="probe") + except Exception as e: # noqa: BLE001 + errors.append(e) + + writer_t = threading.Thread(target=writer, daemon=True) + reader_t = threading.Thread(target=reader, daemon=True) + writer_t.start() + reader_t.start() + reader_t.join(timeout=30) + stop.set() + writer_t.join(timeout=5) + + assert not errors, f"track_replay raced with concurrent update: {errors}"