Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 121 additions & 13 deletions tests/integration/_utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
from __future__ import annotations

import asyncio
import inspect
import logging
import secrets
import string
import time
from collections.abc import AsyncIterator, Iterator
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Protocol, TypeVar, overload
from typing import TYPE_CHECKING, Any, Protocol, TypeVar, cast, overload

import pytest

if TYPE_CHECKING:
from collections.abc import Callable, Coroutine
from collections.abc import Awaitable, Callable, Coroutine

logger = logging.getLogger(__name__)

# Environment variable names for test configuration
TOKEN_ENV_VAR = 'APIFY_TEST_USER_API_TOKEN'
Expand Down Expand Up @@ -119,6 +123,107 @@ async def maybe_sleep(seconds: float, *, is_async: bool) -> None:
time.sleep(seconds) # noqa: ASYNC251


async def _maybe_await(value: Awaitable[T] | T) -> T:
"""Await `value` if it is awaitable, otherwise return it unchanged.

Lets `call_with_exp_backoff` and `poll_until_condition` accept both sync and async callables.
"""
if inspect.isawaitable(value):
return await cast('Awaitable[T]', value)
return cast('T', value)


@overload
async def call_with_exp_backoff(
fn: Callable[[], Awaitable[T]],
condition: Callable[[T], bool] = ...,
*,
max_retries: int = ...,
base_delay: float = ...,
) -> T: ...
@overload
async def call_with_exp_backoff(
fn: Callable[[], T],
condition: Callable[[T], bool] = ...,
*,
max_retries: int = ...,
base_delay: float = ...,
) -> T: ...
async def call_with_exp_backoff(
fn: Callable[[], Awaitable[T] | T],
condition: Callable[[T], bool] = bool,
*,
max_retries: int = 5,
base_delay: float = 1.0,
) -> T:
"""Call `fn`, retrying with exponential backoff until `condition(result)` is True.

Calls `fn` and checks whether `condition` holds for its result. If it does not, `fn` is retried up to
`max_retries` times, sleeping `base_delay * 2 ** attempt` seconds before each retry. The last result is
returned regardless of whether the condition was ever satisfied, so the caller can run its own assertion.

This is useful for eventually-consistent APIs where a freshly created resource may take a moment to become
visible. The default condition checks for a truthy result. Pass `max_retries=0` to call `fn` exactly once.

Unlike `poll_until_condition`, the delay between attempts grows exponentially rather than staying constant.
"""
result = await _maybe_await(fn())
for attempt in range(max_retries):
if condition(result):
return result
delay = base_delay * 2**attempt
logger.info(
'Condition not met for %r, retrying in %ss (attempt %d/%d).', result, delay, attempt + 1, max_retries
)
await asyncio.sleep(delay)
result = await _maybe_await(fn())
return result


@overload
async def poll_until_condition(
fn: Callable[[], Awaitable[T]],
condition: Callable[[T], bool] = ...,
*,
timeout: float = ...,
poll_interval: float = ...,
) -> T: ...
@overload
async def poll_until_condition(
fn: Callable[[], T],
condition: Callable[[T], bool] = ...,
*,
timeout: float = ...,
poll_interval: float = ...,
) -> T: ...
async def poll_until_condition(
fn: Callable[[], Awaitable[T] | T],
condition: Callable[[T], bool] = bool,
*,
timeout: float = 5,
poll_interval: float = 1,
) -> T:
"""Poll `fn` until `condition(result)` is True or the timeout expires.

Polls `fn` at `poll_interval`-second intervals until `condition` is satisfied or `timeout` seconds have elapsed.
Returns the last polled result regardless of whether the condition was met, so the caller can run its own
assertion. The default condition checks for a truthy result.

Use this instead of a fixed `asyncio.sleep` when waiting for eventually-consistent state (e.g. a freshly
created resource appearing in a listing) that may take a variable amount of time to propagate. Unlike
`call_with_exp_backoff`, the interval between polls stays constant.
"""
deadline = time.monotonic() + timeout
result = await _maybe_await(fn())
while not condition(result):
remaining = deadline - time.monotonic()
if remaining <= 0:
break
await asyncio.sleep(min(poll_interval, remaining))
result = await _maybe_await(fn())
return result


async def collect_iterate_until_present(
iterator_factory: Callable[[], Iterator[_HasIdT] | AsyncIterator[_HasIdT]],
expected_ids: set[str],
Expand All @@ -132,7 +237,7 @@ async def collect_iterate_until_present(

Handles eventual consistency on listing endpoints: under parallel load a freshly
created resource may not appear in the listing for a short window. Each attempt
builds a fresh iterator via `iterator_factory`, drains it, and breaks early once
builds a fresh iterator via `iterator_factory`, drains it, and stops early once
`expected_ids` is a subset of the collected items' `.id` values. The most recent
collection is returned regardless of whether the condition was met, so the caller
can run its own assertion with a helpful failure message.
Expand All @@ -141,19 +246,17 @@ async def collect_iterate_until_present(
iterator_factory: No-arg callable returning a fresh iterator on each call.
expected_ids: IDs that must all appear in the collected items.
item_type: Asserted to match the runtime type of each yielded item.
is_async: Whether the iterator is async (and so are sleeps).
is_async: Whether the iterator is async.
max_attempts: Maximum number of polling rounds.
interval: Seconds to sleep before each attempt.
interval: Seconds to sleep between attempts.

Returns:
The most recently collected items.
"""
collected: list[_HasIdT] = []
for attempt in range(max_attempts):
if attempt > 0:
await maybe_sleep(interval, is_async=is_async)

async def drain() -> list[_HasIdT]:
iterator = iterator_factory()
collected = []
collected: list[_HasIdT] = []
if is_async:
assert isinstance(iterator, AsyncIterator)
async for item in iterator:
Expand All @@ -164,9 +267,14 @@ async def collect_iterate_until_present(
for item in iterator:
assert isinstance(item, item_type)
collected.append(item)
if expected_ids.issubset(item.id for item in collected):
break
return collected
return collected

return await poll_until_condition(
drain,
lambda collected: expected_ids.issubset(item.id for item in collected),
timeout=max_attempts * interval,
poll_interval=interval,
)


# ============================================================================
Expand Down
Loading