diff --git a/.gitignore b/.gitignore
index 5f96167..aa49fe5 100644
--- a/.gitignore
+++ b/.gitignore
@@ -103,7 +103,7 @@ ipython_config.py
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
-#uv.lock
+uv.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
@@ -187,4 +187,4 @@ examples/checkpoints/
examples/outputs/
.codex/
-openspec/
\ No newline at end of file
+.agents/
\ No newline at end of file
diff --git a/README.md b/README.md
index 6521a97..90503e5 100644
--- a/README.md
+++ b/README.md
@@ -5,7 +5,7 @@ Datafast is a python library for synthetic data generation using llms.
The old dataset-class API has been removed. The canonical package is `datafast`, and the primary model is:
- create records with `Source` or `Seed`
-- transform them with composable steps
+- transform them with composable steps such as `AddUUID`, `Map`, and `Filter`
- call LLMs with `LLMStep`, `Classify`, `Score`, `Compare`, `Rewrite`, or `Extract`
- persist results with `Sink`
@@ -53,7 +53,7 @@ pipeline.run(batch_size=4)
- `Source`: load records from Python lists, files, or Hugging Face datasets
- `Seed`: generate record combinations declaratively
-- `Map`, `FlatMap`, `Filter`, `Group`, `Pair`, `Concat`, `Join`: data operations
+- `AddUUID`, `Map`, `FlatMap`, `Filter`, `Group`, `Pair`, `Concat`, `Join`: data operations
- `LLMStep`: free-form generation
- `Classify`, `Score`, `Compare`, `Rewrite`, `Extract`: higher-level LLM transforms
- `Branch` and `JoinBranches`: multi-path pipelines
@@ -105,6 +105,7 @@ configure_langfuse_tracing()
- `datafast/`: canonical source package
- `examples/scripts/`: runnable pipeline examples
+- `examples/providers/`: direct provider usage examples
- `docs/`: pipeline-first documentation
- `datafast_new_design_document.md`: retained design reference
diff --git a/datafast/__init__.py b/datafast/__init__.py
index 19bd452..5ffcae8 100644
--- a/datafast/__init__.py
+++ b/datafast/__init__.py
@@ -15,12 +15,14 @@
MistralProvider,
OpenRouterProvider,
OllamaProvider,
+ OpenAICompatibleProvider,
openai,
anthropic,
gemini,
mistral,
openrouter,
ollama,
+ openai_compatible,
)
from datafast.logger_config import configure_logger
from datafast.sinks.sink import Sink, JSONLSink, CSVSink, ListSink, ParquetSink, HubSink
@@ -31,7 +33,7 @@
is_langfuse_tracing_enabled,
)
from datafast.transforms.branch import Branch, JoinBranches
-from datafast.transforms.data_ops import Map, FlatMap, Filter, Group, Pair, Concat, Join
+from datafast.transforms.data_ops import AddUUID, Map, FlatMap, Filter, Group, Pair, Concat, Join
from datafast.transforms.llm_eval import Classify, Score, Compare
from datafast.transforms.llm_extract import Extract
from datafast.transforms.llm_step import LLMStep
@@ -64,6 +66,7 @@ def get_version() -> str:
"Seed",
"SeedDimension",
"Sample",
+ "AddUUID",
"Map",
"FlatMap",
"Filter",
@@ -92,12 +95,14 @@ def get_version() -> str:
"MistralProvider",
"OpenRouterProvider",
"OllamaProvider",
+ "OpenAICompatibleProvider",
"openai",
"anthropic",
"gemini",
"mistral",
"openrouter",
"ollama",
+ "openai_compatible",
"configure_logger",
"configure_langfuse_tracing",
"get_version",
diff --git a/datafast/core/runner.py b/datafast/core/runner.py
index 0a28ba2..c7d280f 100644
--- a/datafast/core/runner.py
+++ b/datafast/core/runner.py
@@ -3,6 +3,7 @@
import time
import uuid
from collections import defaultdict
+from dataclasses import dataclass
from typing import TYPE_CHECKING
from loguru import logger
@@ -19,7 +20,6 @@
if TYPE_CHECKING:
from datafast.core.step import Pipeline, Step
- from datafast.llm.provider import LLMProvider
from datafast.transforms.llm_step import LLMStep
@@ -29,6 +29,12 @@ def chunked(iterable: list, size: int):
yield iterable[i : i + size]
+@dataclass
+class _LLMBatchStats:
+ generated: int = 0
+ errors: int = 0
+
+
class Runner:
"""
Execution engine for pipelines.
@@ -218,77 +224,206 @@ def _execute_llm_step(
if self._checkpoint_mgr and not skip_call_ids:
self._checkpoint_mgr.clear_step_file(step_index, step_name)
- completed_in_batch = 0
+ completed_since_checkpoint = 0
errors = 0
generated_total = len(skip_call_ids)
for batch in chunked(calls, self.config.batch_size):
batch_start = time.perf_counter()
- batch_generated = 0
- batch_model_id = batch[0].model_id if batch else "unknown"
-
- for call in batch:
- model = models_map[call.model_id]
- batch_model_id = call.model_id
-
- try:
- result = model.generate(
- call.messages,
- metadata=build_trace_metadata(
- model=model,
- component="pipeline.step",
- trace_name=f"datafast.{step_name}",
- session_id=self._trace_session_id,
- step_name=step_name,
- step_type=step.__class__.__name__,
- record_index=call.record_index,
- prompt_index=call.prompt_index,
- output_index=call.output_index,
- language_code=call.language_code or None,
- call_id=call.call_id,
- ),
- )
- output_record = step.apply_result(call, result, model)
- output_records.append(output_record)
- progress.completed_call_ids.append(call.call_id)
- completed_in_batch += 1
- batch_generated += 1
- generated_total += 1
-
- # Append record immediately to JSONL
- if self._checkpoint_mgr:
- self._checkpoint_mgr.append_record(
- step_index, step_name, output_record
- )
-
- except Exception as e:
- errors += 1
- logger.warning(
- f"LLM call failed | Model: {call.model_id} | "
- f"Call: {call.call_id} | Error: {e}"
- )
+ batch_model_id = (
+ batch[0].model_id
+ if len({call.model_id for call in batch}) == 1
+ else "mixed"
+ )
+ stats = self._execute_llm_batch(
+ step=step,
+ step_name=step_name,
+ step_index=step_index,
+ batch=batch,
+ models_map=models_map,
+ progress=progress,
+ output_records=output_records,
+ )
+ completed_since_checkpoint += stats.generated
+ errors += stats.errors
+ generated_total += stats.generated
batch_duration = time.perf_counter() - batch_start
logger.info(
- f"Generated {batch_generated} samples (total: {generated_total}) | "
+ f"Generated {stats.generated} samples (total: {generated_total}) | "
f"model: {batch_model_id} | duration: {batch_duration:.2f}s"
)
if (
self._checkpoint_mgr
and manifest
- and completed_in_batch >= self.config.checkpoint_every
+ and completed_since_checkpoint >= self.config.checkpoint_every
):
self._checkpoint_mgr.save_llm_progress(
step_index, step_name, progress, output_records
)
- completed_in_batch = 0
+ completed_since_checkpoint = 0
logger.info(
f"LLMStep complete: {len(output_records)} outputs, {errors} errors"
)
return output_records
+ def _execute_llm_batch(
+ self,
+ *,
+ step: "LLMStep",
+ step_name: str,
+ step_index: int,
+ batch: list[LLMCall],
+ models_map: dict[str, object],
+ progress: LLMStepProgress,
+ output_records: list[Record],
+ ) -> _LLMBatchStats:
+ """Execute and apply one runner batch, preserving input order."""
+ batch_results, errors = self._collect_llm_batch_results(
+ step=step,
+ step_name=step_name,
+ batch=batch,
+ models_map=models_map,
+ )
+ stats = self._apply_llm_batch_results(
+ step=step,
+ step_name=step_name,
+ step_index=step_index,
+ batch=batch,
+ batch_results=batch_results,
+ models_map=models_map,
+ progress=progress,
+ output_records=output_records,
+ )
+ stats.errors += errors
+ return stats
+
+ def _collect_llm_batch_results(
+ self,
+ *,
+ step: "LLMStep",
+ step_name: str,
+ batch: list[LLMCall],
+ models_map: dict[str, object],
+ ) -> tuple[list[object | None], int]:
+ batch_results: list[object | None] = [None] * len(batch)
+ errors = 0
+ grouped_indexes: dict[str, list[int]] = defaultdict(list)
+
+ for index, call in enumerate(batch):
+ grouped_indexes[call.model_id].append(index)
+
+ for model_id, indexes in grouped_indexes.items():
+ group_calls = [batch[index] for index in indexes]
+ model = models_map[model_id]
+ try:
+ group_results = self._generate_llm_group(
+ step=step,
+ step_name=step_name,
+ model=model,
+ group_calls=group_calls,
+ )
+ except Exception as e:
+ errors += len(group_calls)
+ self._log_llm_failures(group_calls, e)
+ continue
+
+ for result_index, result in zip(indexes, group_results):
+ batch_results[result_index] = result
+
+ return batch_results, errors
+
+ def _generate_llm_group(
+ self,
+ *,
+ step: "LLMStep",
+ step_name: str,
+ model: object,
+ group_calls: list[LLMCall],
+ ) -> list[object]:
+ group_metadata = [
+ self._build_llm_call_metadata(step, step_name, call, model)
+ for call in group_calls
+ ]
+ if hasattr(model, "generate_batch"):
+ return list(
+ model.generate_batch( # type: ignore[attr-defined]
+ [call.messages for call in group_calls],
+ metadata=group_metadata,
+ )
+ )
+ return [
+ model.generate( # type: ignore[attr-defined]
+ messages=call.messages,
+ metadata=metadata,
+ )
+ for call, metadata in zip(group_calls, group_metadata)
+ ]
+
+ def _apply_llm_batch_results(
+ self,
+ *,
+ step: "LLMStep",
+ step_name: str,
+ step_index: int,
+ batch: list[LLMCall],
+ batch_results: list[object | None],
+ models_map: dict[str, object],
+ progress: LLMStepProgress,
+ output_records: list[Record],
+ ) -> _LLMBatchStats:
+ stats = _LLMBatchStats()
+
+ for call, result in zip(batch, batch_results):
+ if result is None:
+ continue
+ try:
+ output_record = step.apply_result(call, result, models_map[call.model_id])
+ output_records.append(output_record)
+ progress.completed_call_ids.append(call.call_id)
+ stats.generated += 1
+
+ if self._checkpoint_mgr:
+ self._checkpoint_mgr.append_record(
+ step_index, step_name, output_record
+ )
+ except Exception as e:
+ stats.errors += 1
+ self._log_llm_failures([call], e)
+
+ return stats
+
+ def _build_llm_call_metadata(
+ self,
+ step: "LLMStep",
+ step_name: str,
+ call: LLMCall,
+ model: object,
+ ) -> dict[str, object]:
+ return build_trace_metadata(
+ model=model,
+ component="pipeline.step",
+ trace_name=f"datafast.{step_name}",
+ session_id=self._trace_session_id,
+ step_name=step_name,
+ step_type=step.__class__.__name__,
+ record_index=call.record_index,
+ prompt_index=call.prompt_index,
+ output_index=call.output_index,
+ language_code=call.language_code or None,
+ call_id=call.call_id,
+ )
+
+ @staticmethod
+ def _log_llm_failures(calls: list[LLMCall], error: Exception) -> None:
+ for call in calls:
+ logger.warning(
+ f"LLM call failed | Model: {call.model_id} | "
+ f"Call: {call.call_id} | Error: {error}"
+ )
+
def _order_calls(self, calls: list[LLMCall]) -> list[LLMCall]:
"""Order calls according to execution strategy."""
strategy = self.config.llm_strategy
diff --git a/datafast/llm/__init__.py b/datafast/llm/__init__.py
index 725ece6..eba520d 100644
--- a/datafast/llm/__init__.py
+++ b/datafast/llm/__init__.py
@@ -8,12 +8,27 @@
MistralProvider,
OpenRouterProvider,
OllamaProvider,
+ OpenAICompatibleProvider,
openai,
anthropic,
gemini,
mistral,
openrouter,
ollama,
+ openai_compatible,
+)
+from datafast.llm.types import (
+ BatchMode,
+ CacheMode,
+ ContentPart,
+ EndpointMode,
+ Modality,
+ NormalizedResponse,
+ RetryPolicy,
+ StructuredOutputMode,
+ TargetCapabilities,
+ TargetConfig,
+ UnsupportedParamsPolicy,
)
from datafast.llm.parsing import (
OutputParser,
@@ -30,12 +45,25 @@
"MistralProvider",
"OpenRouterProvider",
"OllamaProvider",
+ "OpenAICompatibleProvider",
"openai",
"anthropic",
"gemini",
"mistral",
"openrouter",
"ollama",
+ "openai_compatible",
+ "BatchMode",
+ "CacheMode",
+ "ContentPart",
+ "EndpointMode",
+ "Modality",
+ "NormalizedResponse",
+ "RetryPolicy",
+ "StructuredOutputMode",
+ "TargetCapabilities",
+ "TargetConfig",
+ "UnsupportedParamsPolicy",
"OutputParser",
"TextParser",
"JSONParser",
diff --git a/datafast/llm/capabilities.py b/datafast/llm/capabilities.py
new file mode 100644
index 0000000..14d6b26
--- /dev/null
+++ b/datafast/llm/capabilities.py
@@ -0,0 +1,273 @@
+"""Capability resolution for Datafast LLM targets."""
+
+from __future__ import annotations
+
+from datafast.llm.types import (
+ BatchMode,
+ CacheMode,
+ EndpointMode,
+ Modality,
+ StructuredOutputMode,
+ TargetCapabilities,
+)
+
+
+COMMON_CHAT_PARAMS = frozenset({
+ "temperature",
+ "max_completion_tokens",
+ "timeout",
+})
+
+SAMPLING_CHAT_PARAMS = frozenset({
+ "top_p",
+ "frequency_penalty",
+})
+
+REASONING_PARAMS = frozenset({
+ "thinking",
+ "reasoning_effort",
+})
+
+RESPONSES_PARAMS = frozenset({
+ "temperature",
+ "max_completion_tokens",
+ "timeout",
+ "thinking",
+ "reasoning_effort",
+ "previous_response_id",
+})
+
+
+HOSTED_CHAT = TargetCapabilities(
+ endpoint_modes=frozenset({EndpointMode.CHAT}),
+ default_endpoint_mode=EndpointMode.CHAT,
+ supported_params=COMMON_CHAT_PARAMS | SAMPLING_CHAT_PARAMS,
+ structured_output=StructuredOutputMode.JSON_SCHEMA,
+ batch_mode=BatchMode.LITELLM_BATCH,
+ cache_mode=CacheMode.PROVIDER_PROMPT,
+)
+
+
+OPENAI_RESPONSES = TargetCapabilities(
+ endpoint_modes=frozenset({EndpointMode.CHAT, EndpointMode.RESPONSES}),
+ default_endpoint_mode=EndpointMode.RESPONSES,
+ supported_params=RESPONSES_PARAMS | SAMPLING_CHAT_PARAMS,
+ structured_output=StructuredOutputMode.JSON_SCHEMA,
+ batch_mode=BatchMode.FALLBACK_CONCURRENCY,
+ cache_mode=CacheMode.PROVIDER_PROMPT,
+ supports_reasoning=True,
+)
+
+
+OPENAI_CHAT = TargetCapabilities(
+ endpoint_modes=frozenset({EndpointMode.CHAT, EndpointMode.RESPONSES}),
+ default_endpoint_mode=EndpointMode.CHAT,
+ supported_params=COMMON_CHAT_PARAMS | SAMPLING_CHAT_PARAMS,
+ structured_output=StructuredOutputMode.JSON_SCHEMA,
+ batch_mode=BatchMode.LITELLM_BATCH,
+ cache_mode=CacheMode.PROVIDER_PROMPT,
+)
+
+
+ANTHROPIC_CHAT = TargetCapabilities(
+ endpoint_modes=frozenset({EndpointMode.CHAT}),
+ default_endpoint_mode=EndpointMode.CHAT,
+ supported_params=COMMON_CHAT_PARAMS | REASONING_PARAMS,
+ structured_output=StructuredOutputMode.JSON_SCHEMA,
+ batch_mode=BatchMode.LITELLM_BATCH,
+ cache_mode=CacheMode.PROVIDER_PROMPT,
+ supports_reasoning=True,
+ supports_thinking=True,
+)
+
+
+OPENROUTER_CHAT = TargetCapabilities(
+ endpoint_modes=frozenset({EndpointMode.CHAT}),
+ default_endpoint_mode=EndpointMode.CHAT,
+ supported_params=COMMON_CHAT_PARAMS | SAMPLING_CHAT_PARAMS,
+ modalities=frozenset({Modality.TEXT, Modality.IMAGE}),
+ structured_output=StructuredOutputMode.JSON_SCHEMA,
+ batch_mode=BatchMode.LITELLM_BATCH,
+ cache_mode=CacheMode.ROUTER,
+ notes=(
+ "OpenRouter capabilities remain model and routed-provider dependent.",
+ "Reasoning controls are omitted by default; pass provider_params for "
+ "model-specific OpenRouter/LiteLLM escape hatches.",
+ ),
+)
+
+
+OLLAMA_CHAT = TargetCapabilities(
+ endpoint_modes=frozenset({EndpointMode.CHAT}),
+ default_endpoint_mode=EndpointMode.CHAT,
+ supported_params=COMMON_CHAT_PARAMS | SAMPLING_CHAT_PARAMS,
+ structured_output=StructuredOutputMode.JSON_OBJECT,
+ batch_mode=BatchMode.FALLBACK_CONCURRENCY,
+ cache_mode=CacheMode.LOCAL_KV,
+ no_api_key=True,
+ notes=("Structured output uses Ollama JSON mode plus Datafast validation.",),
+)
+
+
+VLLM_CHAT = TargetCapabilities(
+ endpoint_modes=frozenset({EndpointMode.CHAT, EndpointMode.RESPONSES}),
+ default_endpoint_mode=EndpointMode.CHAT,
+ supported_params=COMMON_CHAT_PARAMS | SAMPLING_CHAT_PARAMS,
+ modalities=frozenset({Modality.TEXT, Modality.IMAGE, Modality.VIDEO}),
+ structured_output=StructuredOutputMode.JSON_SCHEMA,
+ batch_mode=BatchMode.FALLBACK_CONCURRENCY,
+ cache_mode=CacheMode.LOCAL_KV,
+ no_api_key=True,
+ requires_chat_template=True,
+ notes=(
+ "vLLM exposes OpenAI-compatible chat and Responses endpoints, but "
+ "feature coverage remains model and server-version dependent.",
+ "Multimodal support depends on the served model; stable media UUIDs "
+ "can be passed with ContentPart.media_id.",
+ ),
+)
+
+
+LLAMACPP_CHAT = TargetCapabilities(
+ endpoint_modes=frozenset({EndpointMode.CHAT}),
+ default_endpoint_mode=EndpointMode.CHAT,
+ supported_params=COMMON_CHAT_PARAMS | SAMPLING_CHAT_PARAMS,
+ modalities=frozenset({
+ Modality.TEXT,
+ Modality.IMAGE,
+ Modality.AUDIO,
+ Modality.VIDEO,
+ Modality.FILE,
+ }),
+ structured_output=StructuredOutputMode.JSON_SCHEMA,
+ batch_mode=BatchMode.FALLBACK_CONCURRENCY,
+ cache_mode=CacheMode.LOCAL_KV,
+ no_api_key=True,
+ requires_chat_template=True,
+ notes=(
+ "llama.cpp server is OpenAI-compatible for chat, with JSON schema "
+ "support through response_format.",
+ "Multimodal inputs and reasoning controls are model and build dependent; "
+ "use provider_params for llama.cpp-specific extra_body fields.",
+ ),
+)
+
+
+OPENAI_COMPATIBLE_CHAT = TargetCapabilities(
+ endpoint_modes=frozenset({EndpointMode.CHAT, EndpointMode.RESPONSES}),
+ default_endpoint_mode=EndpointMode.CHAT,
+ supported_params=frozenset({"timeout"}),
+ structured_output=StructuredOutputMode.PROMPTED_JSON,
+ batch_mode=BatchMode.FALLBACK_CONCURRENCY,
+ cache_mode=CacheMode.LOCAL_KV,
+ no_api_key=True,
+ requires_chat_template=True,
+ notes=("OpenAI-compatible transport does not imply OpenAI feature support.",),
+)
+
+
+_CATALOG: dict[tuple[str, str], TargetCapabilities] = {
+ ("openai", "gpt-5.5"): OPENAI_RESPONSES,
+ ("openai", "gpt-5.4"): OPENAI_RESPONSES,
+ ("openai", "gpt-5.4-mini"): OPENAI_RESPONSES,
+ ("openai", "gpt-5.4-nano"): OPENAI_RESPONSES,
+ ("anthropic", "claude-sonnet-4-6"): ANTHROPIC_CHAT,
+ ("anthropic", "claude-haiku-4-5"): ANTHROPIC_CHAT,
+ ("gemini", "gemini-2.5-pro"): HOSTED_CHAT,
+ ("gemini", "gemini-3.5-flash"): HOSTED_CHAT,
+ ("gemini", "gemini-3.1-flash-lite"): HOSTED_CHAT,
+ ("mistral", "mistral-medium-3-5"): HOSTED_CHAT,
+ ("mistral", "mistral-large-2512"): HOSTED_CHAT,
+ ("mistral", "mistral-small-2603"): HOSTED_CHAT,
+ ("mistral", "ministral-14b-2512"): OPENAI_COMPATIBLE_CHAT,
+ ("mistral", "ministral-8b-2512"): OPENAI_COMPATIBLE_CHAT,
+ ("mistral", "ministral-3b-2512"): OPENAI_COMPATIBLE_CHAT,
+}
+
+_PROVIDER_DEFAULTS: dict[str, TargetCapabilities] = {
+ "anthropic": ANTHROPIC_CHAT,
+ "gemini": HOSTED_CHAT,
+ "llamacpp": LLAMACPP_CHAT,
+ "mistral": HOSTED_CHAT,
+ "ollama": OLLAMA_CHAT,
+ "openrouter": OPENROUTER_CHAT,
+ "vllm": VLLM_CHAT,
+}
+
+_OPENAI_COMPATIBLE_PROVIDERS = frozenset({
+ "openai_compatible",
+})
+
+
+def resolve_capabilities(
+ provider: str,
+ model_id: str,
+ *,
+ api_base_url: str | None = None,
+ explicit: TargetCapabilities | None = None,
+) -> TargetCapabilities:
+ """Resolve target capabilities with conservative defaults."""
+ if explicit is not None:
+ return explicit
+
+ normalized_provider = provider.lower()
+ normalized_model = model_id.lower()
+
+ catalog_match = _CATALOG.get((normalized_provider, normalized_model))
+ if catalog_match is not None:
+ return catalog_match
+
+ if normalized_provider == "openai":
+ return _resolve_openai_capabilities(normalized_model)
+
+ provider_default = _PROVIDER_DEFAULTS.get(normalized_provider)
+ if provider_default is not None:
+ return provider_default
+
+ if normalized_provider in _OPENAI_COMPATIBLE_PROVIDERS:
+ return OPENAI_COMPATIBLE_CHAT
+
+ if api_base_url:
+ return OPENAI_COMPATIBLE_CHAT
+
+ return _unknown_capabilities()
+
+
+def _resolve_openai_capabilities(model_id: str) -> TargetCapabilities:
+ if _looks_like_openai_reasoning_model(model_id):
+ return OPENAI_RESPONSES
+ return OPENAI_CHAT
+
+
+def _looks_like_openai_reasoning_model(model_id: str) -> bool:
+ return (
+ model_id.startswith("gpt-5")
+ or model_id.startswith("o1")
+ or model_id.startswith("o3")
+ or model_id.startswith("o4")
+ )
+
+
+def _unknown_capabilities() -> TargetCapabilities:
+ return TargetCapabilities(
+ endpoint_modes=frozenset({EndpointMode.CHAT}),
+ default_endpoint_mode=EndpointMode.CHAT,
+ supported_params=frozenset({"timeout"}),
+ structured_output=StructuredOutputMode.PROMPTED_JSON,
+ batch_mode=BatchMode.FALLBACK_CONCURRENCY,
+ notes=("Unknown target; optional Datafast parameters are omitted by default.",),
+ )
+
+
+__all__ = [
+ "ANTHROPIC_CHAT",
+ "HOSTED_CHAT",
+ "LLAMACPP_CHAT",
+ "OLLAMA_CHAT",
+ "OPENAI_CHAT",
+ "OPENAI_COMPATIBLE_CHAT",
+ "OPENAI_RESPONSES",
+ "OPENROUTER_CHAT",
+ "VLLM_CHAT",
+ "resolve_capabilities",
+]
diff --git a/datafast/llm/provider.py b/datafast/llm/provider.py
index 4768b24..4ef0ef8 100644
--- a/datafast/llm/provider.py
+++ b/datafast/llm/provider.py
@@ -1,52 +1,1463 @@
-"""Provider exports for the pipeline-first datafast API."""
-
-from datafast.llms import (
- LLMProvider,
- OpenAIProvider,
- AnthropicProvider,
- GeminiProvider,
- MistralProvider,
- OpenRouterProvider,
- OllamaProvider,
+"""Capability-aware LLM providers for Datafast."""
+
+from __future__ import annotations
+
+import copy
+import os
+import random
+import time
+import traceback
+import warnings
+from concurrent.futures import ThreadPoolExecutor
+from threading import Lock
+from typing import Any, TypeVar
+
+from loguru import logger
+from pydantic import BaseModel
+
+import litellm
+from litellm import exceptions as litellm_exceptions
+
+from datafast.llm.capabilities import resolve_capabilities
+from datafast.llm.types import (
+ BatchMode,
+ ContentPart,
+ EndpointMode,
+ Message,
+ Messages,
+ Modality,
+ NormalizedRequest,
+ NormalizedResponse,
+ RetryPolicy,
+ StructuredOutputMode,
+ TargetCapabilities,
+ TargetConfig,
+ UnsupportedParamsPolicy,
+)
+from datafast.tracing import (
+ build_trace_metadata,
+ load_env_once,
+ maybe_configure_langfuse_tracing,
)
-def openai(model_id: str = "gpt-5-mini-2025-08-07", **kwargs) -> OpenAIProvider:
- """Create an OpenAI provider instance."""
+T = TypeVar("T", bound=BaseModel)
+
+JSON_INSTRUCTIONS = (
+ "\nReturn only valid JSON. Do not include markdown fences. Use double quotes "
+ "for keys and string values, escape internal newlines, and avoid trailing commas."
+)
+
+
+class LLMProvider:
+ """One Datafast provider target resolved to LiteLLM request adapters."""
+
+ def __init__(
+ self,
+ provider: str,
+ model_id: str,
+ *,
+ litellm_provider: str,
+ env_key_name: str | None,
+ endpoint_mode: str | EndpointMode = EndpointMode.AUTO,
+ temperature: float | None = None,
+ max_completion_tokens: int | None = None,
+ max_tokens: int | None = None,
+ thinking: bool | None = None,
+ reasoning_effort: str | None = None,
+ rpm_limit: int | None = None,
+ timeout: float | None = None,
+ api_key: str | None = None,
+ api_base_url: str | None = None,
+ api_base: str | None = None,
+ retry_limit: int | None = None,
+ retry_policy: RetryPolicy | None = None,
+ unsupported_params: str | UnsupportedParamsPolicy = UnsupportedParamsPolicy.WARN,
+ provider_params: dict[str, Any] | None = None,
+ max_concurrent: int = 4,
+ capabilities: TargetCapabilities | None = None,
+ **extra_provider_params: Any,
+ ) -> None:
+ if max_completion_tokens is None and max_tokens is not None:
+ max_completion_tokens = max_tokens
+ if api_base_url is None:
+ api_base_url = api_base
+
+ merged_provider_params = dict(provider_params or {})
+ merged_provider_params.update(extra_provider_params)
+
+ if retry_policy is None:
+ retry_policy = RetryPolicy(
+ max_retries=retry_limit if retry_limit is not None else 3
+ )
+
+ unsupported_policy = _coerce_unsupported_policy(unsupported_params)
+
+ self.config = TargetConfig(
+ provider=provider,
+ model_id=model_id,
+ litellm_provider=litellm_provider,
+ env_key_name=env_key_name,
+ endpoint_mode=_coerce_endpoint_mode(endpoint_mode),
+ temperature=temperature,
+ max_completion_tokens=max_completion_tokens,
+ thinking=thinking,
+ reasoning_effort=reasoning_effort,
+ rpm_limit=rpm_limit,
+ timeout=timeout,
+ api_key=api_key,
+ api_base_url=api_base_url,
+ retry_policy=retry_policy,
+ unsupported_params=unsupported_policy,
+ provider_params=merged_provider_params,
+ max_concurrent=max_concurrent,
+ )
+ self.capabilities = resolve_capabilities(
+ provider,
+ model_id,
+ api_base_url=api_base_url,
+ explicit=capabilities,
+ )
+ self.endpoint_mode = self._resolve_endpoint_mode(self.config.endpoint_mode)
+
+ self.provider_name = provider
+ self.model_id = model_id
+ self.env_key_name = env_key_name
+ self.api_key = api_key or (os.getenv(env_key_name) if env_key_name else None)
+ self.api_base_url = api_base_url
+ self.temperature = temperature
+ self.max_completion_tokens = max_completion_tokens
+ self.reasoning_effort = reasoning_effort
+ self.rpm_limit = rpm_limit
+ self.timeout = timeout
+ self.unsupported_params = unsupported_policy.value
+
+ self._request_timestamps: list[float] = []
+ self._rate_lock = Lock()
+ self._sleep = time.sleep
+ self._configured_common_params = {
+ name
+ for name, value in {
+ "temperature": temperature,
+ "max_completion_tokens": max_completion_tokens,
+ "thinking": thinking,
+ "reasoning_effort": reasoning_effort,
+ "timeout": timeout,
+ }.items()
+ if value is not None
+ }
+
+ load_env_once()
+ maybe_configure_langfuse_tracing(load_env=False)
+ logger.info(
+ "Initialized {} | Model: {} | Endpoint: {}",
+ self.provider_name,
+ self.model_id,
+ self.endpoint_mode.value,
+ )
+
+ def generate(
+ self,
+ prompt: str | list[str] | None = None,
+ messages: Messages | list[Messages] | None = None,
+ response_format: type[T] | None = None,
+ metadata: dict[str, Any] | None = None,
+ previous_response_id: str | None = None,
+ ) -> str | list[str] | T | list[T]:
+ """Generate a single response or ordered batch of responses."""
+ requests, single_input = self._normalize_inputs(
+ prompt=prompt,
+ messages=messages,
+ metadata=metadata,
+ previous_response_id=previous_response_id,
+ response_format=response_format,
+ )
+ try:
+ results = self._generate_requests(requests, response_format=response_format)
+ except ValueError:
+ raise
+ except Exception as exc:
+ error_trace = traceback.format_exc()
+ logger.error(
+ "Generation failed | Provider: {} | Model: {} | Error: {}",
+ self.provider_name,
+ self.model_id,
+ exc,
+ )
+ raise RuntimeError(
+ f"Error generating response with {self.provider_name}:\n{error_trace}"
+ ) from exc
+
+ if single_input:
+ return results[0]
+ return results
+
+ def generate_batch(
+ self,
+ messages: list[Messages],
+ *,
+ response_format: type[T] | None = None,
+ metadata: list[dict[str, Any] | None] | dict[str, Any] | None = None,
+ previous_response_ids: list[str | None] | None = None,
+ ) -> list[str] | list[T]:
+ """Generate an ordered batch from pre-built message lists."""
+ if not messages:
+ return []
+
+ metadata_items = _normalize_metadata(metadata, len(messages))
+ previous_ids = previous_response_ids or [None] * len(messages)
+ if len(previous_ids) != len(messages):
+ raise ValueError("previous_response_ids length must match messages length")
+
+ requests = [
+ NormalizedRequest(
+ messages=self._prepare_messages(
+ item,
+ response_format=response_format,
+ ),
+ metadata=metadata_items[index],
+ previous_response_id=previous_ids[index],
+ )
+ for index, item in enumerate(messages)
+ ]
+ return self._generate_requests(requests, response_format=response_format)
+
+ def generate_response(
+ self,
+ prompt: str | list[str] | None = None,
+ messages: Messages | list[Messages] | None = None,
+ metadata: dict[str, Any] | None = None,
+ previous_response_id: str | None = None,
+ ) -> NormalizedResponse | list[NormalizedResponse]:
+ """Generate response metadata, including LiteLLM reasoning fields when present."""
+ requests, single_input = self._normalize_inputs(
+ prompt=prompt,
+ messages=messages,
+ metadata=metadata,
+ previous_response_id=previous_response_id,
+ response_format=None,
+ )
+ responses = self._generate_normalized_responses(
+ requests,
+ response_format=None,
+ )
+ if single_input:
+ return responses[0]
+ return responses
+
+ def generate_batch_response(
+ self,
+ messages: list[Messages],
+ *,
+ metadata: list[dict[str, Any] | None] | dict[str, Any] | None = None,
+ previous_response_ids: list[str | None] | None = None,
+ ) -> list[NormalizedResponse]:
+ """Generate ordered batch responses with metadata preserved."""
+ if not messages:
+ return []
+
+ metadata_items = _normalize_metadata(metadata, len(messages))
+ previous_ids = previous_response_ids or [None] * len(messages)
+ if len(previous_ids) != len(messages):
+ raise ValueError("previous_response_ids length must match messages length")
+
+ requests = [
+ NormalizedRequest(
+ messages=self._prepare_messages(item, response_format=None),
+ metadata=metadata_items[index],
+ previous_response_id=previous_ids[index],
+ )
+ for index, item in enumerate(messages)
+ ]
+ return self._generate_normalized_responses(requests, response_format=None)
+
+ def _generate_requests(
+ self,
+ requests: list[NormalizedRequest],
+ *,
+ response_format: type[T] | None,
+ ) -> list[str] | list[T]:
+ responses = self._generate_normalized_responses(
+ requests,
+ response_format=response_format,
+ )
+ return [
+ self._parse_response(response, response_format=response_format)
+ for response in responses
+ ]
+
+ def _generate_normalized_responses(
+ self,
+ requests: list[NormalizedRequest],
+ *,
+ response_format: type[T] | None,
+ ) -> list[NormalizedResponse]:
+ if not requests:
+ return []
+
+ if len(requests) == 1:
+ return [self._execute_single(requests[0], response_format=response_format)]
+
+ if (
+ self.endpoint_mode == EndpointMode.CHAT
+ and self.capabilities.batch_mode == BatchMode.LITELLM_BATCH
+ ):
+ return self._execute_litellm_batch(
+ requests,
+ response_format=response_format,
+ )
+
+ warnings.warn(
+ (
+ f"{self.provider_name}/{self.model_id} does not expose native "
+ "same-target batching for this endpoint. Falling back to bounded "
+ "parallel single requests."
+ ),
+ UserWarning,
+ stacklevel=2,
+ )
+ with ThreadPoolExecutor(
+ max_workers=max(1, min(self.config.max_concurrent, len(requests)))
+ ) as executor:
+ responses = list(
+ executor.map(
+ lambda request: self._execute_single(
+ request,
+ response_format=response_format,
+ ),
+ requests,
+ )
+ )
+ return responses
+
+ def _execute_single(
+ self,
+ request: NormalizedRequest,
+ *,
+ response_format: type[T] | None,
+ ) -> NormalizedResponse:
+ if self.endpoint_mode == EndpointMode.RESPONSES:
+ params = self._build_responses_params(request, response_format)
+ response = self._call_litellm(
+ litellm.responses,
+ params,
+ request_count=1,
+ )
+ return NormalizedResponse(
+ text=_extract_responses_text(response),
+ raw=response,
+ reasoning_content=_extract_responses_reasoning(response),
+ images=_extract_responses_images(response),
+ audio=_extract_responses_audio(response),
+ output_items=_extract_responses_output_items(response),
+ )
+
+ params = self._build_chat_params(request, response_format)
+ response = self._call_litellm(
+ litellm.completion,
+ params,
+ request_count=1,
+ )
+ return NormalizedResponse(
+ text=_extract_chat_text(response),
+ raw=response,
+ reasoning_content=_extract_chat_reasoning_content(response),
+ thinking_blocks=_extract_chat_thinking_blocks(response),
+ images=_extract_chat_images(response),
+ audio=_extract_chat_audio(response),
+ )
+
+ def _execute_litellm_batch(
+ self,
+ requests: list[NormalizedRequest],
+ *,
+ response_format: type[T] | None,
+ ) -> list[NormalizedResponse]:
+ params = self._build_chat_params(
+ NormalizedRequest(
+ messages=[],
+ metadata=_combine_batch_metadata(requests),
+ ),
+ response_format,
+ )
+ params["messages"] = [request.messages for request in requests]
+ response = self._call_litellm(
+ litellm.batch_completion,
+ params,
+ request_count=len(requests),
+ )
+ if not isinstance(response, list):
+ response = list(response)
+
+ normalized: list[NormalizedResponse] = []
+ for index, item in enumerate(response):
+ if isinstance(item, Exception):
+ raise RuntimeError(f"Batch item {index} failed: {item}") from item
+ normalized.append(
+ NormalizedResponse(
+ text=_extract_chat_text(item),
+ raw=item,
+ reasoning_content=_extract_chat_reasoning_content(item),
+ thinking_blocks=_extract_chat_thinking_blocks(item),
+ images=_extract_chat_images(item),
+ audio=_extract_chat_audio(item),
+ )
+ )
+ return normalized
+
+ def _build_chat_params(
+ self,
+ request: NormalizedRequest,
+ response_format: type[T] | None,
+ ) -> dict[str, Any]:
+ params: dict[str, Any] = {
+ "model": self._get_model_string(),
+ "messages": request.messages,
+ "metadata": self._build_request_metadata(request.metadata),
+ }
+ if request.previous_response_id is not None:
+ self._add_supported_param(
+ params,
+ "previous_response_id",
+ request.previous_response_id,
+ endpoint=EndpointMode.CHAT,
+ )
+ self._add_transport_params(params, endpoint=EndpointMode.CHAT)
+ self._add_common_generation_params(params, endpoint=EndpointMode.CHAT)
+ self._add_chat_structured_output(params, response_format)
+ params.update(self.config.provider_params)
+ return _without_none(params)
+
+ def _build_responses_params(
+ self,
+ request: NormalizedRequest,
+ response_format: type[T] | None,
+ ) -> dict[str, Any]:
+ params: dict[str, Any] = {
+ "model": self._get_model_string(),
+ "input": request.messages,
+ "metadata": self._build_request_metadata(request.metadata),
+ }
+ if request.previous_response_id is not None:
+ self._add_supported_param(
+ params,
+ "previous_response_id",
+ request.previous_response_id,
+ endpoint=EndpointMode.RESPONSES,
+ )
+ self._add_transport_params(params, endpoint=EndpointMode.RESPONSES)
+ self._add_common_generation_params(params, endpoint=EndpointMode.RESPONSES)
+ self._add_responses_structured_output(params, response_format)
+ params.update(self.config.provider_params)
+ return _without_none(params)
+
+ def _add_common_generation_params(
+ self,
+ params: dict[str, Any],
+ *,
+ endpoint: EndpointMode,
+ ) -> None:
+ self._add_supported_param(
+ params,
+ "temperature",
+ self.config.temperature,
+ endpoint=endpoint,
+ )
+
+ token_param = (
+ "max_output_tokens"
+ if endpoint == EndpointMode.RESPONSES
+ else "max_completion_tokens"
+ )
+ self._add_supported_param(
+ params,
+ "max_completion_tokens",
+ self.config.max_completion_tokens,
+ endpoint=endpoint,
+ target_name=token_param,
+ )
+
+ if self.config.thinking is False:
+ return
+
+ effort = self.config.reasoning_effort
+ if effort is None and self.config.thinking is True:
+ effort = "low"
+
+ if endpoint == EndpointMode.RESPONSES and effort is not None:
+ self._add_supported_param(
+ params,
+ "reasoning_effort",
+ {"effort": effort},
+ endpoint=endpoint,
+ target_name="reasoning",
+ )
+ return
+
+ self._add_supported_param(
+ params,
+ "reasoning_effort",
+ effort,
+ endpoint=endpoint,
+ )
+
+ def _add_chat_structured_output(
+ self,
+ params: dict[str, Any],
+ response_format: type[T] | None,
+ ) -> None:
+ if response_format is None:
+ return
+
+ mode = self.capabilities.structured_output
+ if mode == StructuredOutputMode.JSON_SCHEMA:
+ params["response_format"] = response_format
+ elif mode == StructuredOutputMode.JSON_OBJECT:
+ params["response_format"] = {"type": "json_object"}
+ if self.provider_name == "ollama":
+ params["format"] = "json"
+ elif mode == StructuredOutputMode.PROMPTED_JSON:
+ warnings.warn(
+ (
+ f"{self.provider_name}/{self.model_id} has no declared native "
+ "schema support. Using prompted JSON plus Pydantic validation."
+ ),
+ UserWarning,
+ stacklevel=3,
+ )
+ else:
+ raise ValueError(
+ f"{self.provider_name}/{self.model_id} does not support structured output"
+ )
+
+ def _add_responses_structured_output(
+ self,
+ params: dict[str, Any],
+ response_format: type[T] | None,
+ ) -> None:
+ if response_format is None:
+ return
+
+ if self.capabilities.structured_output != StructuredOutputMode.JSON_SCHEMA:
+ raise ValueError(
+ f"{self.provider_name}/{self.model_id} does not support native "
+ "Responses structured output"
+ )
+ params["text_format"] = response_format
+
+ def _add_transport_params(
+ self,
+ params: dict[str, Any],
+ *,
+ endpoint: EndpointMode,
+ ) -> None:
+ if self.config.timeout is not None:
+ self._add_supported_param(
+ params,
+ "timeout",
+ self.config.timeout,
+ endpoint=endpoint,
+ )
+ if self.api_base_url is not None:
+ params["api_base"] = self.api_base_url
+ if self.api_key is not None:
+ params["api_key"] = self.api_key
+ elif self.env_key_name and not self.capabilities.no_api_key:
+ env_key = os.getenv(self.env_key_name)
+ if env_key:
+ params["api_key"] = env_key
+ else:
+ raise ValueError(
+ f"{self.env_key_name} environment variable not set. "
+ "Set it or provide api_key when initializing the provider."
+ )
+
+ def _add_supported_param(
+ self,
+ params: dict[str, Any],
+ source_name: str,
+ value: Any,
+ *,
+ endpoint: EndpointMode,
+ target_name: str | None = None,
+ ) -> None:
+ if value is None:
+ return
+
+ if source_name not in self.capabilities.supported_params:
+ if (
+ source_name in self._configured_common_params
+ or source_name == "previous_response_id"
+ or (
+ source_name == "reasoning_effort"
+ and self.config.thinking is True
+ )
+ ):
+ self._handle_unsupported_param(source_name)
+ return
+
+ if source_name == "reasoning_effort" and not self.capabilities.supports_reasoning:
+ self._handle_unsupported_param(source_name)
+ return
+
+ if endpoint == EndpointMode.RESPONSES and not self.capabilities.supports_endpoint(
+ EndpointMode.RESPONSES
+ ):
+ self._handle_unsupported_param(source_name)
+ return
+
+ params[target_name or source_name] = value
+
+ def _handle_unsupported_param(self, name: str) -> None:
+ message = (
+ f"Parameter '{name}' is not supported by resolved target "
+ f"{self.provider_name}/{self.model_id} and will be omitted."
+ )
+ if self.config.unsupported_params == UnsupportedParamsPolicy.FAIL:
+ raise ValueError(message)
+ if self.config.unsupported_params == UnsupportedParamsPolicy.WARN:
+ warnings.warn(message, UserWarning, stacklevel=3)
+
+ def _normalize_inputs(
+ self,
+ *,
+ prompt: str | list[str] | None,
+ messages: Messages | list[Messages] | None,
+ metadata: dict[str, Any] | None,
+ previous_response_id: str | None,
+ response_format: type[T] | None,
+ ) -> tuple[list[NormalizedRequest], bool]:
+ if prompt is None and messages is None:
+ raise ValueError("Either prompt or messages must be provided")
+ if prompt is not None and messages is not None:
+ raise ValueError("Provide either prompt or messages, not both")
+
+ single_input = False
+ batch_messages: list[Messages]
+
+ if prompt is not None:
+ if isinstance(prompt, str):
+ batch_messages = [[{"role": "user", "content": prompt}]]
+ single_input = True
+ elif isinstance(prompt, list) and all(isinstance(item, str) for item in prompt):
+ if not prompt:
+ raise ValueError("prompt list cannot be empty")
+ batch_messages = [
+ [{"role": "user", "content": item}]
+ for item in prompt
+ ]
+ else:
+ raise ValueError("prompt must be a string or list of strings")
+ elif _is_single_messages(messages):
+ batch_messages = [messages] # type: ignore[list-item]
+ single_input = True
+ elif _is_batch_messages(messages):
+ batch_messages = messages # type: ignore[assignment]
+ if not batch_messages:
+ raise ValueError("messages cannot be empty")
+ else:
+ raise ValueError("Invalid messages format")
+
+ return (
+ [
+ NormalizedRequest(
+ messages=self._prepare_messages(
+ item,
+ response_format=response_format,
+ ),
+ metadata=metadata,
+ previous_response_id=previous_response_id,
+ )
+ for item in batch_messages
+ ],
+ single_input,
+ )
+
+ def _prepare_messages(
+ self,
+ messages: Messages,
+ *,
+ response_format: type[T] | None,
+ ) -> Messages:
+ if not messages:
+ raise ValueError("messages cannot be empty")
+
+ normalized = [_normalize_message(message) for message in copy.deepcopy(messages)]
+ self._validate_modalities(normalized)
+
+ if response_format is not None and self.capabilities.structured_output in {
+ StructuredOutputMode.JSON_OBJECT,
+ StructuredOutputMode.PROMPTED_JSON,
+ }:
+ _append_json_instructions(normalized)
+
+ return normalized
+
+ def _validate_modalities(self, messages: Messages) -> None:
+ supported = self.capabilities.modalities
+ for message in messages:
+ content = message.get("content")
+ if not isinstance(content, list):
+ continue
+ for part in content:
+ modality = _modality_for_part(part)
+ if modality not in supported:
+ raise ValueError(
+ f"Modality '{modality.value}' is not supported by "
+ f"{self.provider_name}/{self.model_id}"
+ )
+
+ def _resolve_endpoint_mode(self, endpoint_mode: EndpointMode) -> EndpointMode:
+ if endpoint_mode == EndpointMode.AUTO:
+ return self.capabilities.default_endpoint_mode
+ if not self.capabilities.supports_endpoint(endpoint_mode):
+ raise ValueError(
+ f"{self.provider_name}/{self.model_id} does not support "
+ f"endpoint_mode='{endpoint_mode.value}'"
+ )
+ return endpoint_mode
+
+ def _call_litellm(self, func, params: dict[str, Any], *, request_count: int) -> Any:
+ try:
+ return self._call_with_retries(
+ lambda: func(**params),
+ request_count=request_count,
+ )
+ except Exception as exc:
+ if not self._should_retry_with_drop_params(exc, params):
+ raise
+
+ retry_params = dict(params)
+ retry_params["drop_params"] = True
+ if self.config.unsupported_params == UnsupportedParamsPolicy.WARN:
+ warnings.warn(
+ (
+ "LiteLLM rejected one or more request parameters as "
+ "unsupported. Retrying once with drop_params=True because "
+ f"unsupported_params='{self.config.unsupported_params.value}'."
+ ),
+ UserWarning,
+ stacklevel=3,
+ )
+ return self._call_with_retries(
+ lambda: func(**retry_params),
+ request_count=request_count,
+ )
+
+ def _should_retry_with_drop_params(
+ self,
+ exc: Exception,
+ params: dict[str, Any],
+ ) -> bool:
+ if self.config.unsupported_params == UnsupportedParamsPolicy.FAIL:
+ return False
+ if params.get("drop_params") is True:
+ return False
+ return _is_unsupported_params_error(exc)
+
+ def _call_with_retries(self, func, *, request_count: int) -> Any:
+ retry_policy = self.config.retry_policy
+ attempts = max(1, retry_policy.max_retries)
+
+ for attempt in range(attempts):
+ self._respect_rate_limit(request_count)
+ try:
+ response = func()
+ self._record_requests(request_count)
+ return response
+ except Exception as exc:
+ if attempt >= attempts - 1 or not _is_retryable_error(exc):
+ raise
+ delay = min(
+ retry_policy.max_delay,
+ retry_policy.base_delay * (2 ** attempt),
+ )
+ if retry_policy.jitter > 0:
+ delay += random.uniform(0, delay * retry_policy.jitter)
+ logger.warning(
+ "Retryable LLM error | Provider: {} | Model: {} | "
+ "Attempt: {}/{} | Waiting: {:.2f}s | Error: {}",
+ self.provider_name,
+ self.model_id,
+ attempt + 1,
+ attempts,
+ delay,
+ exc,
+ )
+ self._sleep(delay)
+
+ raise RuntimeError("unreachable retry state")
+
+ def _respect_rate_limit(self, request_count: int = 1) -> None:
+ if self.config.rpm_limit is None:
+ return
+
+ with self._rate_lock:
+ now = time.monotonic()
+ self._request_timestamps = [
+ timestamp
+ for timestamp in self._request_timestamps
+ if now - timestamp < 60
+ ]
+
+ while len(self._request_timestamps) + request_count > self.config.rpm_limit:
+ earliest = self._request_timestamps[0]
+ sleep_time = max(0.0, 60 - (now - earliest))
+ if sleep_time > 0:
+ logger.warning(
+ "Rate limit reached | Provider: {} | Model: {} | "
+ "Waiting {:.2f}s",
+ self.provider_name,
+ self.model_id,
+ sleep_time,
+ )
+ self._sleep(sleep_time)
+ now = time.monotonic()
+ self._request_timestamps = [
+ timestamp
+ for timestamp in self._request_timestamps
+ if now - timestamp < 60
+ ]
+
+ def _record_requests(self, request_count: int = 1) -> None:
+ if self.config.rpm_limit is None:
+ return
+ with self._rate_lock:
+ now = time.monotonic()
+ self._request_timestamps.extend([now] * request_count)
+
+ def _parse_response(
+ self,
+ response: NormalizedResponse,
+ *,
+ response_format: type[T] | None,
+ ) -> str | T:
+ if response_format is None:
+ return response.text.strip() if response.text else response.text
+
+ parsed = getattr(response.raw, "output_parsed", None)
+ if parsed is not None:
+ return parsed
+
+ content = self._strip_code_fences(response.text)
+ try:
+ return response_format.model_validate_json(content)
+ except Exception as validation_error:
+ content_preview = (
+ content[:200] + "..." if len(content) > 200 else content
+ )
+ raise ValueError(
+ f"Failed to parse JSON response into {response_format.__name__}.\n"
+ f"Validation error: {validation_error}\n"
+ f"Content received (first 200 chars):\n{content_preview}"
+ ) from validation_error
+
+ def _build_request_metadata(
+ self,
+ metadata: dict[str, Any] | None = None,
+ ) -> dict[str, Any]:
+ return build_trace_metadata(
+ model=self,
+ component="provider.generate",
+ trace_name=f"datafast.{self.provider_name}",
+ metadata=metadata,
+ )
+
+ def _get_model_string(self) -> str:
+ prefix = f"{self.config.litellm_provider}/"
+ if self.model_id.startswith(prefix):
+ return self.model_id
+ return f"{prefix}{self.model_id}"
+
+ @staticmethod
+ def _strip_code_fences(content: str) -> str:
+ if not content:
+ return content
+
+ content = content.strip()
+ if content.startswith("```"):
+ first_newline = content.find("\n")
+ content = content[first_newline + 1 :] if first_newline != -1 else content[3:]
+ if content.endswith("```"):
+ content = content[:-3]
+ return content.strip()
+
+
+class OpenAIProvider(LLMProvider):
+ def __init__(self, model_id: str = "gpt-5.5", **kwargs: Any) -> None:
+ super().__init__(
+ "openai",
+ model_id,
+ litellm_provider="openai",
+ env_key_name="OPENAI_API_KEY",
+ **kwargs,
+ )
+
+
+class AnthropicProvider(LLMProvider):
+ def __init__(self, model_id: str = "claude-haiku-4-5", **kwargs: Any) -> None:
+ super().__init__(
+ "anthropic",
+ model_id,
+ litellm_provider="anthropic",
+ env_key_name="ANTHROPIC_API_KEY",
+ **kwargs,
+ )
+
+
+class GeminiProvider(LLMProvider):
+ def __init__(self, model_id: str = "gemini-3.1-flash-lite", **kwargs: Any) -> None:
+ super().__init__(
+ "gemini",
+ model_id,
+ litellm_provider="gemini",
+ env_key_name="GEMINI_API_KEY",
+ **kwargs,
+ )
+
+
+class MistralProvider(LLMProvider):
+ def __init__(self, model_id: str = "mistral-small-2603", **kwargs: Any) -> None:
+ super().__init__(
+ "mistral",
+ model_id,
+ litellm_provider="mistral",
+ env_key_name="MISTRAL_API_KEY",
+ **kwargs,
+ )
+
+
+class OpenRouterProvider(LLMProvider):
+ def __init__(self, model_id: str = "openai/gpt-5.4-mini", **kwargs: Any) -> None:
+ super().__init__(
+ "openrouter",
+ model_id,
+ litellm_provider="openrouter",
+ env_key_name="OPENROUTER_API_KEY",
+ **kwargs,
+ )
+
+
+class OllamaProvider(LLMProvider):
+ def __init__(self, model_id: str = "gemma3:4b", **kwargs: Any) -> None:
+ super().__init__(
+ "ollama",
+ model_id,
+ litellm_provider="ollama_chat",
+ env_key_name=None,
+ **kwargs,
+ )
+
+
+class OpenAICompatibleProvider(LLMProvider):
+ def __init__(
+ self,
+ model_id: str,
+ *,
+ provider: str = "openai_compatible",
+ litellm_provider: str = "openai",
+ env_key_name: str | None = None,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(
+ provider,
+ model_id,
+ litellm_provider=litellm_provider,
+ env_key_name=env_key_name,
+ **kwargs,
+ )
+
+
+def openai(model_id: str = "gpt-5.5", **kwargs: Any) -> OpenAIProvider:
return OpenAIProvider(model_id=model_id, **kwargs)
-def anthropic(
- model_id: str = "claude-haiku-4-5-20251001",
- **kwargs,
-) -> AnthropicProvider:
- """Create an Anthropic provider instance."""
+def anthropic(model_id: str = "claude-haiku-4-5", **kwargs: Any) -> AnthropicProvider:
return AnthropicProvider(model_id=model_id, **kwargs)
-def gemini(model_id: str = "gemini-2.0-flash", **kwargs) -> GeminiProvider:
- """Create a Gemini provider instance."""
+def gemini(model_id: str = "gemini-3.1-flash-lite", **kwargs: Any) -> GeminiProvider:
return GeminiProvider(model_id=model_id, **kwargs)
-def mistral(model_id: str = "mistral-small-latest", **kwargs) -> MistralProvider:
- """Create a Mistral provider instance."""
+def mistral(model_id: str = "mistral-small-2603", **kwargs: Any) -> MistralProvider:
return MistralProvider(model_id=model_id, **kwargs)
def openrouter(
- model_id: str = "openai/gpt-5-mini",
- **kwargs,
+ model_id: str = "openai/gpt-5.4-mini",
+ **kwargs: Any,
) -> OpenRouterProvider:
- """Create an OpenRouter provider instance."""
return OpenRouterProvider(model_id=model_id, **kwargs)
-def ollama(model_id: str = "gemma3:4b", **kwargs) -> OllamaProvider:
- """Create an Ollama provider instance."""
+def ollama(model_id: str = "gemma3:4b", **kwargs: Any) -> OllamaProvider:
return OllamaProvider(model_id=model_id, **kwargs)
+def openai_compatible(
+ model_id: str,
+ *,
+ api_base_url: str | None = None,
+ backend: str = "openai_compatible",
+ **kwargs: Any,
+) -> OpenAICompatibleProvider:
+ provider = _normalize_openai_compatible_backend(backend)
+ return OpenAICompatibleProvider(
+ model_id=model_id,
+ provider=provider,
+ api_base_url=api_base_url,
+ **kwargs,
+ )
+
+
+def _normalize_openai_compatible_backend(value: str) -> str:
+ normalized = value.strip().lower().replace("-", "_")
+ aliases = {
+ "openai-compatible": "openai_compatible",
+ "openai_compatible": "openai_compatible",
+ "llama.cpp": "llamacpp",
+ "llama_cpp": "llamacpp",
+ "llamacpp": "llamacpp",
+ "vllm": "vllm",
+ }
+ try:
+ return aliases[normalized]
+ except KeyError as exc:
+ valid = ", ".join(sorted(set(aliases.values())))
+ raise ValueError(
+ f"Unsupported OpenAI-compatible backend '{value}'. Choose: {valid}"
+ ) from exc
+
+
+def _coerce_endpoint_mode(value: str | EndpointMode) -> EndpointMode:
+ if isinstance(value, EndpointMode):
+ return value
+ try:
+ return EndpointMode(value)
+ except ValueError as exc:
+ raise ValueError("endpoint_mode must be 'auto', 'chat', or 'responses'") from exc
+
+
+def _coerce_unsupported_policy(
+ value: str | UnsupportedParamsPolicy,
+) -> UnsupportedParamsPolicy:
+ if isinstance(value, UnsupportedParamsPolicy):
+ return value
+ try:
+ return UnsupportedParamsPolicy(value)
+ except ValueError as exc:
+ raise ValueError("unsupported_params must be 'fail', 'warn', or 'quiet'") from exc
+
+
+def _normalize_metadata(
+ metadata: list[dict[str, Any] | None] | dict[str, Any] | None,
+ expected_length: int,
+) -> list[dict[str, Any] | None]:
+ if isinstance(metadata, list):
+ if len(metadata) != expected_length:
+ raise ValueError("metadata length must match messages length")
+ return metadata
+ return [metadata] * expected_length
+
+
+def _combine_batch_metadata(requests: list[NormalizedRequest]) -> dict[str, Any]:
+ metadata_items = [request.metadata for request in requests]
+ return {
+ "datafast_batch_size": len(requests),
+ "datafast_batch_metadata": metadata_items,
+ }
+
+
+def _is_single_messages(value: Any) -> bool:
+ return isinstance(value, list) and bool(value) and isinstance(value[0], dict)
+
+
+def _is_batch_messages(value: Any) -> bool:
+ return isinstance(value, list) and bool(value) and isinstance(value[0], list)
+
+
+def _normalize_message(message: Message) -> Message:
+ if not isinstance(message, dict):
+ raise ValueError("Each message must be a dictionary")
+
+ normalized = dict(message)
+ content = normalized.get("content")
+ if isinstance(content, list):
+ normalized["content"] = [_normalize_content_part(part) for part in content]
+ elif content is not None and not isinstance(content, str):
+ raise ValueError("message content must be a string, list of parts, or None")
+ return normalized
+
+
+def _normalize_content_part(part: Any) -> dict[str, Any]:
+ part = _content_part_to_dict(part)
+ part_type = part.get("type")
+
+ normalizers = {
+ "text": _normalize_text_part,
+ "image": _normalize_image_part,
+ "audio": _normalize_audio_part,
+ "video": _normalize_video_part,
+ "file": _normalize_file_part,
+ "document": _normalize_file_part,
+ }
+ if part_type in {"image_url", "input_audio", "video_url"}:
+ return _without_none(part)
+ if part_type in normalizers:
+ return normalizers[part_type](part)
+ return part
+
+
+def _content_part_to_dict(part: Any) -> dict[str, Any]:
+ if isinstance(part, ContentPart):
+ part = {
+ "type": part.type,
+ "text": part.text,
+ "url": part.url,
+ "data": part.data,
+ "media_type": part.media_type,
+ "media_id": part.media_id,
+ **part.provider_options,
+ }
+
+ if not isinstance(part, dict):
+ raise ValueError("content parts must be dictionaries or ContentPart objects")
+ return part
+
+
+def _normalize_text_part(part: dict[str, Any]) -> dict[str, Any]:
+ return _without_none({"type": "text", "text": part.get("text")})
+
+
+def _normalize_image_part(part: dict[str, Any]) -> dict[str, Any]:
+ image_url: dict[str, Any] = {"url": part.get("url") or part.get("data")}
+ if part.get("format") or part.get("media_type"):
+ image_url["format"] = part.get("format") or part.get("media_type")
+ if part.get("detail"):
+ image_url["detail"] = part["detail"]
+ normalized = {"type": "image_url", "image_url": _without_none(image_url)}
+ if part.get("media_id"):
+ normalized["uuid"] = part["media_id"]
+ return normalized
+
+
+def _normalize_audio_part(part: dict[str, Any]) -> dict[str, Any]:
+ return {
+ "type": "input_audio",
+ "input_audio": _without_none({
+ "data": part.get("data"),
+ "format": part.get("format") or part.get("media_type") or "wav",
+ }),
+ }
+
+
+def _normalize_video_part(part: dict[str, Any]) -> dict[str, Any]:
+ normalized = {"type": "video_url", "video_url": {"url": part.get("url")}}
+ if part.get("media_id"):
+ normalized["uuid"] = part["media_id"]
+ return normalized
+
+
+def _normalize_file_part(part: dict[str, Any]) -> dict[str, Any]:
+ if isinstance(part.get("file"), dict):
+ file_payload = part["file"]
+ elif part.get("data"):
+ file_payload = {"file_data": part.get("data")}
+ else:
+ file_payload = {"file_id": part.get("url")}
+ return {"type": "file", "file": _without_none(file_payload)}
+
+
+def _modality_for_part(part: dict[str, Any]) -> Modality:
+ part_type = part.get("type")
+ if part_type == "text":
+ return Modality.TEXT
+ if part_type in {"image", "image_url"}:
+ return Modality.IMAGE
+ if part_type in {"audio", "input_audio"}:
+ return Modality.AUDIO
+ if part_type in {"video", "video_url"}:
+ return Modality.VIDEO
+ if part_type == "file":
+ return Modality.FILE
+ if part_type == "document":
+ return Modality.DOCUMENT
+ return Modality.TEXT
+
+
+def _append_json_instructions(messages: Messages) -> None:
+ for message in reversed(messages):
+ if message.get("role") != "user":
+ continue
+ content = message.get("content")
+ if isinstance(content, str):
+ message["content"] = content + JSON_INSTRUCTIONS
+ return
+ if isinstance(content, list):
+ for part in reversed(content):
+ if part.get("type") == "text" and isinstance(part.get("text"), str):
+ part["text"] = part["text"] + JSON_INSTRUCTIONS
+ return
+ messages.append({"role": "user", "content": JSON_INSTRUCTIONS.strip()})
+
+
+def _extract_chat_text(response: Any) -> str:
+ choice = _get_first_choice(response)
+ if choice is None:
+ raise RuntimeError(
+ f"Unexpected chat response from LiteLLM: {type(response).__name__}"
+ )
+
+ message = _get_attr_or_key(choice, "message")
+ if message is None:
+ text = _get_attr_or_key(choice, "text")
+ return "" if text is None else str(text)
+
+ content = _get_attr_or_key(message, "content")
+ return _content_to_text(content)
+
+
+def _extract_chat_reasoning_content(response: Any) -> str | None:
+ message = _extract_chat_message(response)
+ if message is None:
+ return None
+
+ reasoning_content = _get_attr_or_key(message, "reasoning_content")
+ if reasoning_content is None:
+ reasoning_content = _get_attr_or_key(message, "reasoning")
+ if reasoning_content is None:
+ return None
+ if isinstance(reasoning_content, list):
+ return _content_to_text(reasoning_content).strip() or None
+ return str(reasoning_content).strip() or None
+
+
+def _extract_chat_thinking_blocks(response: Any) -> list[dict[str, Any]]:
+ message = _extract_chat_message(response)
+ if message is None:
+ return []
+
+ blocks = _get_attr_or_key(message, "thinking_blocks")
+ if not blocks:
+ return []
+ if not isinstance(blocks, list):
+ blocks = [blocks]
+ return [_normalize_mapping_block(block) for block in blocks]
+
+
+def _extract_chat_images(response: Any) -> list[dict[str, Any]]:
+ message = _extract_chat_message(response)
+ if message is None:
+ return []
+
+ images = _get_attr_or_key(message, "images")
+ collected = list(_normalize_optional_list(images))
+
+ content = _get_attr_or_key(message, "content")
+ for part in _normalize_optional_list(content):
+ part_type = _get_attr_or_key(part, "type")
+ if part_type in {"image", "image_url", "output_image"}:
+ collected.append(part)
+
+ return [_normalize_mapping_block(image) for image in collected]
+
+
+def _extract_chat_audio(response: Any) -> dict[str, Any] | None:
+ message = _extract_chat_message(response)
+ if message is None:
+ return None
+
+ audio = _get_attr_or_key(message, "audio")
+ if audio:
+ return _normalize_mapping_block(audio)
+
+ content = _get_attr_or_key(message, "content")
+ for part in _normalize_optional_list(content):
+ part_type = _get_attr_or_key(part, "type")
+ if part_type in {"audio", "output_audio"}:
+ return _normalize_mapping_block(part)
+ return None
+
+
+def _extract_chat_message(response: Any) -> Any:
+ choice = _get_first_choice(response)
+ if choice is None:
+ return None
+ return _get_attr_or_key(choice, "message")
+
+
+def _extract_responses_text(response: Any) -> str:
+ output_text = _get_attr_or_key(response, "output_text")
+ if output_text:
+ return str(output_text)
+
+ output = _normalize_optional_list(_get_attr_or_key(response, "output"))
+ texts: list[str] = []
+ for item in output:
+ content = _get_attr_or_key(item, "content") or []
+ if isinstance(content, str):
+ texts.append(content)
+ continue
+ for part in _normalize_optional_list(content):
+ part_type = _get_attr_or_key(part, "type")
+ if part_type in {"output_text", "text"}:
+ text = _get_attr_or_key(part, "text")
+ if text is not None:
+ texts.append(str(text))
+ if texts:
+ return "".join(texts)
+ if output:
+ return ""
+ raise RuntimeError(
+ f"Unexpected Responses API response from LiteLLM: {type(response).__name__}"
+ )
+
+
+def _extract_responses_reasoning(response: Any) -> str | None:
+ reasoning_content = _get_attr_or_key(response, "reasoning_content")
+ if reasoning_content:
+ return str(reasoning_content).strip() or None
+
+ output = _normalize_optional_list(_get_attr_or_key(response, "output"))
+ texts: list[str] = []
+ for item in output:
+ item_type = _get_attr_or_key(item, "type")
+ if item_type != "reasoning":
+ continue
+
+ for field_name in ("text", "content"):
+ value = _get_attr_or_key(item, field_name)
+ if value:
+ texts.append(_content_to_text(value))
+
+ summary = _get_attr_or_key(item, "summary") or []
+ if isinstance(summary, str):
+ texts.append(summary)
+ continue
+ for part in _normalize_optional_list(summary):
+ text = _get_attr_or_key(part, "text") or _get_attr_or_key(part, "content")
+ if text:
+ texts.append(_content_to_text(text))
+
+ joined = "\n".join(text.strip() for text in texts if text and text.strip())
+ return joined or None
+
+
+def _extract_responses_output_items(response: Any) -> list[dict[str, Any]]:
+ output = _get_attr_or_key(response, "output") or []
+ return [_normalize_mapping_block(item) for item in _normalize_optional_list(output)]
+
+
+def _extract_responses_images(response: Any) -> list[dict[str, Any]]:
+ images: list[Any] = []
+ for item in _normalize_optional_list(_get_attr_or_key(response, "output")):
+ item_type = _get_attr_or_key(item, "type")
+ if item_type in {"image", "output_image", "image_generation_call"}:
+ images.append(item)
+ for part in _normalize_optional_list(_get_attr_or_key(item, "content")):
+ part_type = _get_attr_or_key(part, "type")
+ if part_type in {"image", "image_url", "output_image"}:
+ images.append(part)
+ return [_normalize_mapping_block(image) for image in images]
+
+
+def _extract_responses_audio(response: Any) -> dict[str, Any] | None:
+ for item in _normalize_optional_list(_get_attr_or_key(response, "output")):
+ item_type = _get_attr_or_key(item, "type")
+ if item_type in {"audio", "output_audio"}:
+ return _normalize_mapping_block(item)
+ for part in _normalize_optional_list(_get_attr_or_key(item, "content")):
+ part_type = _get_attr_or_key(part, "type")
+ if part_type in {"audio", "output_audio"}:
+ return _normalize_mapping_block(part)
+ return None
+
+
+def _get_first_choice(response: Any) -> Any:
+ choices = _get_attr_or_key(response, "choices")
+ if not choices:
+ return None
+ return choices[0]
+
+
+def _content_to_text(content: Any) -> str:
+ if content is None:
+ return ""
+ if isinstance(content, str):
+ return content
+ if isinstance(content, list):
+ texts = []
+ for part in content:
+ text = _get_attr_or_key(part, "text")
+ if text is not None:
+ texts.append(str(text))
+ return "".join(texts)
+ return str(content)
+
+
+def _get_attr_or_key(value: Any, name: str) -> Any:
+ if isinstance(value, dict):
+ return value.get(name)
+ return getattr(value, name, None)
+
+
+def _normalize_optional_list(value: Any) -> list[Any]:
+ if value is None:
+ return []
+ if isinstance(value, str):
+ return []
+ if isinstance(value, list):
+ return value
+ return [value]
+
+
+def _normalize_mapping_block(value: Any) -> dict[str, Any]:
+ if isinstance(value, dict):
+ return dict(value)
+
+ if hasattr(value, "model_dump"):
+ dumped = value.model_dump()
+ if isinstance(dumped, dict):
+ return dumped
+
+ if hasattr(value, "dict"):
+ dumped = value.dict()
+ if isinstance(dumped, dict):
+ return dumped
+
+ result: dict[str, Any] = {}
+ for name in ("type", "text", "thinking", "content", "signature"):
+ attr = getattr(value, name, None)
+ if attr is not None:
+ result[name] = attr
+ if result:
+ return result
+ return {"content": str(value)}
+
+
+def _without_none(values: dict[str, Any]) -> dict[str, Any]:
+ return {key: value for key, value in values.items() if value is not None}
+
+
+def _is_retryable_error(exc: Exception) -> bool:
+ retryable_types = (
+ litellm_exceptions.RateLimitError,
+ litellm_exceptions.APIConnectionError,
+ litellm_exceptions.Timeout,
+ litellm_exceptions.InternalServerError,
+ litellm_exceptions.ServiceUnavailableError,
+ )
+ return isinstance(exc, retryable_types)
+
+
+def _is_unsupported_params_error(exc: Exception) -> bool:
+ unsupported_type = getattr(litellm_exceptions, "UnsupportedParamsError", None)
+ if unsupported_type is not None and isinstance(exc, unsupported_type):
+ return True
+ return exc.__class__.__name__ == "UnsupportedParamsError"
+
+
__all__ = [
"LLMProvider",
"OpenAIProvider",
@@ -55,10 +1466,12 @@ def ollama(model_id: str = "gemma3:4b", **kwargs) -> OllamaProvider:
"MistralProvider",
"OpenRouterProvider",
"OllamaProvider",
+ "OpenAICompatibleProvider",
"openai",
"anthropic",
"gemini",
"mistral",
"openrouter",
"ollama",
+ "openai_compatible",
]
diff --git a/datafast/llm/types.py b/datafast/llm/types.py
new file mode 100644
index 0000000..c8f260b
--- /dev/null
+++ b/datafast/llm/types.py
@@ -0,0 +1,151 @@
+"""Shared types for Datafast LLM provider targets."""
+
+from __future__ import annotations
+
+from dataclasses import dataclass, field
+from enum import Enum
+from typing import Any, Literal
+
+
+Message = dict[str, Any]
+Messages = list[Message]
+
+ContentPartType = Literal["text", "image", "audio", "video", "file", "document"]
+
+
+class EndpointMode(str, Enum):
+ AUTO = "auto"
+ CHAT = "chat"
+ RESPONSES = "responses"
+
+
+class UnsupportedParamsPolicy(str, Enum):
+ FAIL = "fail"
+ WARN = "warn"
+ QUIET = "quiet"
+
+
+class StructuredOutputMode(str, Enum):
+ NONE = "none"
+ PROMPTED_JSON = "prompted_json"
+ JSON_OBJECT = "json_object"
+ JSON_SCHEMA = "json_schema"
+
+
+class BatchMode(str, Enum):
+ NONE = "none"
+ LITELLM_BATCH = "litellm_batch"
+ FALLBACK_CONCURRENCY = "fallback_concurrency"
+
+
+class CacheMode(str, Enum):
+ NONE = "none"
+ PROVIDER_PROMPT = "provider_prompt"
+ ROUTER = "router"
+ LOCAL_KV = "local_kv"
+ CLIENT_RESULT = "client_result"
+
+
+class Modality(str, Enum):
+ TEXT = "text"
+ IMAGE = "image"
+ AUDIO = "audio"
+ VIDEO = "video"
+ FILE = "file"
+ DOCUMENT = "document"
+
+
+@dataclass(frozen=True)
+class RetryPolicy:
+ max_retries: int = 3
+ base_delay: float = 1.0
+ max_delay: float = 30.0
+ jitter: float = 0.25
+
+
+@dataclass(frozen=True)
+class TargetCapabilities:
+ endpoint_modes: frozenset[EndpointMode]
+ default_endpoint_mode: EndpointMode
+ supported_params: frozenset[str] = frozenset()
+ modalities: frozenset[Modality] = frozenset({Modality.TEXT})
+ structured_output: StructuredOutputMode = StructuredOutputMode.PROMPTED_JSON
+ batch_mode: BatchMode = BatchMode.FALLBACK_CONCURRENCY
+ cache_mode: CacheMode = CacheMode.NONE
+ supports_reasoning: bool = False
+ supports_thinking: bool = False
+ no_api_key: bool = False
+ requires_chat_template: bool = False
+ notes: tuple[str, ...] = ()
+
+ def supports_endpoint(self, endpoint_mode: EndpointMode) -> bool:
+ return endpoint_mode in self.endpoint_modes
+
+
+@dataclass(frozen=True)
+class TargetConfig:
+ provider: str
+ model_id: str
+ litellm_provider: str
+ env_key_name: str | None
+ endpoint_mode: EndpointMode = EndpointMode.AUTO
+ temperature: float | None = None
+ max_completion_tokens: int | None = None
+ thinking: bool | None = None
+ reasoning_effort: str | None = None
+ rpm_limit: int | None = None
+ timeout: float | None = None
+ api_key: str | None = None
+ api_base_url: str | None = None
+ retry_policy: RetryPolicy = field(default_factory=RetryPolicy)
+ unsupported_params: UnsupportedParamsPolicy = UnsupportedParamsPolicy.WARN
+ provider_params: dict[str, Any] = field(default_factory=dict)
+ max_concurrent: int = 4
+
+
+@dataclass(frozen=True)
+class NormalizedRequest:
+ messages: Messages
+ metadata: dict[str, Any] | None = None
+ previous_response_id: str | None = None
+
+
+@dataclass(frozen=True)
+class NormalizedResponse:
+ text: str
+ raw: Any
+ reasoning_content: str | None = None
+ thinking_blocks: list[dict[str, Any]] = field(default_factory=list)
+ images: list[dict[str, Any]] = field(default_factory=list)
+ audio: dict[str, Any] | None = None
+ output_items: list[dict[str, Any]] = field(default_factory=list)
+
+
+@dataclass(frozen=True)
+class ContentPart:
+ type: ContentPartType
+ text: str | None = None
+ url: str | None = None
+ data: str | None = None
+ media_type: str | None = None
+ media_id: str | None = None
+ provider_options: dict[str, Any] = field(default_factory=dict)
+
+
+__all__ = [
+ "BatchMode",
+ "CacheMode",
+ "ContentPart",
+ "ContentPartType",
+ "EndpointMode",
+ "Message",
+ "Messages",
+ "Modality",
+ "NormalizedRequest",
+ "NormalizedResponse",
+ "RetryPolicy",
+ "StructuredOutputMode",
+ "TargetCapabilities",
+ "TargetConfig",
+ "UnsupportedParamsPolicy",
+]
diff --git a/datafast/llm_utils.py b/datafast/llm_utils.py
index 18890cd..9aa2fb3 100644
--- a/datafast/llm_utils.py
+++ b/datafast/llm_utils.py
@@ -1,3 +1,8 @@
+from __future__ import annotations
+
+from collections.abc import Sequence
+
+
def get_messages(prompt: str, system_message: str = "You are a helpful assistant.") -> list[dict[str, str]]:
"""Convert a single prompt into a message list format expected by LLM APIs.
@@ -12,3 +17,48 @@ def get_messages(prompt: str, system_message: str = "You are a helpful assistant
{"role": "system", "content": system_message},
{"role": "user", "content": prompt},
]
+
+
+def format_generated_responses(
+ prompts: str | Sequence[str],
+ responses: str | Sequence[str],
+) -> str:
+ """Return a readable string for one or many prompt/response pairs."""
+ prompt_items = [prompts] if isinstance(prompts, str) else list(prompts)
+ response_items = [responses] if isinstance(responses, str) else list(responses)
+
+ if len(prompt_items) != len(response_items):
+ raise ValueError("prompts and responses must have the same length")
+
+ sections = [
+ _format_response_section(prompt, response, index, total=len(prompt_items))
+ for index, (prompt, response) in enumerate(
+ zip(prompt_items, response_items, strict=True),
+ start=1,
+ )
+ ]
+ return "\n\n".join(sections)
+
+
+def _format_response_section(
+ prompt: str,
+ response: str,
+ index: int,
+ *,
+ total: int,
+) -> str:
+ lines = []
+ if total > 1:
+ lines.append(f"Example {index}")
+ lines.extend(
+ [
+ "Prompt",
+ "------",
+ prompt,
+ "",
+ "Response",
+ "--------",
+ response,
+ ]
+ )
+ return "\n".join(lines)
diff --git a/datafast/llms.py b/datafast/llms.py
index 092346a..3478e30 100644
--- a/datafast/llms.py
+++ b/datafast/llms.py
@@ -1,892 +1,28 @@
-"""LLM providers for datafast using LiteLLM.
+"""Compatibility exports for Datafast LLM providers.
-This module provides classes for different LLM providers (OpenAI, Anthropic, Gemini, Mistral)
-with a unified interface using LiteLLM under the hood.
+The implementation lives in :mod:`datafast.llm.provider`.
"""
-from typing import Any, Type, TypeVar
-from abc import ABC, abstractmethod
-import os
-import time
-import traceback
-import warnings
-from loguru import logger
-
-# Pydantic
-from pydantic import BaseModel
-
-# LiteLLM
-import litellm
-from litellm.exceptions import RateLimitError
-from litellm.utils import ModelResponse
-
-# Internal imports
-from .llm_utils import get_messages
-from .tracing import (
- build_trace_metadata,
- load_env_once,
- maybe_configure_langfuse_tracing,
+from datafast.llm.provider import (
+ LLMProvider,
+ AnthropicProvider,
+ GeminiProvider,
+ MistralProvider,
+ OllamaProvider,
+ OpenAICompatibleProvider,
+ OpenAIProvider,
+ OpenRouterProvider,
+ anthropic,
+ gemini,
+ mistral,
+ ollama,
+ openai,
+ openai_compatible,
+ openrouter,
)
+from datafast.tracing import load_env_once, maybe_configure_langfuse_tracing
-# Type aliases for Python 3.10+
-Message = dict[str, str]
-Messages = list[Message]
-T = TypeVar('T', bound=BaseModel)
-
-
-class LLMProvider(ABC):
- """Abstract base class for LLM providers."""
-
- def __init__(
- self,
- model_id: str,
- api_key: str | None = None,
- temperature: float | None = None,
- max_completion_tokens: int | None = None,
- top_p: float | None = None,
- frequency_penalty: float | None = None,
- rpm_limit: int | None = None,
- timeout: int | None = None,
- ):
- """Initialize the LLM provider with common parameters.
-
- Args:
- model_id: The model identifier
- api_key: API key (if None, will get from environment)
- temperature: The sampling temperature to be used, between 0 and 2. Higher values like 0.8 produce more random outputs, while lower values like 0.2 make outputs more focused and deterministic
- max_completion_tokens: An upper bound for the number of tokens that can be generated for a completion, including visible output tokens and reasoning tokens.
- top_p: Nucleus sampling parameter (0.0 to 1.0)
- frequency_penalty: Penalty for token frequency (-2.0 to 2.0)
- """
- self.model_id = model_id
- load_env_once()
- maybe_configure_langfuse_tracing(load_env=False)
- self.api_key = api_key or self._get_api_key()
-
- # Set generation parameters
- self.temperature = temperature
- self.max_completion_tokens = max_completion_tokens
- self.top_p = top_p
- self.frequency_penalty = frequency_penalty
-
- # Rate limiting
- self.rpm_limit = rpm_limit
- self._request_timestamps: list[float] = []
-
- # timeout
- self.timeout = timeout
-
- # Configure environment with API key if needed
- self._configure_env()
- # Log successful initialization
- logger.info(f"Initialized {self.provider_name} | Model: {self.model_id}")
-
- def _build_request_metadata(
- self,
- metadata: dict[str, Any] | None = None,
- ) -> dict[str, Any]:
- """Build default tracing metadata for provider-level calls."""
- return build_trace_metadata(
- model=self,
- component="provider.generate",
- trace_name=f"datafast.{self.provider_name}",
- metadata=metadata,
- )
-
- @property
- @abstractmethod
- def provider_name(self) -> str:
- """Return the provider name used by LiteLLM."""
- pass
-
- @property
- @abstractmethod
- def env_key_name(self) -> str:
- """Return the environment variable name for API key."""
- pass
-
- def _get_api_key(self) -> str:
- """Get API key from environment variables."""
- api_key = os.getenv(self.env_key_name)
- if not api_key:
- logger.error(
- f"Missing API key | Set {self.env_key_name} environment variable"
- )
- raise ValueError(
- f"{self.env_key_name} environment variable not set. "
- f"Please set it or provide an API key when initializing the provider."
- )
- return api_key
-
- def _configure_env(self) -> None:
- """Configure environment variables for API key."""
- if self.api_key:
- os.environ[self.env_key_name] = self.api_key
-
- def _get_model_string(self) -> str:
- """Get the full model string for LiteLLM."""
- return f"{self.provider_name}/{self.model_id}"
-
- def _respect_rate_limit(self) -> None:
- """Block execution to ensure we do not exceed the rpm_limit."""
- if self.rpm_limit is None:
- return
- current = time.monotonic()
- # Keep only timestamps within the last minute
- self._request_timestamps = [
- ts for ts in self._request_timestamps if current - ts < 60]
-
- # Be more conservative - wait if we're at 90% of the limit
- conservative_limit = max(1, int(self.rpm_limit * 0.9))
-
- if len(self._request_timestamps) < conservative_limit:
- return
-
- # Need to wait until the earliest request is outside the 60-second window
- earliest = self._request_timestamps[0]
- # Add a 2s margin to avoid accidental rate limit exceedance
- sleep_time = 62 - (current - earliest)
- if sleep_time > 0:
- logger.warning(
- f"Rate limit approaching | Requests: {len(self._request_timestamps)}/{self.rpm_limit} | "
- f"Waiting {sleep_time:.1f}s"
- )
- time.sleep(sleep_time)
- # Clean up old timestamps after waiting
- current = time.monotonic()
- self._request_timestamps = [
- ts for ts in self._request_timestamps if current - ts < 60]
-
- @staticmethod
- def _strip_code_fences(content: str) -> str:
- """Strip markdown code fences from content if present.
-
- Args:
- content: The content string that may contain code fences
-
- Returns:
- Content with code fences removed
- """
- if not content:
- return content
-
- content = content.strip()
-
- # Check for code fences with optional language identifier
- if content.startswith('```'):
- # Find the end of the first line (language identifier)
- first_newline = content.find('\n')
- if first_newline != -1:
- content = content[first_newline + 1:]
- else:
- # No newline after opening fence, remove just the fence
- content = content[3:]
-
- # Remove closing fence
- if content.endswith('```'):
- content = content[:-3]
-
- return content.strip()
-
- def generate(
- self,
- prompt: str | list[str] | None = None,
- messages: list[Messages] | Messages | None = None,
- response_format: Type[T] | None = None,
- metadata: dict[str, Any] | None = None,
- ) -> str | list[str] | T | list[T]:
- """
- Generate responses from the LLM using single or batch inference.
-
- Args:
- prompt: Single text prompt (str) or list of text prompts for batch processing
- messages: Single message list or list of message lists for batch processing
- response_format: Optional Pydantic model class for structured output
- metadata: Optional LiteLLM metadata for tracing / observability
-
- Returns:
- Single string/model or list of strings/models depending on input type.
-
- Raises:
- ValueError: If neither prompt nor messages is provided, or if both are provided.
- RuntimeError: If there's an error during generation.
- """
- # Validate inputs
- if prompt is None and messages is None:
- raise ValueError("Either prompts or messages must be provided")
- if prompt is not None and messages is not None:
- raise ValueError("Provide either prompts or messages, not both")
-
- # Determine if this is a single input or batch input
- single_input = False
- batch_prompts = None
- batch_messages = None
-
- if prompt is not None:
- if isinstance(prompt, str):
- # Single prompt - convert to batch
- batch_prompts = [prompt]
- single_input = True
- elif isinstance(prompt, list):
- # Already a list of prompts
- batch_prompts = prompt
- single_input = False
- else:
- raise ValueError("prompt must be a string or list of strings")
-
- if messages is not None:
- if isinstance(messages, list) and len(messages) > 0:
- # Check if it's a single message list or batch
- if isinstance(messages[0], dict):
- # Single message list - convert to batch
- batch_messages = [messages]
- single_input = True
- elif isinstance(messages[0], list):
- # Already a batch of message lists
- batch_messages = messages
- single_input = False
- else:
- raise ValueError("Invalid messages format")
- else:
- raise ValueError("messages cannot be empty")
-
- try:
- # Append JSON formatting instructions if response_format is provided
- json_instructions = (
- "\nReturn only valid JSON. To do so, don't include ```json ``` markdown "
- "or code fences around the JSON. Use double quotes for all keys and values. "
- "Escape internal quotes and newlines (use \\n). Do not include trailing commas."
- )
-
- # Convert batch prompts to messages if needed
- batch_to_send = []
- if batch_prompts is not None:
- for one_prompt in batch_prompts:
- # Append JSON instructions to prompt if response_format is provided
- modified_prompt = one_prompt + json_instructions if response_format is not None else one_prompt
- batch_to_send.append(get_messages(modified_prompt))
- else:
- batch_to_send = batch_messages
- # Append JSON instructions to the last user message if response_format is provided
- if response_format is not None:
- for message_list in batch_to_send:
- for msg in reversed(message_list):
- if msg.get("role") == "user":
- msg["content"] += json_instructions
- break
-
- # Enforce rate limit per batch
- self._respect_rate_limit()
-
- # Prepare completion parameters for batch
- completion_params = {
- "model": self._get_model_string(),
- "messages": batch_to_send,
- "temperature": self.temperature,
- "max_tokens": self.max_completion_tokens,
- "top_p": self.top_p,
- "frequency_penalty": self.frequency_penalty,
- "timeout": self.timeout,
- "metadata": self._build_request_metadata(metadata),
- }
- if response_format is not None:
- completion_params["response_format"] = response_format
-
- # Call LiteLLM completion with batch messages - retry on rate limit
- max_retries = 3
- retry_delay = 5 # Start with 5 seconds
- response = None
-
- for attempt in range(max_retries):
- try:
- response: list[ModelResponse] = litellm.batch_completion(
- **completion_params)
- break # Success, exit retry loop
- except RateLimitError as e:
- if attempt < max_retries - 1:
- wait_time = retry_delay * (2 ** attempt) # Exponential backoff
- logger.warning(
- f"Rate limit hit | Provider: {self.provider_name} | Model: {self.model_id} | "
- f"Attempt {attempt + 1}/{max_retries} | Waiting {wait_time}s before retry"
- )
- time.sleep(wait_time)
- else:
- logger.error(
- f"Rate limit exceeded after {max_retries} attempts | "
- f"Provider: {self.provider_name} | Model: {self.model_id}"
- )
- raise
-
- if response is None:
- raise RuntimeError("Failed to get response after retries")
-
- # Record timestamp for rate limiting (one timestamp per batch item)
- if self.rpm_limit is not None:
- current_time = time.monotonic()
- for _ in range(len(batch_to_send)):
- self._request_timestamps.append(current_time)
-
- # Extract content from each response
- results = []
- for idx, one_response in enumerate(response):
- if isinstance(one_response, Exception):
- if isinstance(one_response, RateLimitError):
- logger.warning(
- "Rate limit error in batch item | Provider: %s | Model: %s | Item: %d",
- self.provider_name,
- self.model_id,
- idx,
- )
- raise RuntimeError(
- f"Batch item {idx} failed during generation: {one_response}"
- ) from one_response
-
- if not getattr(one_response, "choices", None):
- raise RuntimeError(
- f"Unexpected response type from LiteLLM batch completion at item {idx}: {type(one_response).__name__}"
- )
-
- content = one_response.choices[0].message.content
-
- if response_format is not None:
- # Strip code fences before validation
- content = self._strip_code_fences(content)
- try:
- results.append(
- response_format.model_validate_json(content))
- except Exception as validation_error:
- # Show the content that failed to parse for debugging
- content_preview = content[:200] + "..." if len(content) > 200 else content
- logger.warning(
- f"JSON parsing failed, skipping response | "
- f"Model: {self.model_id} | "
- f"Format: {response_format.__name__} | "
- f"Content preview: {content_preview}"
- )
- raise ValueError(
- f"Failed to parse JSON response into {response_format.__name__}.\n"
- f"Validation error: {validation_error}\n"
- f"Content received (first 200 chars):\n{content_preview}"
- ) from validation_error
- else:
- # Strip leading/trailing whitespace for text responses
- results.append(content.strip() if content else content)
-
- # Return single result for backward compatibility
- if single_input and len(results) == 1:
- return results[0]
- return results
-
- except Exception as e:
- error_trace = traceback.format_exc()
- logger.error(
- f"Generation failed | Provider: {self.provider_name} | "
- f"Model: {self.model_id} | Error: {str(e)}"
- )
- raise RuntimeError(
- f"Error generating batch response with {self.provider_name}:\n{error_trace}"
- )
-
-
-class OpenAIProvider(LLMProvider):
- """OpenAI provider using litellm.responses endpoint.
-
- Note: This provider uses the new responses endpoint which has different
- parameter support compared to the standard completion endpoint:
- - temperature, top_p, and frequency_penalty are not supported
- - Uses text_format instead of response_format
- - Supports reasoning parameter for controlling reasoning effort
- - Does not support batch operations (will process sequentially with warning)
- """
-
- @property
- def provider_name(self) -> str:
- return "openai"
-
- @property
- def env_key_name(self) -> str:
- return "OPENAI_API_KEY"
-
- def __init__(
- self,
- model_id: str = "gpt-5-mini-2025-08-07",
- api_key: str | None = None,
- max_completion_tokens: int | None = None,
- reasoning_effort: str = "low",
- temperature: float | None = None,
- top_p: float | None = None,
- frequency_penalty: float | None = None,
- timeout: int | None = None,
- ):
- """Initialize the OpenAI provider.
-
- Args:
- model_id: The model ID (defaults to gpt-5-mini)
- api_key: API key (if None, will get from environment)
- max_completion_tokens: An upper bound for the number of tokens that can be generated for a completion, including visible output tokens and reasoning tokens.
- reasoning_effort: Reasoning effort level - "low", "medium", or "high" (defaults to "low")
- temperature: DEPRECATED - Not supported by responses endpoint
- top_p: DEPRECATED - Not supported by responses endpoint
- frequency_penalty: DEPRECATED - Not supported by responses endpoint
- timeout: Request timeout in seconds
- """
- # Warn about deprecated parameters
- if temperature is not None:
- warnings.warn(
- "temperature parameter is not supported by OpenAI responses endpoint and will be ignored",
- UserWarning,
- stacklevel=2
- )
- if top_p is not None:
- warnings.warn(
- "top_p parameter is not supported by OpenAI responses endpoint and will be ignored",
- UserWarning,
- stacklevel=2
- )
- if frequency_penalty is not None:
- warnings.warn(
- "frequency_penalty parameter is not supported by OpenAI responses endpoint and will be ignored",
- UserWarning,
- stacklevel=2
- )
-
- # Store reasoning effort
- self.reasoning_effort = reasoning_effort
-
- # Call parent init with None for unsupported params
- super().__init__(
- model_id=model_id,
- api_key=api_key,
- temperature=None,
- max_completion_tokens=max_completion_tokens,
- top_p=None,
- frequency_penalty=None,
- timeout=timeout,
- )
-
- def generate(
- self,
- prompt: str | list[str] | None = None,
- messages: list[Messages] | Messages | None = None,
- response_format: Type[T] | None = None,
- metadata: dict[str, Any] | None = None,
- ) -> str | list[str] | T | list[T]:
- """
- Generate responses from the LLM using the responses endpoint.
-
- Note: Batch operations are processed sequentially as the responses endpoint
- does not support native batching.
-
- Args:
- prompt: Single text prompt (str) or list of text prompts for batch processing
- messages: Single message list or list of message lists for batch processing
- response_format: Optional Pydantic model class for structured output
- metadata: Optional LiteLLM metadata for tracing / observability
-
- Returns:
- Single string/model or list of strings/models depending on input type.
-
- Raises:
- ValueError: If neither prompt nor messages is provided, or if both are provided.
- RuntimeError: If there's an error during generation.
- """
- # Validate inputs
- if prompt is None and messages is None:
- raise ValueError("Either prompts or messages must be provided")
- if prompt is not None and messages is not None:
- raise ValueError("Provide either prompts or messages, not both")
-
- # Determine if this is a single input or batch input
- single_input = False
- batch_prompts = None
- batch_messages = None
-
- if prompt is not None:
- if isinstance(prompt, str):
- # Single prompt - convert to batch
- batch_prompts = [prompt]
- single_input = True
- elif isinstance(prompt, list):
- # Already a list of prompts
- batch_prompts = prompt
- single_input = False
- else:
- raise ValueError("prompt must be a string or list of strings")
-
- if messages is not None:
- if isinstance(messages, list) and len(messages) > 0:
- # Check if it's a single message list or batch
- if isinstance(messages[0], dict):
- # Single message list - convert to batch
- batch_messages = [messages]
- single_input = True
- elif isinstance(messages[0], list):
- # Already a batch of message lists
- batch_messages = messages
- single_input = False
- else:
- raise ValueError("Invalid messages format")
- else:
- raise ValueError("messages cannot be empty")
-
- try:
- # Convert batch prompts to messages if needed
- batch_to_send = []
- if batch_prompts is not None:
- for one_prompt in batch_prompts:
- batch_to_send.append([{"role": "user", "content": one_prompt}])
- else:
- batch_to_send = batch_messages
-
- # Warn if batch processing is being used
- if len(batch_to_send) > 1:
- warnings.warn(
- f"OpenAI responses endpoint does not support batch operations. "
- f"Processing {len(batch_to_send)} requests sequentially.",
- UserWarning,
- stacklevel=2
- )
-
- # Process each request sequentially
- results = []
- for message_list in batch_to_send:
- # Enforce rate limit per request
- self._respect_rate_limit()
-
- # Prepare completion parameters
- completion_params = {
- "model": self._get_model_string(),
- "input": message_list,
- "reasoning": {"effort": self.reasoning_effort},
- "metadata": self._build_request_metadata(metadata),
- }
-
- # Add max_output_tokens if specified
- if self.max_completion_tokens is not None:
- completion_params["max_output_tokens"] = self.max_completion_tokens
-
- # Add text_format if response_format is provided
- if response_format is not None:
- completion_params["text_format"] = response_format
-
- # Call LiteLLM responses endpoint
- response = litellm.responses(**completion_params)
-
- # Record timestamp for rate limiting
- if self.rpm_limit is not None:
- self._request_timestamps.append(time.monotonic())
-
- # Extract content from response
- # Response structure: response.output[1].content[0].text
- content = response.output[1].content[0].text
-
- if response_format is not None:
- # Strip code fences before validation
- content = self._strip_code_fences(content)
- try:
- results.append(response_format.model_validate_json(content))
- except Exception as validation_error:
- # Show the content that failed to parse for debugging
- content_preview = content[:200] + "..." if len(content) > 200 else content
- logger.warning(
- f"JSON parsing failed, skipping response | "
- f"Model: {self.model_id} | "
- f"Format: {response_format.__name__} | "
- f"Content preview: {content_preview}"
- )
- raise ValueError(
- f"Failed to parse JSON response into {response_format.__name__}.\n"
- f"Validation error: {validation_error}\n"
- f"Content received (first 200 chars):\n{content_preview}"
- ) from validation_error
- else:
- # Strip leading/trailing whitespace for text responses
- results.append(content.strip() if content else content)
-
- # Return single result for backward compatibility
- if single_input and len(results) == 1:
- return results[0]
- return results
-
- except Exception as e:
- error_trace = traceback.format_exc()
- logger.error(
- f"Generation failed | Provider: {self.provider_name} | "
- f"Model: {self.model_id} | Error: {str(e)}"
- )
- raise RuntimeError(
- f"Error generating response with {self.provider_name}:\n{error_trace}"
- )
-
-
-class AnthropicProvider(LLMProvider):
- """Anthropic provider using litellm."""
-
- @property
- def provider_name(self) -> str:
- return "anthropic"
-
- @property
- def env_key_name(self) -> str:
- return "ANTHROPIC_API_KEY"
-
- def __init__(
- self,
- model_id: str = "claude-haiku-4-5-20251001",
- api_key: str | None = None,
- temperature: float | None = None,
- max_completion_tokens: int | None = None,
- timeout: int | None = None,
- # top_p: float | None = None, # Not properly supported by anthropic models 4.5
- # frequency_penalty: float | None = None, # Not supported by anthropic models 4.5
- ):
- """Initialize the Anthropic provider.
-
- Args:
- model_id: The model ID (defaults to claude-haiku-4-5-20251001)
- api_key: API key (if None, will get from environment)
- temperature: Temperature for generation (0.0 to 1.0)
- max_completion_tokens: Maximum tokens to generate
- timeout: Request timeout in seconds
- top_p: Nucleus sampling parameter (0.0 to 1.0)
- """
- super().__init__(
- model_id=model_id,
- api_key=api_key,
- temperature=temperature,
- max_completion_tokens=max_completion_tokens,
- timeout=timeout,
- )
-
-
-class GeminiProvider(LLMProvider):
- """Google Gemini provider using litellm."""
-
- @property
- def provider_name(self) -> str:
- return "gemini"
-
- @property
- def env_key_name(self) -> str:
- return "GEMINI_API_KEY"
-
- def __init__(
- self,
- model_id: str = "gemini-2.0-flash",
- api_key: str | None = None,
- temperature: float | None = None,
- max_completion_tokens: int | None = None,
- top_p: float | None = None,
- frequency_penalty: float | None = None,
- rpm_limit: int | None = None,
- timeout: int | None = None,
- ):
- """Initialize the Gemini provider.
-
- Args:
- model_id: The model ID (defaults to gemini-2.0-flash)
- api_key: API key (if None, will get from environment)
- temperature: Temperature for generation (0.0 to 1.0)
- max_completion_tokens: Maximum tokens to generate
- top_p: Nucleus sampling parameter (0.0 to 1.0)
- frequency_penalty: Penalty for token frequency (-2.0 to 2.0)
- timeout: Request timeout in seconds
- """
- super().__init__(
- model_id=model_id,
- api_key=api_key,
- temperature=temperature,
- max_completion_tokens=max_completion_tokens,
- top_p=top_p,
- frequency_penalty=frequency_penalty,
- rpm_limit=rpm_limit,
- timeout=timeout,
- )
-
-
-class OllamaProvider(LLMProvider):
- """Ollama provider using litellm.
-
- Note: Ollama typically doesn't require an API key as it's usually run locally.
- """
-
- @property
- def provider_name(self) -> str:
- return "ollama_chat"
-
- @property
- def env_key_name(self) -> str:
- return "OLLAMA_API_BASE"
-
- def _get_api_key(self) -> str:
- """Override to handle Ollama not requiring an API key.
-
- Returns an empty string since Ollama typically doesn't need an API key.
- OLLAMA_API_BASE can be used to set a custom base URL.
- """
- return ""
-
- def __init__(
- self,
- model_id: str = "gemma3:4b",
- temperature: float | None = None,
- max_completion_tokens: int | None = None,
- top_p: float | None = None,
- frequency_penalty: float | None = None,
- api_base: str | None = None,
- rpm_limit: int | None = None,
- timeout: int | None = None,
- ):
- """Initialize the Ollama provider.
-
- Args:
- model_id: The model ID (defaults to llama3)
- temperature: Temperature for generation (0.0 to 1.0)
- max_completion_tokens: Maximum tokens to generate
- top_p: Nucleus sampling parameter (0.0 to 1.0)
- frequency_penalty: Penalty for token frequency (-2.0 to 2.0)
- api_base: Base URL for Ollama API (e.g., "http://localhost:11434")
- timeout: Request timeout in seconds
- """
- # Set API base URL if provided
- if api_base:
- os.environ["OLLAMA_API_BASE"] = api_base
-
- super().__init__(
- model_id=model_id,
- api_key="", # Pass empty string since parent class requires this parameter
- temperature=temperature,
- max_completion_tokens=max_completion_tokens,
- top_p=top_p,
- frequency_penalty=frequency_penalty,
- rpm_limit=rpm_limit,
- timeout=timeout,
- )
-
-
-class OpenRouterProvider(LLMProvider):
- """OpenRouter provider using litellm"""
-
- @property
- def provider_name(self) -> str:
- return "openrouter"
-
- @property
- def env_key_name(self) -> str:
- return "OPENROUTER_API_KEY"
-
- def __init__(
- self,
- model_id: str = "openai/gpt-5-mini", # for default model
- api_key: str | None = None,
- temperature: float | None = None,
- max_completion_tokens: int | None = None,
- top_p: float | None = None,
- frequency_penalty: float | None = None,
- timeout: int | None = None,
- ):
- """Initialize the OpenRouter provider.
-
- Args:
- model_id: The model ID (defaults to openai/gpt-5-mini)
- api_key: API key (if None, will get from environment)
- temperature: Temperature for generation (0.0 to 1.0)
- max_completion_tokens: Maximum tokens to generate
- top_p: Nucleus sampling parameter (0.0 to 1.0)
- frequency_penalty: Penalty for token frequency (-2.0 to 2.0)
- timeout: Request timeout in seconds
- """
- super().__init__(
- model_id = model_id,
- api_key = api_key,
- temperature = temperature,
- max_completion_tokens = max_completion_tokens,
- top_p = top_p,
- frequency_penalty = frequency_penalty,
- timeout = timeout,
- )
-
-
-class MistralProvider(LLMProvider):
- """Mistral AI provider using litellm."""
-
- @property
- def provider_name(self) -> str:
- return "mistral"
-
- @property
- def env_key_name(self) -> str:
- return "MISTRAL_API_KEY"
-
- def __init__(
- self,
- model_id: str = "mistral-small-latest",
- api_key: str | None = None,
- temperature: float | None = None,
- max_completion_tokens: int | None = None,
- top_p: float | None = None,
- frequency_penalty: float | None = None,
- rpm_limit: int | None = None,
- timeout: int | None = None,
- ):
- """Initialize the Mistral provider.
-
- Args:
- model_id: The model ID (defaults to mistral-small-latest)
- api_key: API key (if None, will get from MISTRAL_API_KEY env var)
- temperature: Temperature for generation (0.0 to 1.0)
- max_completion_tokens: Maximum tokens to generate
- top_p: Nucleus sampling parameter (0.0 to 1.0)
- frequency_penalty: Penalty for token frequency (-2.0 to 2.0)
- rpm_limit: Requests per minute limit for rate limiting
- timeout: Request timeout in seconds
- """
- super().__init__(
- model_id=model_id,
- api_key=api_key,
- temperature=temperature,
- max_completion_tokens=max_completion_tokens,
- top_p=top_p,
- frequency_penalty=frequency_penalty,
- rpm_limit=rpm_limit,
- timeout=timeout,
- )
-
-
-def openai(model_id: str = "gpt-5-mini-2025-08-07", **kwargs) -> OpenAIProvider:
- """Create an OpenAI provider instance."""
- return OpenAIProvider(model_id=model_id, **kwargs)
-
-
-def anthropic(
- model_id: str = "claude-haiku-4-5-20251001",
- **kwargs,
-) -> AnthropicProvider:
- """Create an Anthropic provider instance."""
- return AnthropicProvider(model_id=model_id, **kwargs)
-
-
-def gemini(model_id: str = "gemini-2.0-flash", **kwargs) -> GeminiProvider:
- """Create a Gemini provider instance."""
- return GeminiProvider(model_id=model_id, **kwargs)
-
-
-def ollama(model_id: str = "gemma3:4b", **kwargs) -> OllamaProvider:
- """Create an Ollama provider instance."""
- return OllamaProvider(model_id=model_id, **kwargs)
-
-
-def openrouter(
- model_id: str = "openai/gpt-5-mini",
- **kwargs,
-) -> OpenRouterProvider:
- """Create an OpenRouter provider instance."""
- return OpenRouterProvider(model_id=model_id, **kwargs)
-
-
-def mistral(model_id: str = "mistral-small-latest", **kwargs) -> MistralProvider:
- """Create a Mistral provider instance."""
- return MistralProvider(model_id=model_id, **kwargs)
+import litellm
__all__ = [
@@ -894,13 +30,18 @@ def mistral(model_id: str = "mistral-small-latest", **kwargs) -> MistralProvider
"OpenAIProvider",
"AnthropicProvider",
"GeminiProvider",
- "OllamaProvider",
- "OpenRouterProvider",
"MistralProvider",
+ "OpenRouterProvider",
+ "OllamaProvider",
+ "OpenAICompatibleProvider",
"openai",
"anthropic",
"gemini",
- "ollama",
- "openrouter",
"mistral",
+ "openrouter",
+ "ollama",
+ "openai_compatible",
+ "litellm",
+ "load_env_once",
+ "maybe_configure_langfuse_tracing",
]
diff --git a/datafast/transforms/__init__.py b/datafast/transforms/__init__.py
index 025ea3f..f7a88d2 100644
--- a/datafast/transforms/__init__.py
+++ b/datafast/transforms/__init__.py
@@ -1,7 +1,7 @@
"""Transform steps for datafast v2."""
from datafast.transforms.sample import Sample
-from datafast.transforms.data_ops import Map, FlatMap, Filter, Group, Pair, Concat, Join
+from datafast.transforms.data_ops import AddUUID, Map, FlatMap, Filter, Group, Pair, Concat, Join
from datafast.transforms.llm_step import LLMStep
from datafast.transforms.llm_eval import Classify, Score, Compare
from datafast.transforms.llm_transform import Rewrite
@@ -9,7 +9,7 @@
from datafast.transforms.branch import Branch, JoinBranches
__all__ = [
- "Sample", "Map", "FlatMap", "Filter", "Group", "Pair", "Concat", "Join",
+ "Sample", "AddUUID", "Map", "FlatMap", "Filter", "Group", "Pair", "Concat", "Join",
"LLMStep", "Classify", "Score", "Compare", "Rewrite", "Extract",
"Branch", "JoinBranches",
]
diff --git a/datafast/transforms/data_ops.py b/datafast/transforms/data_ops.py
index 3887460..fafb5cf 100644
--- a/datafast/transforms/data_ops.py
+++ b/datafast/transforms/data_ops.py
@@ -3,6 +3,7 @@
import itertools
import random
import re
+import uuid
from collections import defaultdict
from collections.abc import Callable, Iterable
from typing import Any
@@ -62,6 +63,34 @@ def process(self, records: Iterable[Record]) -> Iterable[Record]:
yield from self._fn(record)
+class AddUUID(Step):
+ """Add a UUID field to each record."""
+
+ def __init__(self, column: str = "id", overwrite: bool = False) -> None:
+ """
+ Initialize an AddUUID step.
+
+ Args:
+ column: Field name to write the UUID into.
+ overwrite: If True, replace existing values in the target column.
+
+ Examples:
+ >>> AddUUID()
+ >>> AddUUID(column="example_id", overwrite=True)
+ """
+ super().__init__()
+ self._column = column
+ self._overwrite = overwrite
+
+ def process(self, records: Iterable[Record]) -> Iterable[Record]:
+ """Add UUIDs while preserving all other fields."""
+ for record in records:
+ if self._column in record and not self._overwrite:
+ yield record
+ else:
+ yield {**record, self._column: str(uuid.uuid4())}
+
+
class Filter(Step):
"""Keep or drop records based on conditions."""
diff --git a/datafast/transforms/llm_eval.py b/datafast/transforms/llm_eval.py
index b0ea320..6fb3e78 100644
--- a/datafast/transforms/llm_eval.py
+++ b/datafast/transforms/llm_eval.py
@@ -366,7 +366,7 @@ def _process_llm(self, records: Iterable[Record]) -> Iterable[Record]:
try:
messages = self._build_messages(record)
raw = model.generate(
- messages,
+ messages=messages,
metadata=build_trace_metadata(
model=model,
component="step.process",
@@ -657,7 +657,7 @@ def _process_llm(self, records: Iterable[Record]) -> Iterable[Record]:
try:
messages = self._build_messages(record)
raw = model.generate(
- messages,
+ messages=messages,
metadata=build_trace_metadata(
model=model,
component="step.process",
@@ -1011,7 +1011,7 @@ def _process_llm(self, records: Iterable[Record]) -> Iterable[Record]:
try:
messages = self._build_messages(record)
raw = model.generate(
- messages,
+ messages=messages,
metadata=build_trace_metadata(
model=model,
component="step.process",
diff --git a/datafast/transforms/llm_extract.py b/datafast/transforms/llm_extract.py
index aa8161d..9d3e095 100644
--- a/datafast/transforms/llm_extract.py
+++ b/datafast/transforms/llm_extract.py
@@ -418,7 +418,7 @@ def _process_llm(self, records: Iterable[Record]) -> Iterable[Record]:
try:
messages = self._build_messages(record)
raw = model.generate(
- messages,
+ messages=messages,
metadata=build_trace_metadata(
model=model,
component="step.process",
diff --git a/datafast/transforms/llm_step.py b/datafast/transforms/llm_step.py
index d2aae42..ad1a8fb 100644
--- a/datafast/transforms/llm_step.py
+++ b/datafast/transforms/llm_step.py
@@ -384,7 +384,7 @@ def process(self, records: Iterable[Record]) -> Iterable[Record]:
messages = self._build_messages(prompt_template, context)
raw_output = model.generate(
- messages,
+ messages=messages,
metadata=build_trace_metadata(
model=model,
component="step.process",
diff --git a/datafast/transforms/llm_transform.py b/datafast/transforms/llm_transform.py
index 8901a03..105ce65 100644
--- a/datafast/transforms/llm_transform.py
+++ b/datafast/transforms/llm_transform.py
@@ -298,7 +298,7 @@ def process(self, records: Iterable[Record]) -> Iterable[Record]:
try:
messages = self._build_messages(record)
raw = model.generate(
- messages,
+ messages=messages,
metadata=build_trace_metadata(
model=model,
component="step.process",
diff --git a/docs/api.md b/docs/api.md
index edef161..45857e2 100644
--- a/docs/api.md
+++ b/docs/api.md
@@ -36,6 +36,7 @@ from datafast import Source, LLMStep, Sink, openrouter
## Data Operations
- `Sample`
+- `AddUUID`
- `Map`
- `FlatMap`
- `Filter`
diff --git a/docs/cookbook/assets/index.md b/docs/cookbook/assets/index.md
new file mode 100644
index 0000000..65896be
--- /dev/null
+++ b/docs/cookbook/assets/index.md
@@ -0,0 +1,80 @@
+# Cookbook Assets
+
+Prompt files and dataset details used by cookbook examples.
+
+## Text Classification
+
+### Dataset
+
+- **Source:** seed dimensions created with `Seed.product`
+- **Dimensions:** label, trail type, style, language, and model
+- **Local output:** `examples/outputs/45_text_classification_cookbook.jsonl`
+- **Checkpoints:** `examples/checkpoints/45_text_classification_cookbook`
+- **Hub output:** optional, controlled by `DATAFAST_PUSH_TO_HUB=1`
+
+This cookbook models variation directly as seed dimensions so the label, trail
+type, style, language, and model are all explicit in the
+pipeline.
+
+### Prompt
+
+| File | Style |
+| --- | --- |
+| [text_classification_generation.txt](text_classification_generation.txt) | One short trail comment per call, with label, trail type, style, and language injected |
+
+## Persona Generation
+
+### Dataset
+
+- **Source:** `xsum` (Hugging Face), `validation` split
+- **Fields used:** `id`, `document`, `summary`
+- **Filter:** 300–500 words, first 100 matches
+- **Local output:** `examples/outputs/43_persona_cookbook.jsonl`
+- **Checkpoints:** `examples/checkpoints/43_persona_cookbook`
+- **Hub output:** set `HF_REPO_ID` and the `repo_id` in `push_records_to_hub()` to repos under your own Hugging Face username or organization
+
+The example keeps first-match sampling for reproducibility. For local JSONL corpora with metadata such as `document_filename`, stratified sampling is usually a better fit.
+
+### Prompt Variants
+
+Each LLM step picks one prompt at random per record. The script also assigns random `life_stage` and `related_life_stage` values before the corresponding LLM steps. Multiple variants add diversity.
+
+#### Text-to-Persona
+
+| File | Style |
+| --- | --- |
+| [text_to_persona_v1.txt](text_to_persona_v1.txt) | Direct inference of a reader persona |
+| [text_to_persona_v2.txt](text_to_persona_v2.txt) | XML-tagged source text, writer/reader framing |
+| [text_to_persona_v3.txt](text_to_persona_v3.txt) | System-role preamble, search-interest angle |
+
+#### Persona-to-Persona
+
+| File | Style |
+| --- | --- |
+| [persona_to_persona_v1.txt](persona_to_persona_v1.txt) | Close relationship, standalone description |
+| [persona_to_persona_v2.txt](persona_to_persona_v2.txt) | Rule-list format, explicit separation of description and relationship |
+| [persona_to_persona_v3.txt](persona_to_persona_v3.txt) | XML-tagged input, concise vivid output |
+
+### Provenance
+
+- Text-to-Persona and Persona-to-Persona prompts are paper-aligned adaptations. The Persona Hub paper states its published prompts are simplified, not exact.
+- No Persona Hub code is reused. The workflow is built with datafast primitives.
+
+## Space Engineering Text Generation
+
+### Dataset
+
+- **Source:** seed dimensions created with `Seed.product`
+- **Dimensions:** document type, topic, expertise level, and language
+- **Local output:** `examples/outputs/44_space_text_generation_cookbook.jsonl`
+- **Checkpoints:** `examples/checkpoints/44_space_text_generation_cookbook`
+- **Hub output:** optional, controlled by `DATAFAST_PUSH_TO_HUB=1`
+
+### Prompt
+
+The text-generation cookbook uses one compact prompt and relies on seed
+dimensions for variation.
+
+| File | Style |
+| --- | --- |
+| [space_text_generation.txt](space_text_generation.txt) | Minimal variable-driven request |
diff --git a/docs/cookbook/assets/persona_to_persona_v1.txt b/docs/cookbook/assets/persona_to_persona_v1.txt
new file mode 100644
index 0000000..eabb6d6
--- /dev/null
+++ b/docs/cookbook/assets/persona_to_persona_v1.txt
@@ -0,0 +1,11 @@
+Given the following persona, infer one other specific persona who is in a close relationship with them.
+
+Persona:
+{persona_description}
+
+Requirements:
+1. Use one clear relationship such as family member, colleague, friend, or neighbor, coach, teacher, married partner.
+2. Choose a related persona that adds a meaningfully different life perspective but is still likely to be in close contact with the original persona.
+3. Keep the related persona realistic and specific.
+4. Don't talk about the orginal person in the description of the related persona, as it should be self-contained description.
+5. The related persona must be {related_life_stage}. Do not state a precise age, just reflect this life stage naturally.
diff --git a/docs/cookbook/assets/persona_to_persona_v2.txt b/docs/cookbook/assets/persona_to_persona_v2.txt
new file mode 100644
index 0000000..b4e4adf
--- /dev/null
+++ b/docs/cookbook/assets/persona_to_persona_v2.txt
@@ -0,0 +1,14 @@
+Think of a person who regularly interacts with the following persona in a meaningful way.
+
+Rules:
+- Do not mention the original persona in the description of the related persona.
+- Do not mention the relationship between the two personas in the description, only in the relationship_type
+- Pick a single, concrete relationship type such as mentor-mentee, colleague, neighbor, supervisor-report, or service provider-client
+- The related person should bring a distinctly different viewpoint or expertise, and some uniqueness.
+- Keep the description realistic and standalone without mentionning with the original persona.
+- The related persona must be {related_life_stage}. Do not state a precise age, just reflect this life stage naturally.
+
+Original Persona:
+{persona_description}
+
+Now generate a related persona.
\ No newline at end of file
diff --git a/docs/cookbook/assets/persona_to_persona_v3.txt b/docs/cookbook/assets/persona_to_persona_v3.txt
new file mode 100644
index 0000000..9652161
--- /dev/null
+++ b/docs/cookbook/assets/persona_to_persona_v3.txt
@@ -0,0 +1,16 @@
+Here is the description of someone:
+
+{persona_description}
+
+
+Come up with one other description of an individual who could be part of this persona's life.
+We want the description to be detailed but super concise (max 2 sentences) and vivid.
+But we want to have the a standalone description of that new persona without mentioning the original persona or a reason in the description.
+
+Requirements:
+1. Define a clear interpersonal link such as friend, advisor, competitor, family member, or collaborator.
+2. The new persona should offer a complementary or contrasting perspective.
+3. Make the related persona vivid and believable, avoid generic archetypes.
+4. Describe the relation in relationship_type field, not in the description.
+5. The related persona must be {related_life_stage}. Do not state a precise age, just reflect this life stage naturally.
+
diff --git a/docs/cookbook/assets/space_text_generation.txt b/docs/cookbook/assets/space_text_generation.txt
new file mode 100644
index 0000000..ca5af4b
--- /dev/null
+++ b/docs/cookbook/assets/space_text_generation.txt
@@ -0,0 +1 @@
+Write one {document_type} excerpt about {topic} for {expertise_level} in {language_name}.
diff --git a/docs/cookbook/assets/text_classification_generation.txt b/docs/cookbook/assets/text_classification_generation.txt
new file mode 100644
index 0000000..85dc0f1
--- /dev/null
+++ b/docs/cookbook/assets/text_classification_generation.txt
@@ -0,0 +1,16 @@
+Write one realistic trail comment in {language_name} that sounds like something
+an actual hiker would write after being on the trail.
+
+Target category: {label}
+Category definition: {label_description}
+
+Constraints:
+- The comment must clearly match the target category.
+- The setting must be a {trail_type}.
+- The writing style must be {style}.
+- Keep it to 1 or 2 sentences.
+- Make it sound first-hand, natural, and slightly informal when appropriate.
+- Do not sound like an official report, safety bulletin, or structured form.
+- Do not mention the category name directly.
+- Do not use bullets, numbering, or explanations.
+- Make the comment concrete and varied.
diff --git a/docs/cookbook/assets/text_to_persona_v1.txt b/docs/cookbook/assets/text_to_persona_v1.txt
new file mode 100644
index 0000000..cd09909
--- /dev/null
+++ b/docs/cookbook/assets/text_to_persona_v1.txt
@@ -0,0 +1,17 @@
+Infer one specific persona who is likely to read text.
+
+Source text:
+{document}
+
+Requirements:
+1. Return a single persona, not a group.
+2. Make the persona specific and fine-grained rather than generic.
+3. Ground the persona in signals from the text such as domain, expertise, context, or likely motivation.
+4. Do not quote the source text in the persona field.
+5. Only write 1 or 2 sentences maximum.
+6. The persona is not the subject of the text, but rather someone who would be reading it.
+7. Do not refer to the source text, article, or its content in the persona description. The persona must be self-contained.
+8. The persona must be {life_stage}. Do not mention a precise age, just reflect this life stage naturally.
+
+Now figure out a persona description who would be reading this text.
+
diff --git a/docs/cookbook/assets/text_to_persona_v2.txt b/docs/cookbook/assets/text_to_persona_v2.txt
new file mode 100644
index 0000000..294577d
--- /dev/null
+++ b/docs/cookbook/assets/text_to_persona_v2.txt
@@ -0,0 +1,16 @@
+
+{document}
+
+
+Identify one precise individual who would naturally encounter or write the .
+
+Requirements:
+1. Describe exactly one person.
+2. Be as specific as possible: mention plausible occupation and/or life situation.
+3. Derive the persona strictly from cues in the text such as topic, jargon, tone, or implied audience as a potential writter / reader of this text.
+4. Do not copy or paraphrase the source text in the persona field.
+5. Only return 1 or 2 sentences maximum.
+6. The described person is not the subject of the text, but rather someone who would be encountering or writing such text as part of their life.
+7. Do not reference the source text, article, or its content in the persona description. The persona must stand on its own.
+8. The persona must be {life_stage}. Do not state a precise age, just reflect this life stage naturally.
+
diff --git a/docs/cookbook/assets/text_to_persona_v3.txt b/docs/cookbook/assets/text_to_persona_v3.txt
new file mode 100644
index 0000000..3ccb077
--- /dev/null
+++ b/docs/cookbook/assets/text_to_persona_v3.txt
@@ -0,0 +1,17 @@
+You are a persona inference assistant.
+
+Based on the text content below, imagine one real person who would be interested in searching about the topic from this content.
+
+Rules:
+- Output a single, concrete persona rather than a broad demographic.
+- Include details like professional background, interests, or situational context that make the persona feel authentic.
+- Don't mention the person search or information retrieval action in the persona description, just describe the persona which could explain their interest in the topic.
+- Keep it super short and concise.
+- Do not mention or refer to the source text, article, or its content in the persona description. The persona must be self-contained.
+- The persona must be {life_stage}. Do not state a precise age, just reflect this life stage naturally.
+
+Source text:
+{document}
+
+
+
diff --git a/docs/cookbook/index.md b/docs/cookbook/index.md
new file mode 100644
index 0000000..1b745ec
--- /dev/null
+++ b/docs/cookbook/index.md
@@ -0,0 +1,16 @@
+# Cookbook
+
+Cookbooks connect a runnable script to a documentation walkthrough.
+
+The Python script is the source of truth. Each cookbook page explains:
+
+- where the executable example lives
+- what inputs it uses
+- which prompt assets it depends on
+- where it writes its output artifacts
+
+## Available Cookbooks
+
+- [Text Classification](text_classification.md): generate a multilingual trail-conditions classification dataset from explicit seed dimensions.
+- [Persona Generation](persona_generation.md): infer personas from real articles and expand them through relationships using randomized prompt variants.
+- [Space Engineering Text Generation](space_text_generation.md): generate a raw multilingual technical text corpus from seed dimensions.
diff --git a/docs/cookbook/persona_generation.md b/docs/cookbook/persona_generation.md
new file mode 100644
index 0000000..f314a39
--- /dev/null
+++ b/docs/cookbook/persona_generation.md
@@ -0,0 +1,89 @@
+# Persona Generation
+
+Build personas from real articles and expand them through relationships. Inspired by the Persona Hub paper, implemented entirely with datafast.
+
+## Source
+
+- **Script:** `examples/scripts/43_cookbook_persona_generation.py`
+- **Prompt assets:** [asset index](assets/index.md)
+- **Local output:** `examples/outputs/43_persona_cookbook.jsonl`
+- **Checkpoints:** `examples/checkpoints/43_persona_cookbook`
+- **Hub output:** pushed to the Hugging Face Hub repo IDs configured in the script
+
+## Pipeline
+
+1. Load `xsum` articles (`validation` split), preserving the dataset `id`.
+2. Filter to documents between 300 and 500 words. Keep the first 100 matches.
+3. Assign a random life stage to the source persona.
+4. **Text-to-Persona** — infer one persona from each article and life stage.
+5. Assign a random life stage to the related persona.
+6. **Persona-to-Persona** — expand that persona into a related individual.
+7. Keep the final output fields, add a row UUID, write JSONL, checkpoint progress, and push results to Hugging Face Hub.
+
+Each LLM step randomly picks one prompt variant per record using `Sample(prompts, n=1)`. This adds diversity across generations.
+
+The cookbook keeps `Sample(n=100, strategy="first")` so runs are deterministic and easy to compare. For local corpora with source metadata, use stratified sampling, for example `Sample(n=250, strategy="stratified", by="document_filename")`, to avoid over-representing one source file.
+
+```text
+xsum article
+ │
+ â–¼
+life_stage (random from configured stages)
+ │
+ â–¼
+Text-to-Persona (random prompt from 3 variants)
+ │
+ â–¼
+related_life_stage (random from configured stages)
+ │
+ â–¼
+Persona-to-Persona (random prompt from 3 variants)
+ │
+ â–¼
+Hugging Face Hub
+```
+
+## Run
+
+Prerequisites:
+
+- `OPENROUTER_API_KEY` set in a `.env` file
+- Hugging Face authentication via `HF_TOKEN` in `.env` or a cached `huggingface_hub` login
+- Base dependencies from `pyproject.toml` installed
+
+Before running, replace the example Hugging Face namespaces in the script with your own username or organization:
+
+- `HF_REPO_ID = "/new-persona-cookbook-dataset"` controls the private pipeline sink.
+- `repo_id = "/datafast-persona-cookbook"` inside `push_records_to_hub()` controls the public publish step.
+
+```bash
+python examples/scripts/43_cookbook_persona_generation.py
+```
+
+The run uses `checkpoint_dir` and `resume=True`, which is useful for paid or rate-limited LLM calls. If a run is interrupted, re-run the same command to continue from the saved checkpoints.
+
+The main example reads from Hugging Face. For a local JSONL corpus, replace `Source.huggingface(...)` with `Source.file(...)` and map your text column to `document` before `add_word_count`.
+
+## Prompt Variants
+
+Each step draws from multiple prompt files stored under `docs/cookbook/assets/`. See the [asset index](assets/index.md) for the full list.
+
+- **Text-to-Persona:** 3 variants (`text_to_persona_v1.txt`, `v2`, `v3`)
+- **Persona-to-Persona:** 3 variants (`persona_to_persona_v1.txt`, `v2`, `v3`)
+
+## Research Basis
+
+The Persona Hub paper introduces Text-to-Persona and Persona-to-Persona as scalable methods for building personas from web text. The paper states that its published prompts are simplified, not the exact experiment strings. This cookbook treats them as paper-aligned adaptations. It does not reuse any Persona Hub code.
+
+## Output Fields
+
+- `id` — generated row UUID
+- `source_id` — original XSum record identifier
+- `summary` — original article summary
+- `document` — source article text
+- `word_count` — whitespace token count
+- `life_stage` — randomly selected life stage for the inferred persona
+- `persona_description` — inferred persona
+- `relationship_type` — link between the two personas
+- `related_life_stage` — randomly selected life stage for the expanded persona
+- `related_persona_description` — the expanded related persona
diff --git a/docs/cookbook/space_text_generation.md b/docs/cookbook/space_text_generation.md
new file mode 100644
index 0000000..92c55dc
--- /dev/null
+++ b/docs/cookbook/space_text_generation.md
@@ -0,0 +1,103 @@
+# Space Engineering Text Generation
+
+Build a raw technical text corpus across document types, topics, expertise levels,
+languages, and model choices.
+
+## Source
+
+- **Script:** `examples/scripts/44_cookbook_space_text_generation.py`
+- **Prompt assets:** [asset index](assets/index.md)
+- **Local output:** `examples/outputs/44_space_text_generation_cookbook.jsonl`
+- **Checkpoints:** `examples/checkpoints/44_space_text_generation_cookbook`
+- **Hub output:** optional, controlled by `DATAFAST_PUSH_TO_HUB=1`
+
+## Pipeline
+
+1. Create a seed grid with `Seed.product`.
+2. Cross document types, topics, and expertise levels explicitly.
+3. Generate one section per seed and language with `LLMStep`.
+4. Let the prompt variables drive the corpus variation.
+5. Parse `title` and `text` from JSON mode.
+6. Keep publication fields, add a row UUID, write JSONL, checkpoint progress,
+ and optionally push to Hugging Face Hub.
+
+The default model is `nvidia/nemotron-3-super-120b-a12b:nitro` through
+OpenRouter.
+
+```text
+document_type x topic x expertise_level
+ |
+ v
+LLMStep language expansion: English and French
+ |
+ v
+JSON fields: title, text
+ |
+ v
+examples/outputs/44_space_text_generation_cookbook.jsonl
+```
+
+## Row Count
+
+The default script generates:
+
+```text
+3 document types x 8 topics x 3 expertise levels x 2 languages
+x 1 generated output x 1 model = 144 rows
+```
+
+To use several models, add provider IDs to `MODEL_IDS`. `LLMStep` will run each
+seed-language combination through every model and the row count will multiply by
+the number of models.
+
+## Run
+
+Prerequisites:
+
+- `OPENROUTER_API_KEY` set in a `.env` file
+- Base dependencies from `pyproject.toml` installed
+- Hugging Face authentication only if publishing
+
+```bash
+python examples/scripts/44_cookbook_space_text_generation.py
+```
+
+To publish, replace `HF_REPO_ID` in the script with a repository under your own
+Hugging Face username or organization, then run:
+
+```bash
+DATAFAST_PUSH_TO_HUB=1 python examples/scripts/44_cookbook_space_text_generation.py
+```
+
+The run uses `checkpoint_dir` and `resume=True`. If generation is interrupted,
+run the command again to continue from saved checkpoints.
+
+## Prompt
+
+The script uses one compact prompt file:
+
+```text
+Write one {document_type} excerpt about {topic} for {expertise_level} in {language_name}.
+```
+
+## Generation Controls
+
+- `MODEL_IDS` controls which models generate each record.
+- `LANGUAGES` controls language expansion and writes the emitted language code to
+ the `language` field.
+- `NUM_OUTPUTS` controls how many generated rows are created for each
+ seed, language, and model combination.
+- `PROMPT_PATH` controls the prompt file used for generation.
+- `SEED` controls deterministic dataset splitting when publishing.
+- `HF_REPO_ID` controls the optional Hugging Face Hub destination.
+
+## Output Fields
+
+- `id` - generated row UUID
+- `document_type` - requested document style
+- `topic` - space engineering topic
+- `expertise_level` - intended reader level
+- `language` - language code emitted by `LLMStep`
+- `model` - model ID emitted by `LLMStep`
+- `title` - generated section title
+- `text` - generated corpus text
diff --git a/docs/cookbook/text_classification.md b/docs/cookbook/text_classification.md
new file mode 100644
index 0000000..dd36422
--- /dev/null
+++ b/docs/cookbook/text_classification.md
@@ -0,0 +1,118 @@
+# Text Classification
+
+Build a multilingual trail-conditions classification dataset with `datafast`.
+
+## Source
+
+- **Script:** `examples/scripts/45_cookbook_text_classification.py`
+- **Prompt assets:** [asset index](assets/index.md)
+- **Local output:** `examples/outputs/45_text_classification_cookbook.jsonl`
+- **Checkpoints:** `examples/checkpoints/45_text_classification_cookbook`
+- **Hub output:** optional, controlled by `DATAFAST_PUSH_TO_HUB=1`
+
+## Use Case
+
+This cookbook generates short trail comments across four trail-condition labels
+so teams can monitor trail quality and surface issues quickly.
+
+The default setup is:
+
+- multi-class: 4 trail-condition labels
+- multi-lingual: English and French
+- multi-model: two generation models by default
+- publishable: optional push to Hugging Face Hub
+
+## Pipeline
+
+1. Create a seed grid from labels, trail types, and writing styles.
+2. Generate one short trail comment for each seed across all configured models
+ and languages.
+3. Keep the label and prompt-variation provenance in flat output columns.
+4. Add a UUID, write JSONL locally, and optionally push to Hugging Face Hub.
+
+Variation is modeled explicitly through `Seed.product(...)`, which keeps the
+generation axes inspectable and easy to count.
+
+```text
+label x trail_type x style
+ |
+ v
+LLMStep language expansion: English and French
+ |
+ v
+LLMStep model expansion
+ |
+ v
+examples/outputs/45_text_classification_cookbook.jsonl
+```
+
+## Row Count
+
+The default script generates:
+
+```text
+4 labels x 3 trail types x 2 styles x 2 languages
+x 2 models = 96 rows
+```
+
+Each extra model in `MODEL_IDS` multiplies the total row count.
+
+## Run
+
+Prerequisites:
+
+- `OPENROUTER_API_KEY` set in a `.env` file
+- Base dependencies from `pyproject.toml` installed
+- Hugging Face authentication only if publishing
+
+```bash
+python examples/scripts/45_cookbook_text_classification.py
+```
+
+To publish, replace `HF_REPO_ID` in the script with a repository under your own
+Hugging Face username or organization, then run:
+
+```bash
+DATAFAST_PUSH_TO_HUB=1 python examples/scripts/45_cookbook_text_classification.py
+```
+
+The run uses `checkpoint_dir` and `resume=True`. If generation is interrupted,
+run the command again to continue from saved checkpoints.
+
+If you want to use provider-specific clients directly, replace `MODEL_IDS` or
+the `model=MODELS` argument in `LLMStep` with providers such as `openai(...)`
+or `anthropic(...)`. The default setup uses multiple OpenRouter-backed models
+so it works with one API key.
+
+## Prompt
+
+The cookbook uses one prompt file and drives diversity through seed dimensions:
+
+```text
+Write one realistic trail comment in {language_name}.
+```
+
+See [text_classification_generation.txt](assets/text_classification_generation.txt)
+for the full prompt.
+
+## Generation Controls
+
+- `LABELS` defines the target classes and their prompt descriptions.
+- `TRAIL_TYPES` controls the trail settings used in generation.
+- `STYLES` controls the voice and format of each comment.
+- `LANGUAGES` controls language expansion.
+- `MODEL_IDS` controls which models generate records.
+- `HF_REPO_ID` controls the optional Hugging Face Hub destination.
+
+If you want an extra quality-control pass, add a downstream `Classify` and
+`Filter` stage to verify that generated comments match their intended label.
+
+## Output Fields
+
+- `id` - generated row UUID
+- `label` - target trail-condition label
+- `trail_type` - prompt expansion axis for the trail setting
+- `style` - prompt expansion axis for the comment style
+- `language` - language code emitted by `LLMStep`
+- `model` - model ID emitted by `LLMStep`
+- `text` - generated trail comment
diff --git a/docs/guides/building_pipelines.md b/docs/guides/building_pipelines.md
index 64aaaf2..b755410 100644
--- a/docs/guides/building_pipelines.md
+++ b/docs/guides/building_pipelines.md
@@ -3,11 +3,12 @@
## Minimal Pipeline
```python
-from datafast import Map, Sink, Source
+from datafast import AddUUID, Map, Sink, Source
pipeline = (
Source.list([{"text": "hello"}])
>> Map(lambda r: {**r, "length": len(r["text"])})
+ >> AddUUID()
>> Sink.list()
)
@@ -38,6 +39,7 @@ seed = Seed.product(
## Core Data Operations
+- `AddUUID`: add a UUID field to each record
- `Map`: one record in, one record out
- `FlatMap`: one record in, many records out
- `Filter`: keep or drop records
diff --git a/examples/providers/README.md b/examples/providers/README.md
new file mode 100644
index 0000000..5c9e4c7
--- /dev/null
+++ b/examples/providers/README.md
@@ -0,0 +1,8 @@
+# Provider Examples
+
+This folder contains direct, provider-focused examples.
+
+- `openrouter/`: simple OpenRouter calls with `model.generate(...)`
+
+These scripts are intentionally separate from `examples/scripts/`, which focuses on
+pipeline usage.
diff --git a/examples/providers/openrouter/01_simple_prompt.py b/examples/providers/openrouter/01_simple_prompt.py
new file mode 100644
index 0000000..fdfb279
--- /dev/null
+++ b/examples/providers/openrouter/01_simple_prompt.py
@@ -0,0 +1,22 @@
+"""Minimal OpenRouter example with a single prompt."""
+
+from dotenv import load_dotenv
+
+from datafast import openrouter
+from datafast.llm_utils import format_generated_responses
+
+
+MODEL_ID = "openai/gpt-5.4-mini"
+PROMPT = "Write one sentence explaining what OpenRouter is."
+
+
+def main() -> None:
+ load_dotenv()
+
+ model = openrouter(MODEL_ID, temperature=0)
+ response = model.generate(prompt=PROMPT)
+ print(format_generated_responses(PROMPT, response))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/providers/openrouter/02_batch_prompts.py b/examples/providers/openrouter/02_batch_prompts.py
new file mode 100644
index 0000000..765b219
--- /dev/null
+++ b/examples/providers/openrouter/02_batch_prompts.py
@@ -0,0 +1,26 @@
+"""Minimal OpenRouter example with a batch of prompts."""
+
+from dotenv import load_dotenv
+
+from datafast import openrouter
+from datafast.llm_utils import format_generated_responses
+
+
+MODEL_ID = "openai/gpt-5.4-mini"
+PROMPTS = [
+ "Give a one-sentence definition of synthetic data.",
+ "Give a one-sentence definition of retrieval-augmented generation.",
+ "Give a one-sentence definition of tool calling.",
+]
+
+
+def main() -> None:
+ load_dotenv()
+
+ model = openrouter(MODEL_ID, temperature=0)
+ responses = model.generate(prompt=PROMPTS)
+ print(format_generated_responses(PROMPTS, responses))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/providers/openrouter/03_messages_with_system_prompt.py b/examples/providers/openrouter/03_messages_with_system_prompt.py
new file mode 100644
index 0000000..546e63f
--- /dev/null
+++ b/examples/providers/openrouter/03_messages_with_system_prompt.py
@@ -0,0 +1,31 @@
+"""OpenRouter example using explicit chat messages."""
+
+from dotenv import load_dotenv
+
+from datafast import openrouter
+from datafast.llm_utils import format_generated_responses
+
+
+MODEL_ID = "openai/gpt-5.4-mini"
+MESSAGES = [
+ {
+ "role": "system",
+ "content": "You are a concise technical assistant. Answer in exactly two bullets.",
+ },
+ {
+ "role": "user",
+ "content": "Explain why teams use an LLM router.",
+ },
+]
+
+
+def main() -> None:
+ load_dotenv()
+
+ model = openrouter(MODEL_ID, temperature=0)
+ response = model.generate(messages=MESSAGES)
+ print(format_generated_responses(MESSAGES[-1]["content"], response))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/providers/openrouter/04_structured_output.py b/examples/providers/openrouter/04_structured_output.py
new file mode 100644
index 0000000..1c2bd61
--- /dev/null
+++ b/examples/providers/openrouter/04_structured_output.py
@@ -0,0 +1,28 @@
+"""OpenRouter example with structured output validation."""
+
+from dotenv import load_dotenv
+from pydantic import BaseModel
+
+from datafast import openrouter
+
+
+MODEL_ID = "openai/gpt-5.4-mini"
+PROMPT = "Return a JSON object describing OpenRouter in two short sentences."
+
+
+class ProviderSummary(BaseModel):
+ name: str
+ summary: str
+ best_for: str
+
+
+def main() -> None:
+ load_dotenv()
+
+ model = openrouter(MODEL_ID, temperature=0)
+ response = model.generate(prompt=PROMPT, response_format=ProviderSummary)
+ print(response.model_dump_json(indent=2))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/providers/openrouter/05_batch_messages.py b/examples/providers/openrouter/05_batch_messages.py
new file mode 100644
index 0000000..25b42d7
--- /dev/null
+++ b/examples/providers/openrouter/05_batch_messages.py
@@ -0,0 +1,44 @@
+"""OpenRouter example with a batch of message lists."""
+
+from dotenv import load_dotenv
+
+from datafast import openrouter
+from datafast.llm_utils import format_generated_responses
+
+
+MODEL_ID = "openai/gpt-5.4-mini"
+BATCH_MESSAGES = [
+ [
+ {
+ "role": "system",
+ "content": "You answer for engineers in one sentence.",
+ },
+ {
+ "role": "user",
+ "content": "What is prompt caching?",
+ },
+ ],
+ [
+ {
+ "role": "system",
+ "content": "You answer for engineers in one sentence.",
+ },
+ {
+ "role": "user",
+ "content": "What is structured output?",
+ },
+ ],
+]
+
+
+def main() -> None:
+ load_dotenv()
+
+ model = openrouter(MODEL_ID, temperature=0)
+ responses = model.generate(messages=BATCH_MESSAGES)
+ prompts = [messages[-1]["content"] for messages in BATCH_MESSAGES]
+ print(format_generated_responses(prompts, responses))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/scripts/43_cookbook_persona_generation.py b/examples/scripts/43_cookbook_persona_generation.py
new file mode 100644
index 0000000..ac4f718
--- /dev/null
+++ b/examples/scripts/43_cookbook_persona_generation.py
@@ -0,0 +1,137 @@
+"""Persona-generation cookbook: XSum article -> personas -> related personas.
+
+Demonstrates: Source.huggingface, Map, Filter, Sample, JSON-mode LLMSteps,
+and prompt assets stored under docs/cookbook/assets.
+
+Requires:
+- OPENROUTER_API_KEY
+- Hugging Face authentication via HF_TOKEN or a cached `huggingface_hub` login
+- network access to Hugging Face and OpenRouter
+"""
+
+import random
+
+from dotenv import load_dotenv
+
+from datafast import AddUUID, Filter, LLMStep, Map, Sample, Sink, Source, openrouter
+
+import litellm
+
+load_dotenv()
+
+litellm.suppress_debug_info = True
+
+
+MODEL_ID = "nvidia/nemotron-3-super-120b-a12b:nitro"
+OUTPUT_PATH = "examples/outputs/43_persona_cookbook.jsonl"
+CHECKPOINT_DIR = "examples/checkpoints/43_persona_cookbook"
+HF_REPO_ID = "patrickfleith/new-persona-cookbook-dataset"
+TEXT_TO_PERSONA_PROMPTS = [
+ "docs/cookbook/assets/text_to_persona_v1.txt",
+ "docs/cookbook/assets/text_to_persona_v2.txt",
+ "docs/cookbook/assets/text_to_persona_v3.txt",
+]
+PERSONA_TO_PERSONA_PROMPTS = [
+ "docs/cookbook/assets/persona_to_persona_v1.txt",
+ "docs/cookbook/assets/persona_to_persona_v2.txt",
+ "docs/cookbook/assets/persona_to_persona_v3.txt",
+]
+LIFE_STAGES = [
+ "a teenager",
+ "a young adult",
+ "an adult (30s/40s)",
+ "a middle-aged person (in their 50s/60s)",
+ "a senior person (in their 70s/80s)",
+]
+
+
+def add_word_count(record: dict) -> dict:
+ return {**record, "word_count": len(record["document"].split())}
+
+
+def assign_life_stage(record: dict) -> dict:
+ return {**record, "life_stage": random.choice(LIFE_STAGES)}
+
+
+def assign_related_life_stage(record: dict) -> dict:
+ return {**record, "related_life_stage": random.choice(LIFE_STAGES)}
+
+
+def keep_output_fields(record: dict) -> dict:
+ return {
+ "source_id": record["id"],
+ "summary": record["summary"],
+ "document": record["document"],
+ "word_count": record["word_count"],
+ "life_stage": record["life_stage"],
+ "persona_description": record["persona_description"],
+ "relationship_type": record["relationship_type"],
+ "related_life_stage": record["related_life_stage"],
+ "related_persona_description": record["related_persona_description"],
+ }
+
+
+def build_pipeline():
+ model = openrouter(MODEL_ID, temperature=0.7)
+
+ return (
+ Source.huggingface(
+ "xsum",
+ split="validation",
+ columns=["id", "document", "summary"],
+ )
+ # For a local JSONL corpus, replace the Hugging Face source with something
+ # like Source.file("data/articles.jsonl") and map your text field to
+ # "document" before add_word_count.
+ >> Map(add_word_count).as_step("add_word_count")
+ >> Filter(fn=lambda r: 300 <= r["word_count"] <= 500).as_step("filter_word_count")
+ >> Sample(n=10, strategy="first").as_step("take_first_100")
+ >> Map(assign_life_stage).as_step("assign_life_stage")
+ >> LLMStep(
+ prompt=Sample(TEXT_TO_PERSONA_PROMPTS, n=1),
+ input_columns=["document", "life_stage"],
+ output_columns=["persona_description"],
+ model=model,
+ parse_mode="json",
+ on_parse_error="raise",
+ ).as_step("text_to_persona")
+ >> Map(assign_related_life_stage).as_step("assign_related_life_stage")
+ >> LLMStep(
+ prompt=Sample(PERSONA_TO_PERSONA_PROMPTS, n=1),
+ input_columns=["persona_description", "related_life_stage"],
+ output_columns=["relationship_type", "related_persona_description"],
+ model=model,
+ parse_mode="json",
+ on_parse_error="raise",
+ ).as_step("persona_to_persona")
+ >> Map(keep_output_fields).as_step("keep_output_fields")
+ >> AddUUID(column="id", overwrite=True).as_step("add_uuid")
+ >> Sink.jsonl(OUTPUT_PATH)
+ >> Sink.hub(HF_REPO_ID, private=True)
+)
+
+
+def push_records_to_hub(records: list[dict]) -> None:
+ repo_id = "patrickfleith/datafast-persona-cookbook"
+ private = False
+
+ list(
+ Sink.hub(
+ repo_id=repo_id,
+ private=private,
+ commit_message=f"Publish cookbook 43 persona dataset with {MODEL_ID}",
+ ).process(records)
+ )
+
+
+def main() -> None:
+ records = build_pipeline().run(
+ batch_size=1,
+ checkpoint_dir=CHECKPOINT_DIR,
+ resume=False,
+ )
+ push_records_to_hub(records)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/scripts/44_cookbook_space_text_generation.py b/examples/scripts/44_cookbook_space_text_generation.py
new file mode 100644
index 0000000..6c5d2cb
--- /dev/null
+++ b/examples/scripts/44_cookbook_space_text_generation.py
@@ -0,0 +1,143 @@
+"""Space text-generation cookbook: seed grid -> technical text corpus.
+
+Demonstrates: Seed.product, LLMStep JSON mode, multi-language generation,
+num_outputs, checkpointing, JSONL output, and optional Hub push.
+
+Requires:
+- OPENROUTER_API_KEY
+- Hugging Face authentication only if DATAFAST_PUSH_TO_HUB=1
+- network access to OpenRouter, and to Hugging Face when publishing
+"""
+
+from __future__ import annotations
+
+import os
+
+import litellm
+from dotenv import load_dotenv
+
+from datafast import AddUUID, LLMStep, Map, Seed, Sink, openrouter
+
+load_dotenv()
+litellm.suppress_debug_info = True
+
+
+SEED = 20250304
+MODEL_IDS = ["nvidia/nemotron-3-super-120b-a12b:nitro"]
+OUTPUT_PATH = "examples/outputs/44_space_text_generation_cookbook.jsonl"
+CHECKPOINT_DIR = "examples/checkpoints/44_space_text_generation_cookbook"
+HF_REPO_ID = "patrickfleith/datafast-space-text-generation-cookbook"
+NUM_OUTPUTS = 1
+PROMPT_PATH = "docs/cookbook/assets/space_text_generation.txt"
+
+DOCUMENT_TYPES = [
+ "space engineering textbook",
+ "spacecraft design justification document",
+ "personal blog of a space engineer",
+]
+
+TOPICS = [
+ "Microgravity",
+ "Vacuum",
+ "Heavy Ions",
+ "Thermal Extremes",
+ "Atomic Oxygen",
+ "Debris Impact",
+ "Electrostatic Charging",
+ "Propellant Boil-off",
+]
+
+EXPERTISE_LEVELS = [
+ "executives",
+ "senior engineers",
+ "PhD candidates",
+]
+
+LANGUAGES = {
+ "en": "English",
+ "fr": "French",
+}
+
+
+def make_models():
+ return [openrouter(model_id, temperature=0.7) for model_id in MODEL_IDS]
+
+
+def expected_row_count(model_count: int | None = None) -> int:
+ """Return the number of rows this configuration should generate."""
+ model_total = len(MODEL_IDS) if model_count is None else model_count
+ return (
+ len(DOCUMENT_TYPES)
+ * len(TOPICS)
+ * len(EXPERTISE_LEVELS)
+ * len(LANGUAGES)
+ * NUM_OUTPUTS
+ * model_total
+ )
+
+
+def finalize_record(record: dict) -> dict:
+ """Keep the columns meant for publication."""
+ return {
+ "document_type": record["document_type"],
+ "topic": record["topic"],
+ "expertise_level": record["expertise_level"],
+ "language": record.get("_language", ""),
+ "model": record.get("_model", ""),
+ "title": record["title"],
+ "text": record["text"],
+ }
+
+
+def build_pipeline():
+ return (
+ Seed.product(
+ Seed.values("document_type", DOCUMENT_TYPES),
+ Seed.values("topic", TOPICS),
+ Seed.values("expertise_level", EXPERTISE_LEVELS),
+ ).as_step("seed_space_text_grid")
+ >> LLMStep(
+ prompt=PROMPT_PATH,
+ input_columns=["document_type", "topic", "expertise_level"],
+ output_columns=["title", "text"],
+ parse_mode="json",
+ model=make_models(),
+ language=LANGUAGES,
+ num_outputs=NUM_OUTPUTS,
+ on_parse_error="raise",
+ ).as_step("generate_space_text")
+ >> Map(finalize_record).as_step("finalize_record")
+ >> AddUUID(column="id", overwrite=True).as_step("add_uuid")
+ >> Sink.jsonl(OUTPUT_PATH)
+ )
+
+
+def push_records_to_hub(records: list[dict]) -> None:
+ list(
+ Sink.hub(
+ repo_id=HF_REPO_ID,
+ private=True,
+ train_size=0.8,
+ seed=SEED,
+ shuffle=True,
+ commit_message=f"Publish cookbook 44 text dataset with {', '.join(MODEL_IDS)}",
+ ).process(records)
+ )
+
+
+def main() -> None:
+ print(f"Expected rows: {expected_row_count()}")
+ records = build_pipeline().run(
+ batch_size=4,
+ checkpoint_dir=CHECKPOINT_DIR,
+ resume=True,
+ )
+
+ if os.getenv("DATAFAST_PUSH_TO_HUB") == "1":
+ push_records_to_hub(records)
+
+ print(f"Wrote {len(records)} records to {OUTPUT_PATH}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/scripts/45_cookbook_text_classification.py b/examples/scripts/45_cookbook_text_classification.py
new file mode 100644
index 0000000..9dce7af
--- /dev/null
+++ b/examples/scripts/45_cookbook_text_classification.py
@@ -0,0 +1,158 @@
+"""Text-classification cookbook: seed grid -> multilingual trail comments.
+
+Demonstrates: Seed.product, prompt expansion via seed dimensions, multi-model
+generation, multi-language generation, checkpointing, JSONL output, and
+optional Hugging Face Hub publishing.
+
+Requires:
+- OPENROUTER_API_KEY
+- Hugging Face authentication only if DATAFAST_PUSH_TO_HUB=1
+- network access to OpenRouter, and to Hugging Face when publishing
+"""
+
+from __future__ import annotations
+
+import os
+
+import litellm
+from dotenv import load_dotenv
+
+from datafast import AddUUID, LLMStep, Map, Seed, SeedDimension, Sink, openrouter
+
+load_dotenv()
+litellm.suppress_debug_info = True
+
+
+SEED = 20250611
+MODEL_IDS = [
+ "nvidia/nemotron-3-super-120b-a12b:nitro",
+ "mistralai/ministral-14b-2512",
+]
+OUTPUT_PATH = "examples/outputs/45_text_classification_cookbook.jsonl"
+CHECKPOINT_DIR = "examples/checkpoints/45_text_classification_cookbook"
+HF_REPO_ID = "patrickfleith/datafast-text-classification-cookbook"
+PROMPT_PATH = "docs/cookbook/assets/text_classification_generation.txt"
+
+LABELS = [
+ {
+ "label": "trail_obstruction",
+ "label_description": (
+ "The trail is partially or fully blocked by obstacles such as "
+ "fallen trees, landslides, snow, flooding, erosion, or dense "
+ "vegetation."
+ ),
+ },
+ {
+ "label": "infrastructure_issues",
+ "label_description": (
+ "The report is about damaged or missing bridges, signs, stairs, "
+ "handrails, markers, boardwalks, or similar trail infrastructure."
+ ),
+ },
+ {
+ "label": "hazards",
+ "label_description": (
+ "The trail has immediate safety risks such as slippery surfaces, "
+ "dangerous crossings, unstable terrain, wildlife threats, or "
+ "other hazardous conditions."
+ ),
+ },
+ {
+ "label": "positive_conditions",
+ "label_description": (
+ "The report highlights clear, safe, enjoyable trail conditions "
+ "such as good maintenance, solid infrastructure, clear signage, "
+ "or scenic features."
+ ),
+ },
+]
+
+TRAIL_TYPES = [
+ "mountain trail",
+ "coastal path",
+ "forest walk",
+]
+
+STYLES = [
+ "a brief social media post",
+ "a hiking review",
+]
+
+LANGUAGES = {
+ "en": "English",
+ "fr": "French",
+}
+
+MODELS = [openrouter(model_id, temperature=0.8) for model_id in MODEL_IDS]
+EXPECTED_ROWS = (
+ len(LABELS)
+ * len(TRAIL_TYPES)
+ * len(STYLES)
+ * len(LANGUAGES)
+ * len(MODELS)
+)
+
+
+def keep_output_fields(record: dict) -> dict:
+ """Keep only the fields meant for publication."""
+ return {
+ "label": record["label"],
+ "trail_type": record["trail_type"],
+ "style": record["style"],
+ "language": record.get("_language", ""),
+ "model": record.get("_model", ""),
+ "text": record["text"],
+ }
+
+
+pipeline = (
+ Seed.product(
+ SeedDimension(
+ columns=["label", "label_description"],
+ values=LABELS,
+ ),
+ Seed.values("trail_type", TRAIL_TYPES),
+ Seed.values("style", STYLES),
+ ).as_step("seed_trail_report_grid")
+ >> LLMStep(
+ prompt=PROMPT_PATH,
+ input_columns=["label", "label_description", "trail_type", "style"],
+ output_column="text",
+ parse_mode="text",
+ model=MODELS,
+ language=LANGUAGES,
+ ).as_step("generate_trail_reports")
+ >> Map(keep_output_fields).as_step("keep_output_fields")
+ >> AddUUID(column="id", overwrite=True).as_step("add_uuid")
+ >> Sink.jsonl(OUTPUT_PATH)
+)
+
+
+def main() -> None:
+ print(f"Expected rows: {EXPECTED_ROWS}")
+ records = pipeline.run(
+ batch_size=4,
+ checkpoint_dir=CHECKPOINT_DIR,
+ resume=True,
+ )
+
+ if os.getenv("DATAFAST_PUSH_TO_HUB") == "1":
+ list(
+ Sink.hub(
+ repo_id=HF_REPO_ID,
+ private=False,
+ train_size=0.8,
+ seed=SEED,
+ shuffle=True,
+ commit_message=(
+ "Publish cookbook 45 classification dataset with "
+ f"{', '.join(MODEL_IDS)}"
+ ),
+ ).process(records)
+ )
+
+ print(f"Wrote {len(records)} records to {OUTPUT_PATH}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/llm_provider_requirements.md b/llm_provider_requirements.md
new file mode 100644
index 0000000..df82303
--- /dev/null
+++ b/llm_provider_requirements.md
@@ -0,0 +1,259 @@
+# LLM Provider Requirements (Draft)
+
+## Goal
+
+Design a clean model-provider layer for `datafast/llms.py` with one stable Datafast API, while resolving actual support per target model or deployment.
+
+The key design rule is:
+
+- The public API should provide a uniform core model.
+- The public API should also provide ergonomic provider-specific entry points.
+- Capabilities should be resolved per target: provider + endpoint + model + optional self-hosted server behavior.
+
+## Core Design Principles
+
+- Keep a small common config surface for normal usage.
+- Do not assume all models under one provider support the same parameters.
+- Do not silently pass unsupported parameters unless that behavior is explicitly enabled.
+- Preserve provider or server defaults when the user does not override them.
+- Separate Datafast-level config from provider-specific request mapping.
+
+## Common Datafast Config
+
+Every target should support these common fields when applicable:
+
+- `model_id`
+- `temperature`
+- `rpm_limit`
+- `timeout`
+
+Optional fields, only sent when supported:
+
+- `max_completion_tokens`
+- `thinking`
+- `reasoning_effort`
+- `api_key`
+- `api_base_url`
+- retry limit
+- `unsupported_params`
+
+`unsupported_params` should control how Datafast handles user-specified parameters that are known to be unsupported by the resolved target.
+
+- `fail`: raise a clear error before sending the request
+- `warn`: omit the unsupported parameter and emit a warning
+- `quiet`: omit the unsupported parameter silently
+
+Default:
+
+- `unsupported_params="warn"`
+
+## Public API Ergonomics
+
+The public API should expose provider-specific entry points such as:
+
+- `openai(...)`
+- `anthropic(...)`
+- `openrouter(...)`
+- `mistral(...)`
+- `ollama(...)`
+
+Requirements:
+
+- Provider-specific entry points should be the primary ergonomic API for users.
+- They should make provider choice explicit and easy to read in pipelines.
+- They should expose sensible provider-specific defaults and validation.
+- They should share the same common config surface where possible.
+- They may expose provider-specific options when needed, without forcing those options into every provider API.
+- They should remain thin wrappers over a shared internal target/config system.
+- Core execution behavior such as retries, batching, capability resolution, caching, and parsing should not live separately in each provider wrapper.
+
+## Capability Resolution
+
+Requirements should be defined around resolved target capabilities, not provider classes alone.
+
+That means:
+
+- OpenAI-compatible transport does not imply OpenAI-equivalent features.
+- OpenRouter support is model-specific, not just provider-specific.
+- Local servers such as Ollama, vLLM, and `llama.cpp` may expose different controls even when they look OpenAI-compatible.
+- Local servers may emulate an endpoint shape without matching the full upstream semantics.
+- When support is unknown, the safe default is to omit optional params rather than optimistically send them.
+
+The design should allow:
+
+- capability mapping per model or deployment
+- endpoint-mode resolution per target, especially chat completions vs Responses API
+- provider-specific parameter aliases
+- explicit escape hatches for provider-specific params
+- controlled dropping of unsupported params when intentionally enabled
+
+Unsupported-parameter handling should be explicit and user-configurable through `unsupported_params`.
+
+- The policy should apply to Datafast-known unsupported parameters for the resolved target.
+- The default behavior should be `warn`.
+- `quiet` should be allowed for users who intentionally want best-effort portability.
+- `fail` should be available for users who want strict validation.
+
+Some targets may work best through `completion()` and others through `responses()`. The public Datafast API should not force users to care about that distinction, but the internal adapter layer should.
+
+Requirements should also allow target-level compatibility notes such as:
+
+- chat endpoint requires a compatible chat template
+- a parameter is accepted but ignored
+- an endpoint is available but implemented as an internal translation layer
+
+## Request / Response Model
+
+Datafast should expose one request model that supports:
+
+- single request
+- concurrent batch requests
+- prompt input
+- message input
+- structured output via Pydantic
+
+The execution layer should support both:
+
+- native same-target batching for many inputs to one resolved model/deployment when available
+- fallback concurrency when native batching is unavailable
+
+If native batching is unavailable and Datafast falls back to parallel single requests, the user should be warned that a fallback execution path is being used.
+
+The message model should support both:
+
+- simple text messages
+- typed multimodal content parts
+
+Supported content parts should include a common shape for:
+
+- text
+- image
+- audio
+- video
+- file
+- document
+
+This keeps the public API compatible with multimodal-capable chat models without forcing separate provider APIs for each modality.
+
+Content parts should also be able to carry optional stable media IDs / UUIDs for targets that can reuse multimodal processing across requests.
+
+## Multimodal Requirements
+
+- Multimodal input support must be capability-aware per target.
+- A model that supports text-only should still work with the same public call shape.
+- A model that supports image, audio, video, document, or file inputs should accept typed content parts in `messages`.
+- The design should also allow non-text outputs when supported, especially image-generation-capable chat models.
+- Structured output and multimodal input should coexist when the target supports both.
+- The design should support targets that expose multimodal and reasoning features primarily through the Responses API.
+- The design should not assume all local backends support the same modalities. For example, support for image, audio, video, and document inputs may differ substantially between vLLM and `llama.cpp`.
+- The design should allow target-specific media options when needed, without polluting the common API surface.
+
+## Reliability and Execution
+
+Every LLM call should have a standard execution policy:
+
+- bounded retries
+- exponential backoff
+- jitter
+- retryable vs non-retryable error handling
+- consistent timeout handling
+- client-side RPM throttling
+
+Batch execution should:
+
+- preserve input order
+- apply the same retry and timeout rules as single requests
+- use native same-target batching when available
+- fall back to controlled concurrency when native batching is unavailable
+- warn the user when fallback concurrency is used instead of native batching
+
+## Endpoint Mode Requirements
+
+The design should explicitly allow multiple endpoint modes behind one public API.
+
+- Some targets should be called through chat completions.
+- Some targets should be called through the Responses API.
+- Endpoint choice should be resolved per target capability, not hardcoded per provider class.
+- Responses API support matters for targets that expose reasoning, multimodal I/O, image generation, or session continuity through that endpoint.
+- When the Responses API is used, the design should allow carrying forward response-session state such as `previous_response_id` when needed.
+- The requirements should not assume that every Responses API implementation is native. A local backend may expose `/v1/responses` by translating it into another internal request shape.
+
+## Caching Requirements
+
+Caching should be part of the design, but not assumed to behave the same across targets.
+
+The requirements should distinguish:
+
+- provider-native prompt caching
+- gateway or routing-layer caching
+- local server prefix / KV caching
+- optional client-side result caching
+
+Key requirements:
+
+- caching must be explicit and correctness-preserving
+- cache behavior must be capability-aware per target
+- cache keys or cache hints must account for model, endpoint, relevant generation params, and multimodal inputs
+- provider-specific caching controls should be supported through the mapping layer or escape hatch
+- the public API should not promise identical cache semantics across OpenAI, Anthropic, Mistral, OpenRouter, Ollama, vLLM, and `llama.cpp`
+
+The requirements should also distinguish between:
+
+- provider-side prompt caching semantics
+- prefix / KV-cache reuse for repeated prompt prefixes
+- multimodal preprocessing cache reuse keyed by stable media identity
+
+In particular, local backends may expose caching mainly as performance-oriented KV reuse rather than provider-managed prompt caching. That should be modeled explicitly.
+
+## What To Keep From The Current Design
+
+The current `llms.py` points to a few good design directions that should remain in the requirements:
+
+- one stable API for single and batch calls
+- first-class structured output
+- proactive client-side rate limiting
+- standard retry behavior
+- graceful fallback when a target lacks native batching
+- support for local backends without requiring an API key
+- tracing / metadata hooks on every request
+
+## Recommended Direction
+
+The optimal design is:
+
+- provider-specific public factories as thin entry points
+- one common Datafast request/config model
+- one target capability layer
+- one shared execution layer for retries, throttling, batching, caching, and parsing
+- thin internal provider adapters that only map Datafast requests into target-specific LiteLLM calls
+
+The capability layer should be able to describe at least:
+
+- supported endpoint modes
+- supported modalities
+- structured-output mechanism
+- cache mechanism type
+- chat-template or prompt-format requirements
+- parameter caveats such as unsupported, ignored, translated, or model-dependent
+
+This keeps the user-facing API simple while allowing model-specific behavior where it actually belongs.
+
+## References
+
+- LiteLLM provider-specific params:
+- LiteLLM drop unsupported params:
+- LiteLLM retries / fallbacks:
+- LiteLLM batching:
+- LiteLLM Responses API:
+- LiteLLM structured output / JSON mode:
+- LiteLLM reasoning content:
+- LiteLLM vision:
+- LiteLLM audio:
+- LiteLLM document understanding:
+- LiteLLM image generation in chat:
+- vLLM online serving:
+- vLLM structured outputs:
+- vLLM automatic prefix caching:
+- vLLM multimodal inputs:
+- llama.cpp server:
+- llama.cpp multimodal:
diff --git a/llm_provider_test_plan.md b/llm_provider_test_plan.md
new file mode 100644
index 0000000..71bad93
--- /dev/null
+++ b/llm_provider_test_plan.md
@@ -0,0 +1,285 @@
+# LLM Provider Test Plan (Draft)
+
+## Goal
+
+Test the provider redesign without exploding the matrix.
+
+Main idea:
+
+- Test shared behavior once at the common layer.
+- Test only provider/model deltas at the capability layer.
+- Run a meaningful live suite against selected real models.
+- Keep the live suite maintainable through a small curated model catalog.
+- Defer multimodal live coverage until after the first stable text-first provider test suite is in place.
+- Defer caching coverage until after the first stable text-first provider test suite is in place.
+
+## Test Layers
+
+| Layer | Purpose | Typical tools |
+|---|---|---|
+| Unit / contract | Validate request normalization, capability resolution, retry logic, batching decisions, parsing, caching decisions | mocked LiteLLM / fake adapters |
+| Adapter tests | Verify mapping from Datafast request to LiteLLM request per endpoint mode | mocked `completion()`, `batch_completion()`, `responses()` |
+| Live acceptance | Verify selected real models are safe for Datafast users | live API / local server |
+
+## Marker Strategy
+
+Recommended markers:
+
+- `live`: any test hitting a real provider endpoint
+- `multimodal`: reserved for later image / audio / document / video coverage
+- `ollama`: real Ollama backend
+- `vllm`: real vLLM backend
+- `llamacpp`: real `llama.cpp` backend
+
+Suggested usage:
+
+- default CI: mocked tests only
+- provider CI / pre-release: `-m live`
+- targeted local runs: `-m "live and ollama"` / `-m "live and vllm"` / `-m "live and llamacpp"`
+
+## Matrix Reduction Strategy
+
+- Do not test every feature against every provider.
+- Run a compact acceptance suite against a curated list of selected models.
+- Choose one representative provider/model endpoint per endpoint mode for mocked tests.
+- Choose one representative provider/model endpoint per modality for deeper live tests.
+- For each provider, test only what is different from the shared contract.
+- Keep local-backend tests separate from hosted-provider smoke tests.
+
+## Selected Model Catalog
+
+Maintain one curated list of current supported / recommended test targets per provider.
+
+This catalog should not aim to include every available model. It should be a curated test surface for capability coverage and user confidence, not a registry of all provider inventory.
+
+Each catalog entry should record at least:
+
+- provider
+- model_id
+- endpoint mode
+- hosted vs local
+- expected modalities
+- expected structured-output support
+- expected reasoning / thinking support
+- expected batching behavior
+- expected cache mechanism type
+- test markers to apply, such as `live`, `multimodal`, `ollama`, `vllm`, `llamacpp`
+
+Design goal:
+
+- adding a new model should usually mean adding one catalog entry
+- most live tests should parametrize over that catalog
+- provider/model-specific regressions should be captured as capability expectations in the catalog
+
+Models are good candidates for the catalog when they are:
+
+- recommended to Datafast users
+- used in docs or examples
+- representative of a distinct capability shape
+- newly added and worth validating before being treated as supported
+- known to be tricky or historically unstable
+
+Models are usually not good candidates when they:
+
+- do not add meaningful new capability coverage
+- are deprecated or not intended for ongoing support
+- are only one of many near-identical variants from the same provider
+
+### Current Catalog Decisions
+
+Current agreed shortlist as of June 2026:
+
+- OpenAI: `gpt-5.5`, `gpt-5.4`, `gpt-5.4-mini`, `gpt-5.4-nano`
+- Anthropic: `claude-sonnet-4-6`, `claude-haiku-4-5`
+- Gemini: `gemini-2.5-pro`, `gemini-3.5-flash`, `gemini-3.1-flash-lite`
+- Mistral hosted: `mistral-medium-3-5`, `mistral-large-2512`, `mistral-small-2603`
+- Mistral local / self-hosted: `ministral-14b-2512`, `ministral-8b-2512`, `ministral-3b-2512`
+
+Current exclusions / constraints:
+
+- Exclude Anthropic `claude-fable-5` and `claude-opus-4-8` due to cost.
+- Exclude Gemini `gemini-2.5-flash`.
+- Keep the catalog curated for capability coverage, not exhaustive by provider inventory.
+- Keep hosted Mistral and local Mistral entries separate in the catalog.
+- Treat local-server capability expectations as backend-specific, especially for `vLLM`, `llama.cpp`, and other OpenAI-compatible servers.
+- If a compact local Mistral subset is needed later, start with `ministral-8b-2512` and `ministral-3b-2512`.
+
+## Live Acceptance Suite
+
+These should run against the curated selected-model catalog.
+
+| ID | Test |
+|---|---|
+| L01 | Basic text generation works for every selected live model |
+| L02 | Structured output works for every selected live model that claims support |
+| L03 | Batch request works for every selected live model using the expected execution path, and emits a warning if fallback batching is used |
+| L04 | Common params such as `timeout` and `temperature` are accepted or handled according to capability expectations |
+| L05 | Declared unsupported params follow `unsupported_params` policy as expected for that model |
+| L06 | Endpoint mode matches expectation: chat completions vs Responses API |
+| L07 | Provider-specific factory entry point works for that model |
+| L08 | Metadata / tracing path does not break live requests |
+
+For local backends, include:
+
+| ID | Test |
+|---|---|
+| L09 | `api_base_url` path works |
+| L10 | no-API-key path works where expected |
+
+## Core Contract Tests
+
+These should run with mocks only.
+
+| ID | Test |
+|---|---|
+| C01 | Factory functions such as `openai(...)`, `openrouter(...)`, `ollama(...)` create the expected internal target/config shape |
+| C02 | Single prompt returns a single result |
+| C03 | Batch prompts return ordered list results |
+| C04 | `messages` input works for single request |
+| C05 | Batched `messages` input works and preserves order |
+| C06 | Reject `prompt=None` and `messages=None` |
+| C07 | Reject providing both `prompt` and `messages` |
+| C08 | Structured output with Pydantic parses successfully |
+| C09 | Structured output surfaces a clear validation error on invalid JSON / schema mismatch |
+| C10 | Text responses are normalized consistently |
+| C11 | Metadata / tracing payload is attached to requests |
+
+## Capability Layer Tests
+
+These should validate the resolved target rules.
+
+| ID | Test |
+|---|---|
+| K01 | Supported params are forwarded for a target that allows them |
+| K02 | Unsupported params are omitted by default when capability is unknown |
+| K03 | `unsupported_params="warn"` omits unsupported params and emits a warning |
+| K04 | `unsupported_params="fail"` raises a clear error before request dispatch |
+| K05 | `unsupported_params="quiet"` omits unsupported params without warning |
+| K06 | Provider-specific aliases map correctly to the internal common config |
+| K07 | `thinking=False` suppresses `reasoning_effort` |
+| K08 | `thinking=True` with no explicit `reasoning_effort` uses target default |
+| K09 | Endpoint mode resolves correctly: chat completions vs Responses API |
+| K10 | Capability notes such as "accepted but ignored" or "translated internally" are represented correctly |
+| K11 | OpenAI-compatible target is not assumed to support all OpenAI features |
+| K12 | Local target requiring a chat template is flagged correctly |
+
+## Adapter Tests
+
+These verify the LiteLLM call shape.
+
+| ID | Test |
+|---|---|
+| A01 | Chat-completions target calls `litellm.completion()` for single input |
+| A02 | Native same-target batch calls `litellm.batch_completion()` when supported |
+| A03 | If native batching is unavailable, batch input is executed via bounded parallel single requests, preserves ordered batch outputs, and emits a user warning |
+| A04 | Responses target calls `litellm.responses()` |
+| A05 | Responses target forwards `previous_response_id` when present |
+| A06 | Structured output maps to the correct LiteLLM field per endpoint mode |
+| A07 | Provider-specific extra params pass only through the escape hatch |
+| A08 | `api_base_url` and optional `api_key` are passed correctly for local / self-hosted targets |
+
+## Reliability Tests
+
+| ID | Test |
+|---|---|
+| R01 | Retryable error triggers bounded retries |
+| R02 | Non-retryable error fails immediately |
+| R03 | Backoff grows across retries |
+| R04 | Jitter is applied within the expected range |
+| R05 | Timeout is forwarded and timeout failure is surfaced clearly |
+| R06 | Client-side `rpm_limit` throttles before provider error |
+| R07 | Batch retry behavior preserves output ordering |
+
+## Multimodal Tests
+
+Multimodal coverage should come later.
+
+For the first rollout:
+
+- keep multimodal tests out of the required live acceptance suite
+- allow a small number of mocked multimodal contract tests if useful
+- add real multimodal coverage only after the text-first live suite is stable
+
+| ID | Test |
+|---|---|
+| M01 | Text-only message content remains supported |
+| M02 | Image content part is accepted for a target with image input support |
+| M03 | Audio content part is accepted for a target with audio input support |
+| M04 | Video content part is accepted for a target with video input support |
+| M05 | File / document content part is accepted for a target with document support |
+| M06 | Unsupported modality is rejected clearly for a text-only target |
+| M07 | Mixed text + image multimodal message preserves part order |
+| M08 | Stable media ID / UUID is forwarded when provided |
+| M09 | Non-text output path is selected correctly for image-generation-capable chat target |
+
+## Caching Tests
+
+Caching coverage should come later.
+
+For the first rollout:
+
+- keep caching tests out of the required live acceptance suite
+- allow mocked cache-resolution tests if useful
+- add real cache-behavior coverage only after the text-first live suite is stable
+
+| ID | Test |
+|---|---|
+| H01 | Cache mode resolves correctly for provider-native prompt caching |
+| H02 | Cache mode resolves correctly for local prefix / KV caching |
+| H03 | Cache key / cache hint changes when model changes |
+| H04 | Cache key / cache hint changes when relevant generation params change |
+| H05 | Cache key / cache hint changes when multimodal input identity changes |
+| H06 | Stable media identity enables multimodal reuse hint when supported |
+| H07 | Public API does not claim cache hit semantics that the target cannot guarantee |
+
+## Provider / Model Delta Live Tests
+
+Add only when a selected model has behavior that differs meaningfully from the common suite.
+
+| ID | Example |
+|---|---|
+| D01 | Responses-only reasoning model |
+| D02 | OpenRouter model with provider-specific capability caveat |
+| D03 | vLLM deployment with structured-output expectations |
+| D04 | `llama.cpp` target with chat-template requirement |
+| D05 | model with unusual unsupported-param behavior expectations |
+| D06 | multimodal model with image input support |
+| D07 | cache-relevant local backend behavior |
+
+## Extended Live Scenarios
+
+These are later-phase tests, not required for the initial rollout.
+
+| ID | Target | Test |
+|---|---|---|
+| E01 | Multimodal hosted model | text + image input |
+| E02 | Audio or document-capable model | real multimodal request |
+| E03 | Structured-output target | real Pydantic schema validation |
+| E04 | Provider with prompt caching | repeated request with cache-relevant setup |
+| E05 | vLLM | prefix-cache-friendly repeated prompt |
+| E06 | local multimodal target | document or image input if supported |
+| E07 | Responses target | `previous_response_id` continuation |
+| E08 | selected-model sweep | run the full acceptance suite across the full catalog |
+
+## New Model Onboarding
+
+When a new model comes out:
+
+1. Add it to the selected-model catalog with expected capabilities.
+2. Run the shared live acceptance suite against it.
+3. Add a provider/model delta test only if it differs from the standard expectations.
+4. Add an extended live scenario only if it adds meaningful new capability coverage.
+
+## Suggested Priorities
+
+- Phase 1: `C*`, `K*`, `A*`, `R*`
+- Phase 2: selected-model `L*` live suite
+- Phase 3: `M*`, `H*`, `D*`
+- Phase 4: `E*`
+
+## Success Criteria
+
+- Shared behavior is covered mostly by fast mocked tests.
+- The curated live suite gives confidence against real provider endpoints.
+- Provider/model-specific logic is tested as deltas, not full re-runs of the whole matrix.
+- Adding a new model is mostly a catalog update plus, if needed, one delta test.
diff --git a/mkdocs.yml b/mkdocs.yml
index 87e795a..131400c 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -43,6 +43,11 @@ nav:
- LLM Steps: guides/llm_steps.md
- Checkpointing: guides/checkpointing.md
- Langfuse Tracing: guides/langfuse_tracing.md
+ - Cookbook:
+ - cookbook/index.md
+ - Text Classification: cookbook/text_classification.md
+ - Persona Generation: cookbook/persona_generation.md
+ - Space Engineering Text Generation: cookbook/space_text_generation.md
- Providers: llms.md
- Models: models.md
- API: api.md
diff --git a/pytest.ini b/pytest.ini
index 798f789..042626f 100644
--- a/pytest.ini
+++ b/pytest.ini
@@ -1,6 +1,11 @@
[pytest]
markers =
integration: marks tests that require API connectivity (deselect with '-m "not integration"')
+ live: marks tests that hit a real provider endpoint
+ multimodal: marks tests that exercise multimodal provider behavior
+ ollama: marks tests that require a real Ollama backend
+ vllm: marks tests that require a real vLLM backend
+ llamacpp: marks tests that require a real llama.cpp backend
slow: marks tests that are slow to run
# Other pytest configurations
diff --git a/tests/conftest.py b/tests/conftest.py
new file mode 100644
index 0000000..961d50b
--- /dev/null
+++ b/tests/conftest.py
@@ -0,0 +1,20 @@
+import pytest
+
+
+def pytest_addoption(parser):
+ parser.addoption(
+ "--run-live",
+ action="store_true",
+ default=False,
+ help="run tests marked live or integration",
+ )
+
+
+def pytest_collection_modifyitems(config, items):
+ if config.getoption("--run-live"):
+ return
+
+ skip_live = pytest.mark.skip(reason="requires --run-live")
+ for item in items:
+ if "live" in item.keywords or "integration" in item.keywords:
+ item.add_marker(skip_live)
diff --git a/tests/test_add_uuid.py b/tests/test_add_uuid.py
new file mode 100644
index 0000000..e89f837
--- /dev/null
+++ b/tests/test_add_uuid.py
@@ -0,0 +1,78 @@
+import uuid
+
+from datafast import AddUUID, LLMStep, Sink, Source
+
+
+def assert_valid_uuid(value: str) -> None:
+ parsed = uuid.UUID(value)
+ assert str(parsed) == value
+
+
+def test_add_uuid_adds_id_when_missing():
+ records = list(AddUUID().process([{"text": "hello"}]))
+
+ assert records[0]["text"] == "hello"
+ assert_valid_uuid(records[0]["id"])
+
+
+def test_add_uuid_preserves_existing_id_by_default():
+ records = list(AddUUID().process([{"id": "source-1", "text": "hello"}]))
+
+ assert records == [{"id": "source-1", "text": "hello"}]
+
+
+def test_add_uuid_overwrites_existing_id_when_requested():
+ records = list(
+ AddUUID(overwrite=True).process([{"id": "source-1", "text": "hello"}])
+ )
+
+ assert records[0]["text"] == "hello"
+ assert records[0]["id"] != "source-1"
+ assert_valid_uuid(records[0]["id"])
+
+
+def test_add_uuid_generates_distinct_ids_for_multiple_records():
+ records = list(AddUUID().process([{"text": "a"}, {"text": "b"}]))
+ ids = [record["id"] for record in records]
+
+ assert len(set(ids)) == 2
+ for value in ids:
+ assert_valid_uuid(value)
+
+
+def test_add_uuid_supports_custom_column_name():
+ records = list(AddUUID(column="example_id").process([{"text": "hello"}]))
+
+ assert "id" not in records[0]
+ assert_valid_uuid(records[0]["example_id"])
+
+
+def test_add_uuid_assigns_unique_ids_to_llm_num_outputs_pipeline():
+ class FakeModel:
+ model_id = "fake-model"
+ provider_name = "fake"
+
+ def generate(self, messages, metadata=None):
+ return '{"title": "Generated", "text": "Body"}'
+
+ pipeline = (
+ Source.list([{"topic": "vacuum"}])
+ >> LLMStep(
+ prompt="Write about {topic}.",
+ input_columns=["topic"],
+ output_columns=["title", "text"],
+ parse_mode="json",
+ model=FakeModel(),
+ num_outputs=2,
+ )
+ >> AddUUID()
+ >> Sink.list()
+ )
+
+ records = pipeline.run()
+ ids = [record["id"] for record in records]
+
+ assert len(records) == 2
+ assert len(set(ids)) == 2
+ for value in ids:
+ assert_valid_uuid(value)
diff --git a/tests/test_llm_provider_contract.py b/tests/test_llm_provider_contract.py
new file mode 100644
index 0000000..8928e5e
--- /dev/null
+++ b/tests/test_llm_provider_contract.py
@@ -0,0 +1,503 @@
+import pytest
+from pydantic import BaseModel
+
+import datafast.llm.provider as provider_module
+from datafast import LLMStep, ListSink, Source
+from datafast.llm import (
+ ContentPart,
+ EndpointMode,
+ Modality,
+ OpenAIProvider,
+ OpenRouterProvider,
+ openai,
+ openai_compatible,
+)
+
+
+class SimpleSchema(BaseModel):
+ answer: str
+
+
+class _DummyMessage:
+ def __init__(
+ self,
+ content,
+ reasoning_content=None,
+ thinking_blocks=None,
+ images=None,
+ audio=None,
+ ):
+ self.content = content
+ self.reasoning_content = reasoning_content
+ self.thinking_blocks = thinking_blocks
+ self.images = images
+ self.audio = audio
+
+
+class _DummyChoice:
+ def __init__(
+ self,
+ content,
+ reasoning_content=None,
+ thinking_blocks=None,
+ images=None,
+ audio=None,
+ ):
+ self.message = _DummyMessage(
+ content,
+ reasoning_content,
+ thinking_blocks,
+ images,
+ audio,
+ )
+
+
+class _DummyChatResponse:
+ def __init__(
+ self,
+ content,
+ reasoning_content=None,
+ thinking_blocks=None,
+ images=None,
+ audio=None,
+ ):
+ self.choices = [
+ _DummyChoice(
+ content,
+ reasoning_content,
+ thinking_blocks,
+ images,
+ audio,
+ )
+ ]
+
+
+class _DummyResponsesResponse:
+ def __init__(self, output_text=None, output=None, reasoning_content=None):
+ self.output_text = output_text
+ self.output = output
+ self.reasoning_content = reasoning_content
+
+
+@pytest.fixture(autouse=True)
+def _disable_provider_side_effects(monkeypatch):
+ monkeypatch.setattr(provider_module, "load_env_once", lambda: None)
+ monkeypatch.setattr(
+ provider_module,
+ "maybe_configure_langfuse_tracing",
+ lambda load_env=False: False,
+ )
+
+
+def test_factories_resolve_expected_targets():
+ hosted = openai(api_key="test-key")
+ local = openai_compatible(
+ "ministral-8b-2512",
+ api_base_url="http://localhost:8000/v1",
+ )
+
+ assert hosted.provider_name == "openai"
+ assert hosted.endpoint_mode == EndpointMode.RESPONSES
+ assert hosted._get_model_string() == "openai/gpt-5.5"
+
+ assert local.provider_name == "openai_compatible"
+ assert local.endpoint_mode == EndpointMode.CHAT
+ assert local.api_base_url == "http://localhost:8000/v1"
+
+
+def test_openai_compatible_backend_profiles_are_distinct():
+ generic = openai_compatible(
+ "local-model",
+ api_base_url="http://localhost:8000/v1",
+ )
+ vllm = openai_compatible(
+ "local-model",
+ api_base_url="http://localhost:8000/v1",
+ backend="vllm",
+ )
+ llamacpp = openai_compatible(
+ "local-model",
+ api_base_url="http://localhost:8080/v1",
+ backend="llamacpp",
+ )
+
+ assert generic.provider_name == "openai_compatible"
+ assert generic.capabilities.modalities == frozenset({Modality.TEXT})
+
+ assert vllm.provider_name == "vllm"
+ assert vllm.capabilities.supports_endpoint(EndpointMode.RESPONSES)
+ assert Modality.IMAGE in vllm.capabilities.modalities
+ assert Modality.VIDEO in vllm.capabilities.modalities
+
+ assert llamacpp.provider_name == "llamacpp"
+ assert Modality.AUDIO in llamacpp.capabilities.modalities
+ assert Modality.FILE in llamacpp.capabilities.modalities
+
+
+def test_input_validation_rejects_missing_or_ambiguous_inputs():
+ provider = OpenRouterProvider(model_id="demo-model", api_key="test-key")
+
+ with pytest.raises(ValueError, match="Either prompt or messages"):
+ provider.generate()
+
+ with pytest.raises(ValueError, match="either prompt or messages"):
+ provider.generate(prompt="hello", messages=[{"role": "user", "content": "hi"}])
+
+
+def test_unsupported_params_warn_and_omit(monkeypatch):
+ captured = {}
+
+ def fake_completion(**kwargs):
+ captured.update(kwargs)
+ return _DummyChatResponse("ok")
+
+ monkeypatch.setattr(provider_module.litellm, "completion", fake_completion)
+
+ provider = openai_compatible(
+ "local-model",
+ api_base_url="http://localhost:8000/v1",
+ temperature=0.7,
+ )
+
+ with pytest.warns(UserWarning, match="temperature"):
+ assert provider.generate(prompt="ping") == "ok"
+
+ assert "temperature" not in captured
+ assert captured["api_base"] == "http://localhost:8000/v1"
+
+
+def test_unsupported_params_fail_before_dispatch(monkeypatch):
+ def fake_completion(**kwargs):
+ raise AssertionError("request should not be dispatched")
+
+ monkeypatch.setattr(provider_module.litellm, "completion", fake_completion)
+
+ provider = openai_compatible(
+ "local-model",
+ api_base_url="http://localhost:8000/v1",
+ temperature=0.7,
+ unsupported_params="fail",
+ )
+
+ with pytest.raises(ValueError, match="temperature"):
+ provider.generate(prompt="ping")
+
+
+def test_chat_endpoint_warns_and_omits_previous_response_id(monkeypatch):
+ captured = {}
+
+ def fake_completion(**kwargs):
+ captured.update(kwargs)
+ return _DummyChatResponse("ok")
+
+ monkeypatch.setattr(provider_module.litellm, "completion", fake_completion)
+
+ provider = OpenRouterProvider(model_id="demo-model", api_key="test-key")
+
+ with pytest.warns(UserWarning, match="previous_response_id"):
+ assert provider.generate(prompt="ping", previous_response_id="resp_old") == "ok"
+
+ assert "previous_response_id" not in captured
+
+
+def test_openrouter_thinking_warns_and_omits_reasoning_param(monkeypatch):
+ captured = {}
+
+ def fake_completion(**kwargs):
+ captured.update(kwargs)
+ return _DummyChatResponse("ok")
+
+ monkeypatch.setattr(provider_module.litellm, "completion", fake_completion)
+
+ provider = OpenRouterProvider(
+ model_id="nvidia/nemotron-3-super-120b-a12b:nitro",
+ api_key="test-key",
+ thinking=True,
+ )
+
+ with pytest.warns(UserWarning, match="reasoning_effort"):
+ assert provider.generate(prompt="ping") == "ok"
+
+ assert "reasoning_effort" not in captured
+ assert "reasoning" not in captured
+
+
+def test_provider_params_escape_hatch_is_forwarded(monkeypatch):
+ captured = {}
+
+ def fake_completion(**kwargs):
+ captured.update(kwargs)
+ return _DummyChatResponse("ok")
+
+ monkeypatch.setattr(provider_module.litellm, "completion", fake_completion)
+
+ provider = openai_compatible(
+ "local-model",
+ api_base_url="http://localhost:8000/v1",
+ provider_params={"extra_body": {"backend_hint": "vllm"}},
+ )
+
+ assert provider.generate(prompt="ping") == "ok"
+ assert captured["extra_body"] == {"backend_hint": "vllm"}
+
+
+def test_content_parts_normalize_multimodal_and_document_shapes():
+ vllm = openai_compatible(
+ "local-model",
+ api_base_url="http://localhost:8000/v1",
+ backend="vllm",
+ )
+ prepared = vllm._prepare_messages(
+ [
+ {
+ "role": "user",
+ "content": [
+ ContentPart(type="text", text="What is in this image?"),
+ ContentPart(
+ type="image",
+ url="https://example.com/image.png",
+ media_id="img-123",
+ ),
+ ContentPart(
+ type="video",
+ url="https://example.com/video.mp4",
+ media_id="vid-123",
+ ),
+ ],
+ }
+ ],
+ response_format=None,
+ )
+
+ assert prepared[0]["content"] == [
+ {"type": "text", "text": "What is in this image?"},
+ {
+ "type": "image_url",
+ "image_url": {"url": "https://example.com/image.png"},
+ "uuid": "img-123",
+ },
+ {
+ "type": "video_url",
+ "video_url": {"url": "https://example.com/video.mp4"},
+ "uuid": "vid-123",
+ },
+ ]
+
+ llamacpp = openai_compatible(
+ "local-model",
+ api_base_url="http://localhost:8080/v1",
+ backend="llamacpp",
+ )
+ prepared = llamacpp._prepare_messages(
+ [
+ {
+ "role": "user",
+ "content": [
+ ContentPart(
+ type="document",
+ data="data:application/pdf;base64,abc",
+ media_type="application/pdf",
+ ),
+ ],
+ }
+ ],
+ response_format=None,
+ )
+
+ assert prepared[0]["content"] == [
+ {
+ "type": "file",
+ "file": {"file_data": "data:application/pdf;base64,abc"},
+ }
+ ]
+
+
+def test_litellm_unsupported_params_can_retry_with_drop_params(monkeypatch):
+ unsupported_error = type("UnsupportedParamsError", (Exception,), {})
+ calls = []
+
+ def fake_completion(**kwargs):
+ calls.append(kwargs)
+ if len(calls) == 1:
+ raise unsupported_error("bad param")
+ return _DummyChatResponse("ok")
+
+ monkeypatch.setattr(provider_module.litellm, "completion", fake_completion)
+
+ provider = OpenRouterProvider(model_id="demo-model", api_key="test-key")
+
+ with pytest.warns(UserWarning, match="drop_params=True"):
+ assert provider.generate(prompt="ping") == "ok"
+
+ assert calls[0].get("drop_params") is None
+ assert calls[1]["drop_params"] is True
+
+
+def test_generate_response_preserves_litellm_reasoning_metadata(monkeypatch):
+ monkeypatch.setattr(
+ provider_module.litellm,
+ "completion",
+ lambda **kwargs: _DummyChatResponse(
+ "final answer",
+ reasoning_content="internal summary",
+ thinking_blocks=[
+ {
+ "type": "thinking",
+ "thinking": "visible thinking block",
+ "signature": "sig",
+ }
+ ],
+ images=[{"type": "image", "url": "https://example.com/out.png"}],
+ audio={"id": "audio-1", "expires_at": 123},
+ ),
+ )
+
+ provider = OpenRouterProvider(model_id="demo-model", api_key="test-key")
+ response = provider.generate_response(prompt="ping")
+
+ assert response.text == "final answer"
+ assert response.reasoning_content == "internal summary"
+ assert response.thinking_blocks == [
+ {
+ "type": "thinking",
+ "thinking": "visible thinking block",
+ "signature": "sig",
+ }
+ ]
+ assert response.images == [
+ {"type": "image", "url": "https://example.com/out.png"}
+ ]
+ assert response.audio == {"id": "audio-1", "expires_at": 123}
+
+
+def test_responses_full_response_preserves_output_items_and_media(monkeypatch):
+ output = [
+ {"type": "reasoning", "summary": [{"text": "short rationale"}]},
+ {"type": "image_generation_call", "result": "base64-image"},
+ {
+ "type": "message",
+ "content": [{"type": "output_text", "text": "Here is the image."}],
+ },
+ ]
+
+ monkeypatch.setattr(
+ provider_module.litellm,
+ "responses",
+ lambda **kwargs: _DummyResponsesResponse(output=output),
+ )
+
+ provider = OpenAIProvider(model_id="gpt-5.5", api_key="test-key")
+ response = provider.generate_response(prompt="make an image")
+
+ assert response.text == "Here is the image."
+ assert response.reasoning_content == "short rationale"
+ assert response.images == [
+ {"type": "image_generation_call", "result": "base64-image"}
+ ]
+ assert response.output_items == output
+
+
+def test_responses_endpoint_maps_reasoning_state_and_structured_output(monkeypatch):
+ captured = {}
+
+ def fake_responses(**kwargs):
+ captured.update(kwargs)
+ return _DummyResponsesResponse('{"answer": "Paris"}')
+
+ monkeypatch.setattr(provider_module.litellm, "responses", fake_responses)
+
+ provider = OpenAIProvider(
+ model_id="gpt-5.5",
+ api_key="test-key",
+ thinking=True,
+ max_completion_tokens=64,
+ )
+
+ result = provider.generate(
+ messages=[{"role": "user", "content": "capital?"}],
+ response_format=SimpleSchema,
+ previous_response_id="resp_previous",
+ metadata={"purpose": "test"},
+ )
+
+ assert result == SimpleSchema(answer="Paris")
+ assert captured["model"] == "openai/gpt-5.5"
+ assert captured["previous_response_id"] == "resp_previous"
+ assert captured["reasoning"] == {"effort": "low"}
+ assert captured["max_output_tokens"] == 64
+ assert captured["text_format"] is SimpleSchema
+ assert captured["metadata"]["purpose"] == "test"
+
+
+def test_fallback_batching_preserves_order(monkeypatch):
+ calls = []
+
+ def fake_completion(**kwargs):
+ calls.append(kwargs["messages"][0]["content"])
+ return _DummyChatResponse(f"reply:{kwargs['messages'][0]['content']}")
+
+ monkeypatch.setattr(provider_module.litellm, "completion", fake_completion)
+
+ provider = openai_compatible(
+ "local-model",
+ api_base_url="http://localhost:8000/v1",
+ max_concurrent=1,
+ )
+
+ with pytest.warns(UserWarning, match="Falling back"):
+ result = provider.generate(prompt=["one", "two", "three"])
+
+ assert result == ["reply:one", "reply:two", "reply:three"]
+ assert calls == ["one", "two", "three"]
+
+
+def test_structured_output_validation_error_is_clear(monkeypatch):
+ monkeypatch.setattr(
+ provider_module.litellm,
+ "completion",
+ lambda **kwargs: _DummyChatResponse("not json"),
+ )
+
+ provider = OpenRouterProvider(model_id="demo-model", api_key="test-key")
+
+ with pytest.raises(ValueError, match="Failed to parse JSON response"):
+ provider.generate(prompt="answer in json", response_format=SimpleSchema)
+
+
+def test_runner_dispatches_same_model_batches_through_generate_batch():
+ class FakeBatchModel:
+ provider_name = "fake"
+ model_id = "fake-model"
+
+ def __init__(self):
+ self.batches = []
+
+ def generate_batch(self, messages, metadata=None, response_format=None):
+ self.batches.append({"messages": messages, "metadata": metadata})
+ return ["first", "second"]
+
+ model = FakeBatchModel()
+ sink = ListSink()
+ pipeline = (
+ Source.list([{"topic": "alpha"}, {"topic": "beta"}])
+ >> LLMStep(
+ prompt="Write about {topic}.",
+ input_columns=["topic"],
+ output_column="result",
+ model=model,
+ )
+ >> sink
+ )
+
+ output = pipeline.run(batch_size=2)
+
+ assert output == [
+ {"topic": "alpha", "result": "first", "_model": "fake-model"},
+ {"topic": "beta", "result": "second", "_model": "fake-model"},
+ ]
+ assert len(model.batches) == 1
+ assert [batch[0]["content"] for batch in model.batches[0]["messages"]] == [
+ "Write about alpha.",
+ "Write about beta.",
+ ]
+ assert len(model.batches[0]["metadata"]) == 2
diff --git a/tests/test_llms_unit.py b/tests/test_llms_unit.py
new file mode 100644
index 0000000..1c67a6d
--- /dev/null
+++ b/tests/test_llms_unit.py
@@ -0,0 +1,107 @@
+import datafast.llms as llms_module
+from datafast.llms import OpenRouterProvider
+
+
+class _DummyMessage:
+ def __init__(self, content: str, **extra: object) -> None:
+ self.content = content
+ for key, value in extra.items():
+ setattr(self, key, value)
+
+
+class _DummyChoice:
+ def __init__(self, content: str, **extra: object) -> None:
+ self.message = _DummyMessage(content, **extra)
+
+
+class _DummyResponse:
+ def __init__(self, content: str, **extra: object) -> None:
+ self.choices = [_DummyChoice(content, **extra)]
+
+
+def test_openrouter_single_messages_use_completion(monkeypatch):
+ monkeypatch.setattr(llms_module, "load_env_once", lambda: None)
+ monkeypatch.setattr(
+ llms_module,
+ "maybe_configure_langfuse_tracing",
+ lambda load_env=False: False,
+ )
+
+ calls = {"completion": 0, "batch_completion": 0}
+
+ def fake_completion(**kwargs):
+ calls["completion"] += 1
+ assert kwargs["messages"] == [{"role": "user", "content": "ping"}]
+ return _DummyResponse("ok")
+
+ def fake_batch_completion(**kwargs):
+ calls["batch_completion"] += 1
+ raise AssertionError("single-message requests should not use batch_completion")
+
+ monkeypatch.setattr(llms_module.litellm, "completion", fake_completion)
+ monkeypatch.setattr(llms_module.litellm, "batch_completion", fake_batch_completion)
+
+ provider = OpenRouterProvider(model_id="demo-model", api_key="test-key")
+
+ response = provider.generate(messages=[{"role": "user", "content": "ping"}])
+
+ assert response == "ok"
+ assert calls == {"completion": 1, "batch_completion": 0}
+
+
+def test_openrouter_batch_messages_use_batch_completion(monkeypatch):
+ monkeypatch.setattr(llms_module, "load_env_once", lambda: None)
+ monkeypatch.setattr(
+ llms_module,
+ "maybe_configure_langfuse_tracing",
+ lambda load_env=False: False,
+ )
+
+ calls = {"completion": 0, "batch_completion": 0}
+
+ def fake_completion(**kwargs):
+ calls["completion"] += 1
+ raise AssertionError("batched requests should not use completion")
+
+ def fake_batch_completion(**kwargs):
+ calls["batch_completion"] += 1
+ assert len(kwargs["messages"]) == 2
+ return [_DummyResponse("first"), _DummyResponse("second")]
+
+ monkeypatch.setattr(llms_module.litellm, "completion", fake_completion)
+ monkeypatch.setattr(llms_module.litellm, "batch_completion", fake_batch_completion)
+
+ provider = OpenRouterProvider(model_id="demo-model", api_key="test-key")
+
+ response = provider.generate(messages=[
+ [{"role": "user", "content": "one"}],
+ [{"role": "user", "content": "two"}],
+ ])
+
+ assert response == ["first", "second"]
+ assert calls == {"completion": 0, "batch_completion": 1}
+
+
+def test_openrouter_generate_response_reads_reasoning_field(monkeypatch):
+ monkeypatch.setattr(llms_module, "load_env_once", lambda: None)
+ monkeypatch.setattr(
+ llms_module,
+ "maybe_configure_langfuse_tracing",
+ lambda load_env=False: False,
+ )
+
+ monkeypatch.setattr(
+ llms_module.litellm,
+ "completion",
+ lambda **kwargs: _DummyResponse(
+ "final answer",
+ reasoning="hidden chain of thought summary",
+ ),
+ )
+
+ provider = OpenRouterProvider(model_id="demo-model", api_key="test-key")
+
+ response = provider.generate_response(prompt="solve this")
+
+ assert response.text == "final answer"
+ assert response.reasoning_content == "hidden chain of thought summary"
diff --git a/tests/test_public_api.py b/tests/test_public_api.py
index 7eaf787..ac56477 100644
--- a/tests/test_public_api.py
+++ b/tests/test_public_api.py
@@ -1,4 +1,5 @@
from datafast import (
+ AddUUID,
Branch,
Classify,
Compare,
@@ -70,6 +71,7 @@ def test_factory_exports_are_available(monkeypatch):
assert Sink is not None
assert Seed is not None
assert Sample is not None
+ assert AddUUID is not None
assert Map is not None
assert FlatMap is not None
assert Filter is not None
diff --git a/tests/test_runner_llm_messages.py b/tests/test_runner_llm_messages.py
new file mode 100644
index 0000000..d870093
--- /dev/null
+++ b/tests/test_runner_llm_messages.py
@@ -0,0 +1,47 @@
+from datafast import LLMStep, ListSink, Source
+
+
+def test_runner_passes_llm_messages_by_keyword():
+ class FakeModel:
+ provider_name = "fake"
+ model_id = "fake-model"
+
+ def __init__(self) -> None:
+ self.calls: list[dict] = []
+
+ def generate(
+ self,
+ prompt=None,
+ messages=None,
+ metadata=None,
+ response_format=None,
+ ):
+ self.calls.append({
+ "prompt": prompt,
+ "messages": messages,
+ "metadata": metadata,
+ })
+ return "done"
+
+ model = FakeModel()
+ sink = ListSink()
+
+ pipeline = (
+ Source.list([{"topic": "robotics"}])
+ >> LLMStep(
+ prompt="Write one short line about {topic}.",
+ input_columns=["topic"],
+ output_column="result",
+ model=model,
+ ).as_step("generate_copy")
+ >> sink
+ )
+
+ output = pipeline.run()
+
+ assert output == [{"topic": "robotics", "result": "done", "_model": "fake-model"}]
+ assert len(model.calls) == 1
+ assert model.calls[0]["prompt"] is None
+ assert model.calls[0]["messages"] == [
+ {"role": "user", "content": "Write one short line about robotics."}
+ ]