From 759e57fd479560ba840da82330815456b3e81dae Mon Sep 17 00:00:00 2001 From: Alex Wang Date: Wed, 17 Jun 2026 13:52:39 -0700 Subject: [PATCH] fix: race condition for id generator --- .../deterministic_id_generator.py | 15 +++- .../tests/test_deterministic_id_generator.py | 77 +++++++++++++++++++ 2 files changed, 89 insertions(+), 3 deletions(-) diff --git a/packages/aws-durable-execution-sdk-python-otel/src/aws_durable_execution_sdk_python_otel/deterministic_id_generator.py b/packages/aws-durable-execution-sdk-python-otel/src/aws_durable_execution_sdk_python_otel/deterministic_id_generator.py index 7b5c2f50..4360d04a 100644 --- a/packages/aws-durable-execution-sdk-python-otel/src/aws_durable_execution_sdk_python_otel/deterministic_id_generator.py +++ b/packages/aws-durable-execution-sdk-python-otel/src/aws_durable_execution_sdk_python_otel/deterministic_id_generator.py @@ -2,6 +2,7 @@ from __future__ import annotations +import contextvars import hashlib import os import re @@ -12,6 +13,12 @@ HASHED_ID_PATTERN = re.compile(r"^[0-9a-f]{16}$") +# Scoping the pending span ID to the execution context ensures concurrent +# operations cannot consume each other's deterministic span ID. +_next_span_id: contextvars.ContextVar[int | None] = contextvars.ContextVar( + "next_span_id", default=None +) + def _parse_xray_root_trace_id(trace_header: str | None) -> str | None: """Parse the Root trace ID from an X-Ray trace header string. @@ -83,7 +90,6 @@ class DeterministicIdGenerator(RandomIdGenerator): """ def __init__(self, fallback_id_generator: IdGenerator | None = None) -> None: - self._next_span_id: int | None = None self._execution_trace_id: int | None = None self._fallback_id_generator = fallback_id_generator or RandomIdGenerator() @@ -92,7 +98,7 @@ def set_next_span_id(self, span_id: int | None) -> None: After one span is created, it resets to random. """ - self._next_span_id = span_id + _next_span_id.set(span_id) def set_trace_id( self, execution_arn: str, start_timestamp: datetime | None @@ -113,5 +119,8 @@ def generate_trace_id(self) -> int: def generate_span_id(self) -> int: """Generate a 64-bit span ID.""" - span_id, self._next_span_id = self._next_span_id, None + span_id = _next_span_id.get() + # Consume once: the deterministic ID applies only to the next span + # created in this context; subsequent spans fall back to random. + _next_span_id.set(None) return span_id or self._fallback_id_generator.generate_span_id() diff --git a/packages/aws-durable-execution-sdk-python-otel/tests/test_deterministic_id_generator.py b/packages/aws-durable-execution-sdk-python-otel/tests/test_deterministic_id_generator.py index 8cb9dc7c..6e725173 100644 --- a/packages/aws-durable-execution-sdk-python-otel/tests/test_deterministic_id_generator.py +++ b/packages/aws-durable-execution-sdk-python-otel/tests/test_deterministic_id_generator.py @@ -2,6 +2,8 @@ from __future__ import annotations +import asyncio +import threading from datetime import UTC, datetime from opentelemetry.sdk.trace import IdGenerator, RandomIdGenerator @@ -203,3 +205,78 @@ def test_deterministic_id_generator_prefers_next_span_id_over_fallback(): assert generator.generate_span_id() == deterministic_span_id # Subsequent calls fall back to the provided generator. assert generator.generate_span_id() == int("b" * 16, 16) + + +def test_pending_span_id_is_isolated_across_threads(): + """Verify a span ID set in one thread is not consumed by another thread. + + The pending span ID is stored in a ContextVar, so each worker thread has + its own value. Without this isolation a concurrent operation could steal + another operation's deterministic span ID, producing the wrong span ID. + """ + random_span_id = int("f" * 16, 16) + fallback = _StubIdGenerator(trace_id=int("a" * 32, 16), span_id=random_span_id) + generator = DeterministicIdGenerator(fallback_id_generator=fallback) + + # The main thread sets a deterministic span ID but never consumes it. + main_deterministic_span_id = int("1" * 16, 16) + generator.set_next_span_id(main_deterministic_span_id) + + barrier = threading.Barrier(2) + results: dict[str, int] = {} + + def worker(name: str, span_id: int) -> None: + # Each worker starts with a fresh context (default None), so it must + # not see the main thread's pending span ID. + barrier.wait() + results[f"{name}-before-set"] = generator.generate_span_id() + generator.set_next_span_id(span_id) + results[f"{name}-after-set"] = generator.generate_span_id() + + worker_a_span_id = int("2" * 16, 16) + worker_b_span_id = int("3" * 16, 16) + thread_a = threading.Thread(target=worker, args=("a", worker_a_span_id)) + thread_b = threading.Thread(target=worker, args=("b", worker_b_span_id)) + thread_a.start() + thread_b.start() + thread_a.join() + thread_b.join() + + # Workers never observed the main thread's value; they fell back to random. + assert results["a-before-set"] == random_span_id + assert results["b-before-set"] == random_span_id + # Each worker consumed only its own deterministic span ID. + assert results["a-after-set"] == worker_a_span_id + assert results["b-after-set"] == worker_b_span_id + # The main thread's pending span ID was untouched by the workers. + assert generator.generate_span_id() == main_deterministic_span_id + + +def test_pending_span_id_is_isolated_across_async_tasks(): + """Verify a span ID set in one async task is not consumed by another. + + Each asyncio task runs with its own copied context, so the pending span ID + stays scoped to the task that set it even across await boundaries on the + same thread. + """ + fallback_span_id = int("e" * 16, 16) + fallback = _StubIdGenerator(trace_id=int("a" * 32, 16), span_id=fallback_span_id) + generator = DeterministicIdGenerator(fallback_id_generator=fallback) + + task_a_span_id = int("4" * 16, 16) + task_b_span_id = int("5" * 16, 16) + + async def task(span_id: int) -> int: + generator.set_next_span_id(span_id) + # Yield control so the other task interleaves between set and consume. + await asyncio.sleep(0) + return generator.generate_span_id() + + async def main() -> tuple[int, int]: + return await asyncio.gather(task(task_a_span_id), task(task_b_span_id)) + + result_a, result_b = asyncio.run(main()) + + # Despite interleaving, each task consumed only its own deterministic ID. + assert result_a == task_a_span_id + assert result_b == task_b_span_id