Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import contextvars
import hashlib
import os
import re
Expand All @@ -12,6 +13,12 @@

HASHED_ID_PATTERN = re.compile(r"^[0-9a-f]{16}$")

# Scoping the pending span ID to the execution context ensures concurrent
# operations cannot consume each other's deterministic span ID.
_next_span_id: contextvars.ContextVar[int | None] = contextvars.ContextVar(
"next_span_id", default=None
)


def _parse_xray_root_trace_id(trace_header: str | None) -> str | None:
"""Parse the Root trace ID from an X-Ray trace header string.
Expand Down Expand Up @@ -83,7 +90,6 @@ class DeterministicIdGenerator(RandomIdGenerator):
"""

def __init__(self, fallback_id_generator: IdGenerator | None = None) -> None:
self._next_span_id: int | None = None
self._execution_trace_id: int | None = None
self._fallback_id_generator = fallback_id_generator or RandomIdGenerator()

Expand All @@ -92,7 +98,7 @@ def set_next_span_id(self, span_id: int | None) -> None:

After one span is created, it resets to random.
"""
self._next_span_id = span_id
_next_span_id.set(span_id)

def set_trace_id(
self, execution_arn: str, start_timestamp: datetime | None
Expand All @@ -113,5 +119,8 @@ def generate_trace_id(self) -> int:

def generate_span_id(self) -> int:
"""Generate a 64-bit span ID."""
span_id, self._next_span_id = self._next_span_id, None
span_id = _next_span_id.get()
# Consume once: the deterministic ID applies only to the next span
# created in this context; subsequent spans fall back to random.
_next_span_id.set(None)
return span_id or self._fallback_id_generator.generate_span_id()
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from __future__ import annotations

import asyncio
import threading
from datetime import UTC, datetime

from opentelemetry.sdk.trace import IdGenerator, RandomIdGenerator
Expand Down Expand Up @@ -203,3 +205,78 @@ def test_deterministic_id_generator_prefers_next_span_id_over_fallback():
assert generator.generate_span_id() == deterministic_span_id
# Subsequent calls fall back to the provided generator.
assert generator.generate_span_id() == int("b" * 16, 16)


def test_pending_span_id_is_isolated_across_threads():
"""Verify a span ID set in one thread is not consumed by another thread.

The pending span ID is stored in a ContextVar, so each worker thread has
its own value. Without this isolation a concurrent operation could steal
another operation's deterministic span ID, producing the wrong span ID.
"""
random_span_id = int("f" * 16, 16)
fallback = _StubIdGenerator(trace_id=int("a" * 32, 16), span_id=random_span_id)
generator = DeterministicIdGenerator(fallback_id_generator=fallback)

# The main thread sets a deterministic span ID but never consumes it.
main_deterministic_span_id = int("1" * 16, 16)
generator.set_next_span_id(main_deterministic_span_id)

barrier = threading.Barrier(2)
results: dict[str, int] = {}

def worker(name: str, span_id: int) -> None:
# Each worker starts with a fresh context (default None), so it must
# not see the main thread's pending span ID.
barrier.wait()
results[f"{name}-before-set"] = generator.generate_span_id()
generator.set_next_span_id(span_id)
results[f"{name}-after-set"] = generator.generate_span_id()

worker_a_span_id = int("2" * 16, 16)
worker_b_span_id = int("3" * 16, 16)
thread_a = threading.Thread(target=worker, args=("a", worker_a_span_id))
thread_b = threading.Thread(target=worker, args=("b", worker_b_span_id))
thread_a.start()
thread_b.start()
thread_a.join()
thread_b.join()

# Workers never observed the main thread's value; they fell back to random.
assert results["a-before-set"] == random_span_id
assert results["b-before-set"] == random_span_id
# Each worker consumed only its own deterministic span ID.
assert results["a-after-set"] == worker_a_span_id
assert results["b-after-set"] == worker_b_span_id
# The main thread's pending span ID was untouched by the workers.
assert generator.generate_span_id() == main_deterministic_span_id


def test_pending_span_id_is_isolated_across_async_tasks():
"""Verify a span ID set in one async task is not consumed by another.

Each asyncio task runs with its own copied context, so the pending span ID
stays scoped to the task that set it even across await boundaries on the
same thread.
"""
fallback_span_id = int("e" * 16, 16)
fallback = _StubIdGenerator(trace_id=int("a" * 32, 16), span_id=fallback_span_id)
generator = DeterministicIdGenerator(fallback_id_generator=fallback)

task_a_span_id = int("4" * 16, 16)
task_b_span_id = int("5" * 16, 16)

async def task(span_id: int) -> int:
generator.set_next_span_id(span_id)
# Yield control so the other task interleaves between set and consume.
await asyncio.sleep(0)
return generator.generate_span_id()

async def main() -> tuple[int, int]:
return await asyncio.gather(task(task_a_span_id), task(task_b_span_id))

result_a, result_b = asyncio.run(main())

# Despite interleaving, each task consumed only its own deterministic ID.
assert result_a == task_a_span_id
assert result_b == task_b_span_id