Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 113 additions & 6 deletions packages/aws-durable-execution-sdk-python-examples/test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,24 @@
import logging
import os
import sys
from datetime import datetime
from enum import StrEnum
from pathlib import Path
from typing import Any

import pytest
from aws_durable_execution_sdk_python.lambda_service import (
ErrorObject,
OperationPayload,
)
from aws_durable_execution_sdk_python.serdes import ExtendedTypeSerDes

from aws_durable_execution_sdk_python_testing.runner import (
DurableFunctionCloudTestRunner,
DurableFunctionTestResult,
DurableFunctionTestRunner,
)

from aws_durable_execution_sdk_python.lambda_service import (
ErrorObject,
OperationPayload,
)
from aws_durable_execution_sdk_python.serdes import ExtendedTypeSerDes


# Add examples/src to Python path for imports
examples_src = Path(__file__).parent.parent / "src"
Expand Down Expand Up @@ -266,3 +267,109 @@ def _get_deployed_function_name(
pytest.skip(
f"Test '{lambda_function_name}' doesn't match LAMBDA_FUNCTION_TEST_NAME '{env_function_name}'"
)


# X-Ray ingestion is eventually consistent; give the backend time to receive and
# index spans before querying, then retry a few times.
_XRAY_QUERY_RETRIES = 3
_XRAY_RETRY_DELAY_SECONDS = 10


class XRaySpanFetcher:
"""Encapsulates all AWS X-Ray interaction for span-validation tests.

Wraps a boto3 X-Ray client and exposes a single high-level operation that
queries trace summaries in a time window (with retries for eventual
consistency), batch-fetches the full traces, and locates the trace whose
segment documents reference a marker span name.
"""

def __init__(self, client: Any):
"""Initialize with a boto3 X-Ray client."""
self._client = client

def _query_trace_summaries(
self, start_time: datetime, end_time: datetime
) -> list[dict]:
"""Query trace summaries in a window, retrying for consistency."""
import time

for attempt in range(_XRAY_QUERY_RETRIES):
response = self._client.get_trace_summaries(
StartTime=start_time,
EndTime=end_time,
TimeRangeType="Event",
Sampling=False,
)
summaries = response.get("TraceSummaries", [])
if summaries:
return summaries

logger.info(
"X-Ray query returned 0 traces, retrying in %ss (attempt %d/%d)",
_XRAY_RETRY_DELAY_SECONDS,
attempt + 1,
_XRAY_QUERY_RETRIES,
)
time.sleep(_XRAY_RETRY_DELAY_SECONDS)
return []

def fetch_trace_with_span(
self,
start_time: datetime,
end_time: datetime,
marker_span: str,
) -> tuple[str, str]:
"""Find the trace containing ``marker_span`` and return its segment text.

Queries trace summaries in the window, then batch-fetches full traces
(X-Ray caps BatchGetTraces at 5 trace IDs per call) and locates the
trace whose segment documents reference the marker span name.

Args:
start_time: Start of the X-Ray query window.
end_time: End of the X-Ray query window.
marker_span: A span name expected to appear in the target trace.

Returns:
A tuple of (trace_id, concatenated segment-document JSON text).
"""
summaries = self._query_trace_summaries(start_time, end_time)
assert summaries, "Expected at least one trace in X-Ray after execution"

trace_ids = [s["Id"] for s in summaries]

for i in range(0, len(trace_ids), 5):
batch = trace_ids[i : i + 5]
result = self._client.batch_get_traces(TraceIds=batch)
for trace in result.get("Traces", []):
documents = [
seg.get("Document", "") for seg in trace.get("Segments", [])
]
segment_text = "\n".join(documents)
if marker_span in segment_text:
return trace["Id"], segment_text

pytest.fail(
f"Did not find a trace containing span '{marker_span}' in the time "
f"window across {len(trace_ids)} trace(s)"
)


@pytest.fixture
def xray_spans(request):
"""Provide an XRaySpanFetcher for cloud-mode span validation tests.

The underlying boto3 X-Ray client is created in the same region as the
cloud runner (AWS_REGION, default us-west-2). In local mode there is no
X-Ray backend, so the fixture skips the test, mirroring the cloud-only
gating of the durable_runner cloud path.
"""
runner_mode: str = request.config.getoption("--runner-mode")
if runner_mode != RunnerMode.CLOUD:
pytest.skip("X-Ray span validation only runs in cloud mode")

import boto3

region = os.environ.get("AWS_REGION", "us-west-2")
return XRaySpanFetcher(boto3.client("xray", region_name=region))
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
"""Tests for the OTel-enriched logger example."""

import time
from datetime import UTC, datetime

import pytest

from aws_durable_execution_sdk_python.execution import InvocationStatus
Expand All @@ -8,6 +11,11 @@
from test.conftest import deserialize_operation_payload


# X-Ray ingestion is eventually consistent; wait before querying so the backend
# has received and indexed the exported spans.
_XRAY_INGESTION_DELAY_SECONDS = 20


@pytest.mark.example
@pytest.mark.durable_execution(
handler=otel_logger_example.handler,
Expand All @@ -30,3 +38,35 @@ def test_otel_logger_example(durable_runner):
op for op in result.operations if op.operation_type is OperationType.CONTEXT
]
assert len(context_ops) >= 1


@pytest.mark.example
@pytest.mark.durable_execution(
handler=otel_logger_example.handler,
lambda_function_name="Otel Logger Example",
)
def test_otel_logger_example_spans_in_xray(durable_runner, xray_spans):
"""Single-invocation example: spans land in one X-Ray trace.

Runs only in cloud mode;
"""
start_time = datetime.now(UTC)

with durable_runner:
result = durable_runner.run(input="{}", timeout=60)

assert result.status is InvocationStatus.SUCCEEDED
assert deserialize_operation_payload(result.result) == "hello world | hello nested"

# Allow X-Ray time to ingest the exported spans.
time.sleep(_XRAY_INGESTION_DELAY_SECONDS)

_trace_id, segment_text = xray_spans.fetch_trace_with_span(
start_time, datetime.now(UTC), marker_span="top-greet"
)

# Expected spans for the single-invocation example.
assert "invocation" in segment_text
assert "top-greet" in segment_text
assert "child-context" in segment_text
assert "child-greet" in segment_text
Original file line number Diff line number Diff line change
@@ -1,12 +1,30 @@
"""Tests for step example."""
"""Tests for the OTel plugin example (execution_with_otel)."""

import time
from datetime import UTC, datetime

import pytest
from aws_durable_execution_sdk_python.execution import InvocationStatus

from aws_durable_execution_sdk_python.execution import InvocationStatus
from src.plugin import execution_with_otel
from test.conftest import deserialize_operation_payload


# X-Ray ingestion is eventually consistent; wait before querying so the backend
# has received and indexed the exported spans.
_XRAY_INGESTION_DELAY_SECONDS = 20


def _count_occurrences(text: str, substring: str) -> int:
"""Count non-overlapping occurrences of ``substring`` in ``text``."""
count = 0
index = 0
while (index := text.find(substring, index)) != -1:
count += 1
index += len(substring)
return count


@pytest.mark.example
@pytest.mark.durable_execution(
handler=execution_with_otel.handler,
Expand All @@ -22,3 +40,43 @@ def test_plugin(durable_runner):

step_result = result.get_step("final-step")
assert deserialize_operation_payload(step_result.result) == 23


@pytest.mark.example
@pytest.mark.durable_execution(
handler=execution_with_otel.handler,
lambda_function_name="Otel Plugin",
)
def test_plugin_spans_in_xray_across_invocations(durable_runner, xray_spans):
"""Multi-invocation example: spans from all invocations share one trace."""
start_time = datetime.now(UTC)

with durable_runner:
result = durable_runner.run(input="{}", timeout=120)

assert result.status is InvocationStatus.SUCCEEDED
assert deserialize_operation_payload(result.result) == 23

# Multi-invocation executions take longer to fully export; give extra time.
time.sleep(_XRAY_INGESTION_DELAY_SECONDS + 5)

trace_id, segment_text = xray_spans.fetch_trace_with_span(
start_time, datetime.now(UTC), marker_span="final-step"
)

# Spans from every child context plus the final top-level step.
for i in range(3):
assert f"context-{i}" in segment_text, f"missing span context-{i}"
assert f"step-{i}" in segment_text, f"missing span step-{i}"
assert f"wait-{i}" in segment_text, f"missing span wait-{i}"
assert "final-step" in segment_text

# The waits force multiple Lambda invocations -> multiple invocation spans.
invocation_count = _count_occurrences(segment_text, "invocation")
assert invocation_count >= 2, (
f"Expected at least 2 invocation spans (multi-invocation), "
f"got {invocation_count}"
)

# All segments belong to one trace -> deterministic trace ID worked.
assert trace_id, "Expected a single unified trace ID across invocations"