diff --git a/.gitignore b/.gitignore index 5f96167..aa49fe5 100644 --- a/.gitignore +++ b/.gitignore @@ -103,7 +103,7 @@ ipython_config.py # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. # This is especially recommended for binary packages to ensure reproducibility, and is more # commonly ignored for libraries. -#uv.lock +uv.lock # poetry # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. @@ -187,4 +187,4 @@ examples/checkpoints/ examples/outputs/ .codex/ -openspec/ \ No newline at end of file +.agents/ \ No newline at end of file diff --git a/README.md b/README.md index 6521a97..90503e5 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ Datafast is a python library for synthetic data generation using llms. The old dataset-class API has been removed. The canonical package is `datafast`, and the primary model is: - create records with `Source` or `Seed` -- transform them with composable steps +- transform them with composable steps such as `AddUUID`, `Map`, and `Filter` - call LLMs with `LLMStep`, `Classify`, `Score`, `Compare`, `Rewrite`, or `Extract` - persist results with `Sink` @@ -53,7 +53,7 @@ pipeline.run(batch_size=4) - `Source`: load records from Python lists, files, or Hugging Face datasets - `Seed`: generate record combinations declaratively -- `Map`, `FlatMap`, `Filter`, `Group`, `Pair`, `Concat`, `Join`: data operations +- `AddUUID`, `Map`, `FlatMap`, `Filter`, `Group`, `Pair`, `Concat`, `Join`: data operations - `LLMStep`: free-form generation - `Classify`, `Score`, `Compare`, `Rewrite`, `Extract`: higher-level LLM transforms - `Branch` and `JoinBranches`: multi-path pipelines @@ -105,6 +105,7 @@ configure_langfuse_tracing() - `datafast/`: canonical source package - `examples/scripts/`: runnable pipeline examples +- `examples/providers/`: direct provider usage examples - `docs/`: pipeline-first documentation - `datafast_new_design_document.md`: retained design reference diff --git a/datafast/__init__.py b/datafast/__init__.py index 19bd452..5ffcae8 100644 --- a/datafast/__init__.py +++ b/datafast/__init__.py @@ -15,12 +15,14 @@ MistralProvider, OpenRouterProvider, OllamaProvider, + OpenAICompatibleProvider, openai, anthropic, gemini, mistral, openrouter, ollama, + openai_compatible, ) from datafast.logger_config import configure_logger from datafast.sinks.sink import Sink, JSONLSink, CSVSink, ListSink, ParquetSink, HubSink @@ -31,7 +33,7 @@ is_langfuse_tracing_enabled, ) from datafast.transforms.branch import Branch, JoinBranches -from datafast.transforms.data_ops import Map, FlatMap, Filter, Group, Pair, Concat, Join +from datafast.transforms.data_ops import AddUUID, Map, FlatMap, Filter, Group, Pair, Concat, Join from datafast.transforms.llm_eval import Classify, Score, Compare from datafast.transforms.llm_extract import Extract from datafast.transforms.llm_step import LLMStep @@ -64,6 +66,7 @@ def get_version() -> str: "Seed", "SeedDimension", "Sample", + "AddUUID", "Map", "FlatMap", "Filter", @@ -92,12 +95,14 @@ def get_version() -> str: "MistralProvider", "OpenRouterProvider", "OllamaProvider", + "OpenAICompatibleProvider", "openai", "anthropic", "gemini", "mistral", "openrouter", "ollama", + "openai_compatible", "configure_logger", "configure_langfuse_tracing", "get_version", diff --git a/datafast/core/runner.py b/datafast/core/runner.py index 0a28ba2..c7d280f 100644 --- a/datafast/core/runner.py +++ b/datafast/core/runner.py @@ -3,6 +3,7 @@ import time import uuid from collections import defaultdict +from dataclasses import dataclass from typing import TYPE_CHECKING from loguru import logger @@ -19,7 +20,6 @@ if TYPE_CHECKING: from datafast.core.step import Pipeline, Step - from datafast.llm.provider import LLMProvider from datafast.transforms.llm_step import LLMStep @@ -29,6 +29,12 @@ def chunked(iterable: list, size: int): yield iterable[i : i + size] +@dataclass +class _LLMBatchStats: + generated: int = 0 + errors: int = 0 + + class Runner: """ Execution engine for pipelines. @@ -218,77 +224,206 @@ def _execute_llm_step( if self._checkpoint_mgr and not skip_call_ids: self._checkpoint_mgr.clear_step_file(step_index, step_name) - completed_in_batch = 0 + completed_since_checkpoint = 0 errors = 0 generated_total = len(skip_call_ids) for batch in chunked(calls, self.config.batch_size): batch_start = time.perf_counter() - batch_generated = 0 - batch_model_id = batch[0].model_id if batch else "unknown" - - for call in batch: - model = models_map[call.model_id] - batch_model_id = call.model_id - - try: - result = model.generate( - call.messages, - metadata=build_trace_metadata( - model=model, - component="pipeline.step", - trace_name=f"datafast.{step_name}", - session_id=self._trace_session_id, - step_name=step_name, - step_type=step.__class__.__name__, - record_index=call.record_index, - prompt_index=call.prompt_index, - output_index=call.output_index, - language_code=call.language_code or None, - call_id=call.call_id, - ), - ) - output_record = step.apply_result(call, result, model) - output_records.append(output_record) - progress.completed_call_ids.append(call.call_id) - completed_in_batch += 1 - batch_generated += 1 - generated_total += 1 - - # Append record immediately to JSONL - if self._checkpoint_mgr: - self._checkpoint_mgr.append_record( - step_index, step_name, output_record - ) - - except Exception as e: - errors += 1 - logger.warning( - f"LLM call failed | Model: {call.model_id} | " - f"Call: {call.call_id} | Error: {e}" - ) + batch_model_id = ( + batch[0].model_id + if len({call.model_id for call in batch}) == 1 + else "mixed" + ) + stats = self._execute_llm_batch( + step=step, + step_name=step_name, + step_index=step_index, + batch=batch, + models_map=models_map, + progress=progress, + output_records=output_records, + ) + completed_since_checkpoint += stats.generated + errors += stats.errors + generated_total += stats.generated batch_duration = time.perf_counter() - batch_start logger.info( - f"Generated {batch_generated} samples (total: {generated_total}) | " + f"Generated {stats.generated} samples (total: {generated_total}) | " f"model: {batch_model_id} | duration: {batch_duration:.2f}s" ) if ( self._checkpoint_mgr and manifest - and completed_in_batch >= self.config.checkpoint_every + and completed_since_checkpoint >= self.config.checkpoint_every ): self._checkpoint_mgr.save_llm_progress( step_index, step_name, progress, output_records ) - completed_in_batch = 0 + completed_since_checkpoint = 0 logger.info( f"LLMStep complete: {len(output_records)} outputs, {errors} errors" ) return output_records + def _execute_llm_batch( + self, + *, + step: "LLMStep", + step_name: str, + step_index: int, + batch: list[LLMCall], + models_map: dict[str, object], + progress: LLMStepProgress, + output_records: list[Record], + ) -> _LLMBatchStats: + """Execute and apply one runner batch, preserving input order.""" + batch_results, errors = self._collect_llm_batch_results( + step=step, + step_name=step_name, + batch=batch, + models_map=models_map, + ) + stats = self._apply_llm_batch_results( + step=step, + step_name=step_name, + step_index=step_index, + batch=batch, + batch_results=batch_results, + models_map=models_map, + progress=progress, + output_records=output_records, + ) + stats.errors += errors + return stats + + def _collect_llm_batch_results( + self, + *, + step: "LLMStep", + step_name: str, + batch: list[LLMCall], + models_map: dict[str, object], + ) -> tuple[list[object | None], int]: + batch_results: list[object | None] = [None] * len(batch) + errors = 0 + grouped_indexes: dict[str, list[int]] = defaultdict(list) + + for index, call in enumerate(batch): + grouped_indexes[call.model_id].append(index) + + for model_id, indexes in grouped_indexes.items(): + group_calls = [batch[index] for index in indexes] + model = models_map[model_id] + try: + group_results = self._generate_llm_group( + step=step, + step_name=step_name, + model=model, + group_calls=group_calls, + ) + except Exception as e: + errors += len(group_calls) + self._log_llm_failures(group_calls, e) + continue + + for result_index, result in zip(indexes, group_results): + batch_results[result_index] = result + + return batch_results, errors + + def _generate_llm_group( + self, + *, + step: "LLMStep", + step_name: str, + model: object, + group_calls: list[LLMCall], + ) -> list[object]: + group_metadata = [ + self._build_llm_call_metadata(step, step_name, call, model) + for call in group_calls + ] + if hasattr(model, "generate_batch"): + return list( + model.generate_batch( # type: ignore[attr-defined] + [call.messages for call in group_calls], + metadata=group_metadata, + ) + ) + return [ + model.generate( # type: ignore[attr-defined] + messages=call.messages, + metadata=metadata, + ) + for call, metadata in zip(group_calls, group_metadata) + ] + + def _apply_llm_batch_results( + self, + *, + step: "LLMStep", + step_name: str, + step_index: int, + batch: list[LLMCall], + batch_results: list[object | None], + models_map: dict[str, object], + progress: LLMStepProgress, + output_records: list[Record], + ) -> _LLMBatchStats: + stats = _LLMBatchStats() + + for call, result in zip(batch, batch_results): + if result is None: + continue + try: + output_record = step.apply_result(call, result, models_map[call.model_id]) + output_records.append(output_record) + progress.completed_call_ids.append(call.call_id) + stats.generated += 1 + + if self._checkpoint_mgr: + self._checkpoint_mgr.append_record( + step_index, step_name, output_record + ) + except Exception as e: + stats.errors += 1 + self._log_llm_failures([call], e) + + return stats + + def _build_llm_call_metadata( + self, + step: "LLMStep", + step_name: str, + call: LLMCall, + model: object, + ) -> dict[str, object]: + return build_trace_metadata( + model=model, + component="pipeline.step", + trace_name=f"datafast.{step_name}", + session_id=self._trace_session_id, + step_name=step_name, + step_type=step.__class__.__name__, + record_index=call.record_index, + prompt_index=call.prompt_index, + output_index=call.output_index, + language_code=call.language_code or None, + call_id=call.call_id, + ) + + @staticmethod + def _log_llm_failures(calls: list[LLMCall], error: Exception) -> None: + for call in calls: + logger.warning( + f"LLM call failed | Model: {call.model_id} | " + f"Call: {call.call_id} | Error: {error}" + ) + def _order_calls(self, calls: list[LLMCall]) -> list[LLMCall]: """Order calls according to execution strategy.""" strategy = self.config.llm_strategy diff --git a/datafast/llm/__init__.py b/datafast/llm/__init__.py index 725ece6..eba520d 100644 --- a/datafast/llm/__init__.py +++ b/datafast/llm/__init__.py @@ -8,12 +8,27 @@ MistralProvider, OpenRouterProvider, OllamaProvider, + OpenAICompatibleProvider, openai, anthropic, gemini, mistral, openrouter, ollama, + openai_compatible, +) +from datafast.llm.types import ( + BatchMode, + CacheMode, + ContentPart, + EndpointMode, + Modality, + NormalizedResponse, + RetryPolicy, + StructuredOutputMode, + TargetCapabilities, + TargetConfig, + UnsupportedParamsPolicy, ) from datafast.llm.parsing import ( OutputParser, @@ -30,12 +45,25 @@ "MistralProvider", "OpenRouterProvider", "OllamaProvider", + "OpenAICompatibleProvider", "openai", "anthropic", "gemini", "mistral", "openrouter", "ollama", + "openai_compatible", + "BatchMode", + "CacheMode", + "ContentPart", + "EndpointMode", + "Modality", + "NormalizedResponse", + "RetryPolicy", + "StructuredOutputMode", + "TargetCapabilities", + "TargetConfig", + "UnsupportedParamsPolicy", "OutputParser", "TextParser", "JSONParser", diff --git a/datafast/llm/capabilities.py b/datafast/llm/capabilities.py new file mode 100644 index 0000000..14d6b26 --- /dev/null +++ b/datafast/llm/capabilities.py @@ -0,0 +1,273 @@ +"""Capability resolution for Datafast LLM targets.""" + +from __future__ import annotations + +from datafast.llm.types import ( + BatchMode, + CacheMode, + EndpointMode, + Modality, + StructuredOutputMode, + TargetCapabilities, +) + + +COMMON_CHAT_PARAMS = frozenset({ + "temperature", + "max_completion_tokens", + "timeout", +}) + +SAMPLING_CHAT_PARAMS = frozenset({ + "top_p", + "frequency_penalty", +}) + +REASONING_PARAMS = frozenset({ + "thinking", + "reasoning_effort", +}) + +RESPONSES_PARAMS = frozenset({ + "temperature", + "max_completion_tokens", + "timeout", + "thinking", + "reasoning_effort", + "previous_response_id", +}) + + +HOSTED_CHAT = TargetCapabilities( + endpoint_modes=frozenset({EndpointMode.CHAT}), + default_endpoint_mode=EndpointMode.CHAT, + supported_params=COMMON_CHAT_PARAMS | SAMPLING_CHAT_PARAMS, + structured_output=StructuredOutputMode.JSON_SCHEMA, + batch_mode=BatchMode.LITELLM_BATCH, + cache_mode=CacheMode.PROVIDER_PROMPT, +) + + +OPENAI_RESPONSES = TargetCapabilities( + endpoint_modes=frozenset({EndpointMode.CHAT, EndpointMode.RESPONSES}), + default_endpoint_mode=EndpointMode.RESPONSES, + supported_params=RESPONSES_PARAMS | SAMPLING_CHAT_PARAMS, + structured_output=StructuredOutputMode.JSON_SCHEMA, + batch_mode=BatchMode.FALLBACK_CONCURRENCY, + cache_mode=CacheMode.PROVIDER_PROMPT, + supports_reasoning=True, +) + + +OPENAI_CHAT = TargetCapabilities( + endpoint_modes=frozenset({EndpointMode.CHAT, EndpointMode.RESPONSES}), + default_endpoint_mode=EndpointMode.CHAT, + supported_params=COMMON_CHAT_PARAMS | SAMPLING_CHAT_PARAMS, + structured_output=StructuredOutputMode.JSON_SCHEMA, + batch_mode=BatchMode.LITELLM_BATCH, + cache_mode=CacheMode.PROVIDER_PROMPT, +) + + +ANTHROPIC_CHAT = TargetCapabilities( + endpoint_modes=frozenset({EndpointMode.CHAT}), + default_endpoint_mode=EndpointMode.CHAT, + supported_params=COMMON_CHAT_PARAMS | REASONING_PARAMS, + structured_output=StructuredOutputMode.JSON_SCHEMA, + batch_mode=BatchMode.LITELLM_BATCH, + cache_mode=CacheMode.PROVIDER_PROMPT, + supports_reasoning=True, + supports_thinking=True, +) + + +OPENROUTER_CHAT = TargetCapabilities( + endpoint_modes=frozenset({EndpointMode.CHAT}), + default_endpoint_mode=EndpointMode.CHAT, + supported_params=COMMON_CHAT_PARAMS | SAMPLING_CHAT_PARAMS, + modalities=frozenset({Modality.TEXT, Modality.IMAGE}), + structured_output=StructuredOutputMode.JSON_SCHEMA, + batch_mode=BatchMode.LITELLM_BATCH, + cache_mode=CacheMode.ROUTER, + notes=( + "OpenRouter capabilities remain model and routed-provider dependent.", + "Reasoning controls are omitted by default; pass provider_params for " + "model-specific OpenRouter/LiteLLM escape hatches.", + ), +) + + +OLLAMA_CHAT = TargetCapabilities( + endpoint_modes=frozenset({EndpointMode.CHAT}), + default_endpoint_mode=EndpointMode.CHAT, + supported_params=COMMON_CHAT_PARAMS | SAMPLING_CHAT_PARAMS, + structured_output=StructuredOutputMode.JSON_OBJECT, + batch_mode=BatchMode.FALLBACK_CONCURRENCY, + cache_mode=CacheMode.LOCAL_KV, + no_api_key=True, + notes=("Structured output uses Ollama JSON mode plus Datafast validation.",), +) + + +VLLM_CHAT = TargetCapabilities( + endpoint_modes=frozenset({EndpointMode.CHAT, EndpointMode.RESPONSES}), + default_endpoint_mode=EndpointMode.CHAT, + supported_params=COMMON_CHAT_PARAMS | SAMPLING_CHAT_PARAMS, + modalities=frozenset({Modality.TEXT, Modality.IMAGE, Modality.VIDEO}), + structured_output=StructuredOutputMode.JSON_SCHEMA, + batch_mode=BatchMode.FALLBACK_CONCURRENCY, + cache_mode=CacheMode.LOCAL_KV, + no_api_key=True, + requires_chat_template=True, + notes=( + "vLLM exposes OpenAI-compatible chat and Responses endpoints, but " + "feature coverage remains model and server-version dependent.", + "Multimodal support depends on the served model; stable media UUIDs " + "can be passed with ContentPart.media_id.", + ), +) + + +LLAMACPP_CHAT = TargetCapabilities( + endpoint_modes=frozenset({EndpointMode.CHAT}), + default_endpoint_mode=EndpointMode.CHAT, + supported_params=COMMON_CHAT_PARAMS | SAMPLING_CHAT_PARAMS, + modalities=frozenset({ + Modality.TEXT, + Modality.IMAGE, + Modality.AUDIO, + Modality.VIDEO, + Modality.FILE, + }), + structured_output=StructuredOutputMode.JSON_SCHEMA, + batch_mode=BatchMode.FALLBACK_CONCURRENCY, + cache_mode=CacheMode.LOCAL_KV, + no_api_key=True, + requires_chat_template=True, + notes=( + "llama.cpp server is OpenAI-compatible for chat, with JSON schema " + "support through response_format.", + "Multimodal inputs and reasoning controls are model and build dependent; " + "use provider_params for llama.cpp-specific extra_body fields.", + ), +) + + +OPENAI_COMPATIBLE_CHAT = TargetCapabilities( + endpoint_modes=frozenset({EndpointMode.CHAT, EndpointMode.RESPONSES}), + default_endpoint_mode=EndpointMode.CHAT, + supported_params=frozenset({"timeout"}), + structured_output=StructuredOutputMode.PROMPTED_JSON, + batch_mode=BatchMode.FALLBACK_CONCURRENCY, + cache_mode=CacheMode.LOCAL_KV, + no_api_key=True, + requires_chat_template=True, + notes=("OpenAI-compatible transport does not imply OpenAI feature support.",), +) + + +_CATALOG: dict[tuple[str, str], TargetCapabilities] = { + ("openai", "gpt-5.5"): OPENAI_RESPONSES, + ("openai", "gpt-5.4"): OPENAI_RESPONSES, + ("openai", "gpt-5.4-mini"): OPENAI_RESPONSES, + ("openai", "gpt-5.4-nano"): OPENAI_RESPONSES, + ("anthropic", "claude-sonnet-4-6"): ANTHROPIC_CHAT, + ("anthropic", "claude-haiku-4-5"): ANTHROPIC_CHAT, + ("gemini", "gemini-2.5-pro"): HOSTED_CHAT, + ("gemini", "gemini-3.5-flash"): HOSTED_CHAT, + ("gemini", "gemini-3.1-flash-lite"): HOSTED_CHAT, + ("mistral", "mistral-medium-3-5"): HOSTED_CHAT, + ("mistral", "mistral-large-2512"): HOSTED_CHAT, + ("mistral", "mistral-small-2603"): HOSTED_CHAT, + ("mistral", "ministral-14b-2512"): OPENAI_COMPATIBLE_CHAT, + ("mistral", "ministral-8b-2512"): OPENAI_COMPATIBLE_CHAT, + ("mistral", "ministral-3b-2512"): OPENAI_COMPATIBLE_CHAT, +} + +_PROVIDER_DEFAULTS: dict[str, TargetCapabilities] = { + "anthropic": ANTHROPIC_CHAT, + "gemini": HOSTED_CHAT, + "llamacpp": LLAMACPP_CHAT, + "mistral": HOSTED_CHAT, + "ollama": OLLAMA_CHAT, + "openrouter": OPENROUTER_CHAT, + "vllm": VLLM_CHAT, +} + +_OPENAI_COMPATIBLE_PROVIDERS = frozenset({ + "openai_compatible", +}) + + +def resolve_capabilities( + provider: str, + model_id: str, + *, + api_base_url: str | None = None, + explicit: TargetCapabilities | None = None, +) -> TargetCapabilities: + """Resolve target capabilities with conservative defaults.""" + if explicit is not None: + return explicit + + normalized_provider = provider.lower() + normalized_model = model_id.lower() + + catalog_match = _CATALOG.get((normalized_provider, normalized_model)) + if catalog_match is not None: + return catalog_match + + if normalized_provider == "openai": + return _resolve_openai_capabilities(normalized_model) + + provider_default = _PROVIDER_DEFAULTS.get(normalized_provider) + if provider_default is not None: + return provider_default + + if normalized_provider in _OPENAI_COMPATIBLE_PROVIDERS: + return OPENAI_COMPATIBLE_CHAT + + if api_base_url: + return OPENAI_COMPATIBLE_CHAT + + return _unknown_capabilities() + + +def _resolve_openai_capabilities(model_id: str) -> TargetCapabilities: + if _looks_like_openai_reasoning_model(model_id): + return OPENAI_RESPONSES + return OPENAI_CHAT + + +def _looks_like_openai_reasoning_model(model_id: str) -> bool: + return ( + model_id.startswith("gpt-5") + or model_id.startswith("o1") + or model_id.startswith("o3") + or model_id.startswith("o4") + ) + + +def _unknown_capabilities() -> TargetCapabilities: + return TargetCapabilities( + endpoint_modes=frozenset({EndpointMode.CHAT}), + default_endpoint_mode=EndpointMode.CHAT, + supported_params=frozenset({"timeout"}), + structured_output=StructuredOutputMode.PROMPTED_JSON, + batch_mode=BatchMode.FALLBACK_CONCURRENCY, + notes=("Unknown target; optional Datafast parameters are omitted by default.",), + ) + + +__all__ = [ + "ANTHROPIC_CHAT", + "HOSTED_CHAT", + "LLAMACPP_CHAT", + "OLLAMA_CHAT", + "OPENAI_CHAT", + "OPENAI_COMPATIBLE_CHAT", + "OPENAI_RESPONSES", + "OPENROUTER_CHAT", + "VLLM_CHAT", + "resolve_capabilities", +] diff --git a/datafast/llm/provider.py b/datafast/llm/provider.py index 4768b24..4ef0ef8 100644 --- a/datafast/llm/provider.py +++ b/datafast/llm/provider.py @@ -1,52 +1,1463 @@ -"""Provider exports for the pipeline-first datafast API.""" - -from datafast.llms import ( - LLMProvider, - OpenAIProvider, - AnthropicProvider, - GeminiProvider, - MistralProvider, - OpenRouterProvider, - OllamaProvider, +"""Capability-aware LLM providers for Datafast.""" + +from __future__ import annotations + +import copy +import os +import random +import time +import traceback +import warnings +from concurrent.futures import ThreadPoolExecutor +from threading import Lock +from typing import Any, TypeVar + +from loguru import logger +from pydantic import BaseModel + +import litellm +from litellm import exceptions as litellm_exceptions + +from datafast.llm.capabilities import resolve_capabilities +from datafast.llm.types import ( + BatchMode, + ContentPart, + EndpointMode, + Message, + Messages, + Modality, + NormalizedRequest, + NormalizedResponse, + RetryPolicy, + StructuredOutputMode, + TargetCapabilities, + TargetConfig, + UnsupportedParamsPolicy, +) +from datafast.tracing import ( + build_trace_metadata, + load_env_once, + maybe_configure_langfuse_tracing, ) -def openai(model_id: str = "gpt-5-mini-2025-08-07", **kwargs) -> OpenAIProvider: - """Create an OpenAI provider instance.""" +T = TypeVar("T", bound=BaseModel) + +JSON_INSTRUCTIONS = ( + "\nReturn only valid JSON. Do not include markdown fences. Use double quotes " + "for keys and string values, escape internal newlines, and avoid trailing commas." +) + + +class LLMProvider: + """One Datafast provider target resolved to LiteLLM request adapters.""" + + def __init__( + self, + provider: str, + model_id: str, + *, + litellm_provider: str, + env_key_name: str | None, + endpoint_mode: str | EndpointMode = EndpointMode.AUTO, + temperature: float | None = None, + max_completion_tokens: int | None = None, + max_tokens: int | None = None, + thinking: bool | None = None, + reasoning_effort: str | None = None, + rpm_limit: int | None = None, + timeout: float | None = None, + api_key: str | None = None, + api_base_url: str | None = None, + api_base: str | None = None, + retry_limit: int | None = None, + retry_policy: RetryPolicy | None = None, + unsupported_params: str | UnsupportedParamsPolicy = UnsupportedParamsPolicy.WARN, + provider_params: dict[str, Any] | None = None, + max_concurrent: int = 4, + capabilities: TargetCapabilities | None = None, + **extra_provider_params: Any, + ) -> None: + if max_completion_tokens is None and max_tokens is not None: + max_completion_tokens = max_tokens + if api_base_url is None: + api_base_url = api_base + + merged_provider_params = dict(provider_params or {}) + merged_provider_params.update(extra_provider_params) + + if retry_policy is None: + retry_policy = RetryPolicy( + max_retries=retry_limit if retry_limit is not None else 3 + ) + + unsupported_policy = _coerce_unsupported_policy(unsupported_params) + + self.config = TargetConfig( + provider=provider, + model_id=model_id, + litellm_provider=litellm_provider, + env_key_name=env_key_name, + endpoint_mode=_coerce_endpoint_mode(endpoint_mode), + temperature=temperature, + max_completion_tokens=max_completion_tokens, + thinking=thinking, + reasoning_effort=reasoning_effort, + rpm_limit=rpm_limit, + timeout=timeout, + api_key=api_key, + api_base_url=api_base_url, + retry_policy=retry_policy, + unsupported_params=unsupported_policy, + provider_params=merged_provider_params, + max_concurrent=max_concurrent, + ) + self.capabilities = resolve_capabilities( + provider, + model_id, + api_base_url=api_base_url, + explicit=capabilities, + ) + self.endpoint_mode = self._resolve_endpoint_mode(self.config.endpoint_mode) + + self.provider_name = provider + self.model_id = model_id + self.env_key_name = env_key_name + self.api_key = api_key or (os.getenv(env_key_name) if env_key_name else None) + self.api_base_url = api_base_url + self.temperature = temperature + self.max_completion_tokens = max_completion_tokens + self.reasoning_effort = reasoning_effort + self.rpm_limit = rpm_limit + self.timeout = timeout + self.unsupported_params = unsupported_policy.value + + self._request_timestamps: list[float] = [] + self._rate_lock = Lock() + self._sleep = time.sleep + self._configured_common_params = { + name + for name, value in { + "temperature": temperature, + "max_completion_tokens": max_completion_tokens, + "thinking": thinking, + "reasoning_effort": reasoning_effort, + "timeout": timeout, + }.items() + if value is not None + } + + load_env_once() + maybe_configure_langfuse_tracing(load_env=False) + logger.info( + "Initialized {} | Model: {} | Endpoint: {}", + self.provider_name, + self.model_id, + self.endpoint_mode.value, + ) + + def generate( + self, + prompt: str | list[str] | None = None, + messages: Messages | list[Messages] | None = None, + response_format: type[T] | None = None, + metadata: dict[str, Any] | None = None, + previous_response_id: str | None = None, + ) -> str | list[str] | T | list[T]: + """Generate a single response or ordered batch of responses.""" + requests, single_input = self._normalize_inputs( + prompt=prompt, + messages=messages, + metadata=metadata, + previous_response_id=previous_response_id, + response_format=response_format, + ) + try: + results = self._generate_requests(requests, response_format=response_format) + except ValueError: + raise + except Exception as exc: + error_trace = traceback.format_exc() + logger.error( + "Generation failed | Provider: {} | Model: {} | Error: {}", + self.provider_name, + self.model_id, + exc, + ) + raise RuntimeError( + f"Error generating response with {self.provider_name}:\n{error_trace}" + ) from exc + + if single_input: + return results[0] + return results + + def generate_batch( + self, + messages: list[Messages], + *, + response_format: type[T] | None = None, + metadata: list[dict[str, Any] | None] | dict[str, Any] | None = None, + previous_response_ids: list[str | None] | None = None, + ) -> list[str] | list[T]: + """Generate an ordered batch from pre-built message lists.""" + if not messages: + return [] + + metadata_items = _normalize_metadata(metadata, len(messages)) + previous_ids = previous_response_ids or [None] * len(messages) + if len(previous_ids) != len(messages): + raise ValueError("previous_response_ids length must match messages length") + + requests = [ + NormalizedRequest( + messages=self._prepare_messages( + item, + response_format=response_format, + ), + metadata=metadata_items[index], + previous_response_id=previous_ids[index], + ) + for index, item in enumerate(messages) + ] + return self._generate_requests(requests, response_format=response_format) + + def generate_response( + self, + prompt: str | list[str] | None = None, + messages: Messages | list[Messages] | None = None, + metadata: dict[str, Any] | None = None, + previous_response_id: str | None = None, + ) -> NormalizedResponse | list[NormalizedResponse]: + """Generate response metadata, including LiteLLM reasoning fields when present.""" + requests, single_input = self._normalize_inputs( + prompt=prompt, + messages=messages, + metadata=metadata, + previous_response_id=previous_response_id, + response_format=None, + ) + responses = self._generate_normalized_responses( + requests, + response_format=None, + ) + if single_input: + return responses[0] + return responses + + def generate_batch_response( + self, + messages: list[Messages], + *, + metadata: list[dict[str, Any] | None] | dict[str, Any] | None = None, + previous_response_ids: list[str | None] | None = None, + ) -> list[NormalizedResponse]: + """Generate ordered batch responses with metadata preserved.""" + if not messages: + return [] + + metadata_items = _normalize_metadata(metadata, len(messages)) + previous_ids = previous_response_ids or [None] * len(messages) + if len(previous_ids) != len(messages): + raise ValueError("previous_response_ids length must match messages length") + + requests = [ + NormalizedRequest( + messages=self._prepare_messages(item, response_format=None), + metadata=metadata_items[index], + previous_response_id=previous_ids[index], + ) + for index, item in enumerate(messages) + ] + return self._generate_normalized_responses(requests, response_format=None) + + def _generate_requests( + self, + requests: list[NormalizedRequest], + *, + response_format: type[T] | None, + ) -> list[str] | list[T]: + responses = self._generate_normalized_responses( + requests, + response_format=response_format, + ) + return [ + self._parse_response(response, response_format=response_format) + for response in responses + ] + + def _generate_normalized_responses( + self, + requests: list[NormalizedRequest], + *, + response_format: type[T] | None, + ) -> list[NormalizedResponse]: + if not requests: + return [] + + if len(requests) == 1: + return [self._execute_single(requests[0], response_format=response_format)] + + if ( + self.endpoint_mode == EndpointMode.CHAT + and self.capabilities.batch_mode == BatchMode.LITELLM_BATCH + ): + return self._execute_litellm_batch( + requests, + response_format=response_format, + ) + + warnings.warn( + ( + f"{self.provider_name}/{self.model_id} does not expose native " + "same-target batching for this endpoint. Falling back to bounded " + "parallel single requests." + ), + UserWarning, + stacklevel=2, + ) + with ThreadPoolExecutor( + max_workers=max(1, min(self.config.max_concurrent, len(requests))) + ) as executor: + responses = list( + executor.map( + lambda request: self._execute_single( + request, + response_format=response_format, + ), + requests, + ) + ) + return responses + + def _execute_single( + self, + request: NormalizedRequest, + *, + response_format: type[T] | None, + ) -> NormalizedResponse: + if self.endpoint_mode == EndpointMode.RESPONSES: + params = self._build_responses_params(request, response_format) + response = self._call_litellm( + litellm.responses, + params, + request_count=1, + ) + return NormalizedResponse( + text=_extract_responses_text(response), + raw=response, + reasoning_content=_extract_responses_reasoning(response), + images=_extract_responses_images(response), + audio=_extract_responses_audio(response), + output_items=_extract_responses_output_items(response), + ) + + params = self._build_chat_params(request, response_format) + response = self._call_litellm( + litellm.completion, + params, + request_count=1, + ) + return NormalizedResponse( + text=_extract_chat_text(response), + raw=response, + reasoning_content=_extract_chat_reasoning_content(response), + thinking_blocks=_extract_chat_thinking_blocks(response), + images=_extract_chat_images(response), + audio=_extract_chat_audio(response), + ) + + def _execute_litellm_batch( + self, + requests: list[NormalizedRequest], + *, + response_format: type[T] | None, + ) -> list[NormalizedResponse]: + params = self._build_chat_params( + NormalizedRequest( + messages=[], + metadata=_combine_batch_metadata(requests), + ), + response_format, + ) + params["messages"] = [request.messages for request in requests] + response = self._call_litellm( + litellm.batch_completion, + params, + request_count=len(requests), + ) + if not isinstance(response, list): + response = list(response) + + normalized: list[NormalizedResponse] = [] + for index, item in enumerate(response): + if isinstance(item, Exception): + raise RuntimeError(f"Batch item {index} failed: {item}") from item + normalized.append( + NormalizedResponse( + text=_extract_chat_text(item), + raw=item, + reasoning_content=_extract_chat_reasoning_content(item), + thinking_blocks=_extract_chat_thinking_blocks(item), + images=_extract_chat_images(item), + audio=_extract_chat_audio(item), + ) + ) + return normalized + + def _build_chat_params( + self, + request: NormalizedRequest, + response_format: type[T] | None, + ) -> dict[str, Any]: + params: dict[str, Any] = { + "model": self._get_model_string(), + "messages": request.messages, + "metadata": self._build_request_metadata(request.metadata), + } + if request.previous_response_id is not None: + self._add_supported_param( + params, + "previous_response_id", + request.previous_response_id, + endpoint=EndpointMode.CHAT, + ) + self._add_transport_params(params, endpoint=EndpointMode.CHAT) + self._add_common_generation_params(params, endpoint=EndpointMode.CHAT) + self._add_chat_structured_output(params, response_format) + params.update(self.config.provider_params) + return _without_none(params) + + def _build_responses_params( + self, + request: NormalizedRequest, + response_format: type[T] | None, + ) -> dict[str, Any]: + params: dict[str, Any] = { + "model": self._get_model_string(), + "input": request.messages, + "metadata": self._build_request_metadata(request.metadata), + } + if request.previous_response_id is not None: + self._add_supported_param( + params, + "previous_response_id", + request.previous_response_id, + endpoint=EndpointMode.RESPONSES, + ) + self._add_transport_params(params, endpoint=EndpointMode.RESPONSES) + self._add_common_generation_params(params, endpoint=EndpointMode.RESPONSES) + self._add_responses_structured_output(params, response_format) + params.update(self.config.provider_params) + return _without_none(params) + + def _add_common_generation_params( + self, + params: dict[str, Any], + *, + endpoint: EndpointMode, + ) -> None: + self._add_supported_param( + params, + "temperature", + self.config.temperature, + endpoint=endpoint, + ) + + token_param = ( + "max_output_tokens" + if endpoint == EndpointMode.RESPONSES + else "max_completion_tokens" + ) + self._add_supported_param( + params, + "max_completion_tokens", + self.config.max_completion_tokens, + endpoint=endpoint, + target_name=token_param, + ) + + if self.config.thinking is False: + return + + effort = self.config.reasoning_effort + if effort is None and self.config.thinking is True: + effort = "low" + + if endpoint == EndpointMode.RESPONSES and effort is not None: + self._add_supported_param( + params, + "reasoning_effort", + {"effort": effort}, + endpoint=endpoint, + target_name="reasoning", + ) + return + + self._add_supported_param( + params, + "reasoning_effort", + effort, + endpoint=endpoint, + ) + + def _add_chat_structured_output( + self, + params: dict[str, Any], + response_format: type[T] | None, + ) -> None: + if response_format is None: + return + + mode = self.capabilities.structured_output + if mode == StructuredOutputMode.JSON_SCHEMA: + params["response_format"] = response_format + elif mode == StructuredOutputMode.JSON_OBJECT: + params["response_format"] = {"type": "json_object"} + if self.provider_name == "ollama": + params["format"] = "json" + elif mode == StructuredOutputMode.PROMPTED_JSON: + warnings.warn( + ( + f"{self.provider_name}/{self.model_id} has no declared native " + "schema support. Using prompted JSON plus Pydantic validation." + ), + UserWarning, + stacklevel=3, + ) + else: + raise ValueError( + f"{self.provider_name}/{self.model_id} does not support structured output" + ) + + def _add_responses_structured_output( + self, + params: dict[str, Any], + response_format: type[T] | None, + ) -> None: + if response_format is None: + return + + if self.capabilities.structured_output != StructuredOutputMode.JSON_SCHEMA: + raise ValueError( + f"{self.provider_name}/{self.model_id} does not support native " + "Responses structured output" + ) + params["text_format"] = response_format + + def _add_transport_params( + self, + params: dict[str, Any], + *, + endpoint: EndpointMode, + ) -> None: + if self.config.timeout is not None: + self._add_supported_param( + params, + "timeout", + self.config.timeout, + endpoint=endpoint, + ) + if self.api_base_url is not None: + params["api_base"] = self.api_base_url + if self.api_key is not None: + params["api_key"] = self.api_key + elif self.env_key_name and not self.capabilities.no_api_key: + env_key = os.getenv(self.env_key_name) + if env_key: + params["api_key"] = env_key + else: + raise ValueError( + f"{self.env_key_name} environment variable not set. " + "Set it or provide api_key when initializing the provider." + ) + + def _add_supported_param( + self, + params: dict[str, Any], + source_name: str, + value: Any, + *, + endpoint: EndpointMode, + target_name: str | None = None, + ) -> None: + if value is None: + return + + if source_name not in self.capabilities.supported_params: + if ( + source_name in self._configured_common_params + or source_name == "previous_response_id" + or ( + source_name == "reasoning_effort" + and self.config.thinking is True + ) + ): + self._handle_unsupported_param(source_name) + return + + if source_name == "reasoning_effort" and not self.capabilities.supports_reasoning: + self._handle_unsupported_param(source_name) + return + + if endpoint == EndpointMode.RESPONSES and not self.capabilities.supports_endpoint( + EndpointMode.RESPONSES + ): + self._handle_unsupported_param(source_name) + return + + params[target_name or source_name] = value + + def _handle_unsupported_param(self, name: str) -> None: + message = ( + f"Parameter '{name}' is not supported by resolved target " + f"{self.provider_name}/{self.model_id} and will be omitted." + ) + if self.config.unsupported_params == UnsupportedParamsPolicy.FAIL: + raise ValueError(message) + if self.config.unsupported_params == UnsupportedParamsPolicy.WARN: + warnings.warn(message, UserWarning, stacklevel=3) + + def _normalize_inputs( + self, + *, + prompt: str | list[str] | None, + messages: Messages | list[Messages] | None, + metadata: dict[str, Any] | None, + previous_response_id: str | None, + response_format: type[T] | None, + ) -> tuple[list[NormalizedRequest], bool]: + if prompt is None and messages is None: + raise ValueError("Either prompt or messages must be provided") + if prompt is not None and messages is not None: + raise ValueError("Provide either prompt or messages, not both") + + single_input = False + batch_messages: list[Messages] + + if prompt is not None: + if isinstance(prompt, str): + batch_messages = [[{"role": "user", "content": prompt}]] + single_input = True + elif isinstance(prompt, list) and all(isinstance(item, str) for item in prompt): + if not prompt: + raise ValueError("prompt list cannot be empty") + batch_messages = [ + [{"role": "user", "content": item}] + for item in prompt + ] + else: + raise ValueError("prompt must be a string or list of strings") + elif _is_single_messages(messages): + batch_messages = [messages] # type: ignore[list-item] + single_input = True + elif _is_batch_messages(messages): + batch_messages = messages # type: ignore[assignment] + if not batch_messages: + raise ValueError("messages cannot be empty") + else: + raise ValueError("Invalid messages format") + + return ( + [ + NormalizedRequest( + messages=self._prepare_messages( + item, + response_format=response_format, + ), + metadata=metadata, + previous_response_id=previous_response_id, + ) + for item in batch_messages + ], + single_input, + ) + + def _prepare_messages( + self, + messages: Messages, + *, + response_format: type[T] | None, + ) -> Messages: + if not messages: + raise ValueError("messages cannot be empty") + + normalized = [_normalize_message(message) for message in copy.deepcopy(messages)] + self._validate_modalities(normalized) + + if response_format is not None and self.capabilities.structured_output in { + StructuredOutputMode.JSON_OBJECT, + StructuredOutputMode.PROMPTED_JSON, + }: + _append_json_instructions(normalized) + + return normalized + + def _validate_modalities(self, messages: Messages) -> None: + supported = self.capabilities.modalities + for message in messages: + content = message.get("content") + if not isinstance(content, list): + continue + for part in content: + modality = _modality_for_part(part) + if modality not in supported: + raise ValueError( + f"Modality '{modality.value}' is not supported by " + f"{self.provider_name}/{self.model_id}" + ) + + def _resolve_endpoint_mode(self, endpoint_mode: EndpointMode) -> EndpointMode: + if endpoint_mode == EndpointMode.AUTO: + return self.capabilities.default_endpoint_mode + if not self.capabilities.supports_endpoint(endpoint_mode): + raise ValueError( + f"{self.provider_name}/{self.model_id} does not support " + f"endpoint_mode='{endpoint_mode.value}'" + ) + return endpoint_mode + + def _call_litellm(self, func, params: dict[str, Any], *, request_count: int) -> Any: + try: + return self._call_with_retries( + lambda: func(**params), + request_count=request_count, + ) + except Exception as exc: + if not self._should_retry_with_drop_params(exc, params): + raise + + retry_params = dict(params) + retry_params["drop_params"] = True + if self.config.unsupported_params == UnsupportedParamsPolicy.WARN: + warnings.warn( + ( + "LiteLLM rejected one or more request parameters as " + "unsupported. Retrying once with drop_params=True because " + f"unsupported_params='{self.config.unsupported_params.value}'." + ), + UserWarning, + stacklevel=3, + ) + return self._call_with_retries( + lambda: func(**retry_params), + request_count=request_count, + ) + + def _should_retry_with_drop_params( + self, + exc: Exception, + params: dict[str, Any], + ) -> bool: + if self.config.unsupported_params == UnsupportedParamsPolicy.FAIL: + return False + if params.get("drop_params") is True: + return False + return _is_unsupported_params_error(exc) + + def _call_with_retries(self, func, *, request_count: int) -> Any: + retry_policy = self.config.retry_policy + attempts = max(1, retry_policy.max_retries) + + for attempt in range(attempts): + self._respect_rate_limit(request_count) + try: + response = func() + self._record_requests(request_count) + return response + except Exception as exc: + if attempt >= attempts - 1 or not _is_retryable_error(exc): + raise + delay = min( + retry_policy.max_delay, + retry_policy.base_delay * (2 ** attempt), + ) + if retry_policy.jitter > 0: + delay += random.uniform(0, delay * retry_policy.jitter) + logger.warning( + "Retryable LLM error | Provider: {} | Model: {} | " + "Attempt: {}/{} | Waiting: {:.2f}s | Error: {}", + self.provider_name, + self.model_id, + attempt + 1, + attempts, + delay, + exc, + ) + self._sleep(delay) + + raise RuntimeError("unreachable retry state") + + def _respect_rate_limit(self, request_count: int = 1) -> None: + if self.config.rpm_limit is None: + return + + with self._rate_lock: + now = time.monotonic() + self._request_timestamps = [ + timestamp + for timestamp in self._request_timestamps + if now - timestamp < 60 + ] + + while len(self._request_timestamps) + request_count > self.config.rpm_limit: + earliest = self._request_timestamps[0] + sleep_time = max(0.0, 60 - (now - earliest)) + if sleep_time > 0: + logger.warning( + "Rate limit reached | Provider: {} | Model: {} | " + "Waiting {:.2f}s", + self.provider_name, + self.model_id, + sleep_time, + ) + self._sleep(sleep_time) + now = time.monotonic() + self._request_timestamps = [ + timestamp + for timestamp in self._request_timestamps + if now - timestamp < 60 + ] + + def _record_requests(self, request_count: int = 1) -> None: + if self.config.rpm_limit is None: + return + with self._rate_lock: + now = time.monotonic() + self._request_timestamps.extend([now] * request_count) + + def _parse_response( + self, + response: NormalizedResponse, + *, + response_format: type[T] | None, + ) -> str | T: + if response_format is None: + return response.text.strip() if response.text else response.text + + parsed = getattr(response.raw, "output_parsed", None) + if parsed is not None: + return parsed + + content = self._strip_code_fences(response.text) + try: + return response_format.model_validate_json(content) + except Exception as validation_error: + content_preview = ( + content[:200] + "..." if len(content) > 200 else content + ) + raise ValueError( + f"Failed to parse JSON response into {response_format.__name__}.\n" + f"Validation error: {validation_error}\n" + f"Content received (first 200 chars):\n{content_preview}" + ) from validation_error + + def _build_request_metadata( + self, + metadata: dict[str, Any] | None = None, + ) -> dict[str, Any]: + return build_trace_metadata( + model=self, + component="provider.generate", + trace_name=f"datafast.{self.provider_name}", + metadata=metadata, + ) + + def _get_model_string(self) -> str: + prefix = f"{self.config.litellm_provider}/" + if self.model_id.startswith(prefix): + return self.model_id + return f"{prefix}{self.model_id}" + + @staticmethod + def _strip_code_fences(content: str) -> str: + if not content: + return content + + content = content.strip() + if content.startswith("```"): + first_newline = content.find("\n") + content = content[first_newline + 1 :] if first_newline != -1 else content[3:] + if content.endswith("```"): + content = content[:-3] + return content.strip() + + +class OpenAIProvider(LLMProvider): + def __init__(self, model_id: str = "gpt-5.5", **kwargs: Any) -> None: + super().__init__( + "openai", + model_id, + litellm_provider="openai", + env_key_name="OPENAI_API_KEY", + **kwargs, + ) + + +class AnthropicProvider(LLMProvider): + def __init__(self, model_id: str = "claude-haiku-4-5", **kwargs: Any) -> None: + super().__init__( + "anthropic", + model_id, + litellm_provider="anthropic", + env_key_name="ANTHROPIC_API_KEY", + **kwargs, + ) + + +class GeminiProvider(LLMProvider): + def __init__(self, model_id: str = "gemini-3.1-flash-lite", **kwargs: Any) -> None: + super().__init__( + "gemini", + model_id, + litellm_provider="gemini", + env_key_name="GEMINI_API_KEY", + **kwargs, + ) + + +class MistralProvider(LLMProvider): + def __init__(self, model_id: str = "mistral-small-2603", **kwargs: Any) -> None: + super().__init__( + "mistral", + model_id, + litellm_provider="mistral", + env_key_name="MISTRAL_API_KEY", + **kwargs, + ) + + +class OpenRouterProvider(LLMProvider): + def __init__(self, model_id: str = "openai/gpt-5.4-mini", **kwargs: Any) -> None: + super().__init__( + "openrouter", + model_id, + litellm_provider="openrouter", + env_key_name="OPENROUTER_API_KEY", + **kwargs, + ) + + +class OllamaProvider(LLMProvider): + def __init__(self, model_id: str = "gemma3:4b", **kwargs: Any) -> None: + super().__init__( + "ollama", + model_id, + litellm_provider="ollama_chat", + env_key_name=None, + **kwargs, + ) + + +class OpenAICompatibleProvider(LLMProvider): + def __init__( + self, + model_id: str, + *, + provider: str = "openai_compatible", + litellm_provider: str = "openai", + env_key_name: str | None = None, + **kwargs: Any, + ) -> None: + super().__init__( + provider, + model_id, + litellm_provider=litellm_provider, + env_key_name=env_key_name, + **kwargs, + ) + + +def openai(model_id: str = "gpt-5.5", **kwargs: Any) -> OpenAIProvider: return OpenAIProvider(model_id=model_id, **kwargs) -def anthropic( - model_id: str = "claude-haiku-4-5-20251001", - **kwargs, -) -> AnthropicProvider: - """Create an Anthropic provider instance.""" +def anthropic(model_id: str = "claude-haiku-4-5", **kwargs: Any) -> AnthropicProvider: return AnthropicProvider(model_id=model_id, **kwargs) -def gemini(model_id: str = "gemini-2.0-flash", **kwargs) -> GeminiProvider: - """Create a Gemini provider instance.""" +def gemini(model_id: str = "gemini-3.1-flash-lite", **kwargs: Any) -> GeminiProvider: return GeminiProvider(model_id=model_id, **kwargs) -def mistral(model_id: str = "mistral-small-latest", **kwargs) -> MistralProvider: - """Create a Mistral provider instance.""" +def mistral(model_id: str = "mistral-small-2603", **kwargs: Any) -> MistralProvider: return MistralProvider(model_id=model_id, **kwargs) def openrouter( - model_id: str = "openai/gpt-5-mini", - **kwargs, + model_id: str = "openai/gpt-5.4-mini", + **kwargs: Any, ) -> OpenRouterProvider: - """Create an OpenRouter provider instance.""" return OpenRouterProvider(model_id=model_id, **kwargs) -def ollama(model_id: str = "gemma3:4b", **kwargs) -> OllamaProvider: - """Create an Ollama provider instance.""" +def ollama(model_id: str = "gemma3:4b", **kwargs: Any) -> OllamaProvider: return OllamaProvider(model_id=model_id, **kwargs) +def openai_compatible( + model_id: str, + *, + api_base_url: str | None = None, + backend: str = "openai_compatible", + **kwargs: Any, +) -> OpenAICompatibleProvider: + provider = _normalize_openai_compatible_backend(backend) + return OpenAICompatibleProvider( + model_id=model_id, + provider=provider, + api_base_url=api_base_url, + **kwargs, + ) + + +def _normalize_openai_compatible_backend(value: str) -> str: + normalized = value.strip().lower().replace("-", "_") + aliases = { + "openai-compatible": "openai_compatible", + "openai_compatible": "openai_compatible", + "llama.cpp": "llamacpp", + "llama_cpp": "llamacpp", + "llamacpp": "llamacpp", + "vllm": "vllm", + } + try: + return aliases[normalized] + except KeyError as exc: + valid = ", ".join(sorted(set(aliases.values()))) + raise ValueError( + f"Unsupported OpenAI-compatible backend '{value}'. Choose: {valid}" + ) from exc + + +def _coerce_endpoint_mode(value: str | EndpointMode) -> EndpointMode: + if isinstance(value, EndpointMode): + return value + try: + return EndpointMode(value) + except ValueError as exc: + raise ValueError("endpoint_mode must be 'auto', 'chat', or 'responses'") from exc + + +def _coerce_unsupported_policy( + value: str | UnsupportedParamsPolicy, +) -> UnsupportedParamsPolicy: + if isinstance(value, UnsupportedParamsPolicy): + return value + try: + return UnsupportedParamsPolicy(value) + except ValueError as exc: + raise ValueError("unsupported_params must be 'fail', 'warn', or 'quiet'") from exc + + +def _normalize_metadata( + metadata: list[dict[str, Any] | None] | dict[str, Any] | None, + expected_length: int, +) -> list[dict[str, Any] | None]: + if isinstance(metadata, list): + if len(metadata) != expected_length: + raise ValueError("metadata length must match messages length") + return metadata + return [metadata] * expected_length + + +def _combine_batch_metadata(requests: list[NormalizedRequest]) -> dict[str, Any]: + metadata_items = [request.metadata for request in requests] + return { + "datafast_batch_size": len(requests), + "datafast_batch_metadata": metadata_items, + } + + +def _is_single_messages(value: Any) -> bool: + return isinstance(value, list) and bool(value) and isinstance(value[0], dict) + + +def _is_batch_messages(value: Any) -> bool: + return isinstance(value, list) and bool(value) and isinstance(value[0], list) + + +def _normalize_message(message: Message) -> Message: + if not isinstance(message, dict): + raise ValueError("Each message must be a dictionary") + + normalized = dict(message) + content = normalized.get("content") + if isinstance(content, list): + normalized["content"] = [_normalize_content_part(part) for part in content] + elif content is not None and not isinstance(content, str): + raise ValueError("message content must be a string, list of parts, or None") + return normalized + + +def _normalize_content_part(part: Any) -> dict[str, Any]: + part = _content_part_to_dict(part) + part_type = part.get("type") + + normalizers = { + "text": _normalize_text_part, + "image": _normalize_image_part, + "audio": _normalize_audio_part, + "video": _normalize_video_part, + "file": _normalize_file_part, + "document": _normalize_file_part, + } + if part_type in {"image_url", "input_audio", "video_url"}: + return _without_none(part) + if part_type in normalizers: + return normalizers[part_type](part) + return part + + +def _content_part_to_dict(part: Any) -> dict[str, Any]: + if isinstance(part, ContentPart): + part = { + "type": part.type, + "text": part.text, + "url": part.url, + "data": part.data, + "media_type": part.media_type, + "media_id": part.media_id, + **part.provider_options, + } + + if not isinstance(part, dict): + raise ValueError("content parts must be dictionaries or ContentPart objects") + return part + + +def _normalize_text_part(part: dict[str, Any]) -> dict[str, Any]: + return _without_none({"type": "text", "text": part.get("text")}) + + +def _normalize_image_part(part: dict[str, Any]) -> dict[str, Any]: + image_url: dict[str, Any] = {"url": part.get("url") or part.get("data")} + if part.get("format") or part.get("media_type"): + image_url["format"] = part.get("format") or part.get("media_type") + if part.get("detail"): + image_url["detail"] = part["detail"] + normalized = {"type": "image_url", "image_url": _without_none(image_url)} + if part.get("media_id"): + normalized["uuid"] = part["media_id"] + return normalized + + +def _normalize_audio_part(part: dict[str, Any]) -> dict[str, Any]: + return { + "type": "input_audio", + "input_audio": _without_none({ + "data": part.get("data"), + "format": part.get("format") or part.get("media_type") or "wav", + }), + } + + +def _normalize_video_part(part: dict[str, Any]) -> dict[str, Any]: + normalized = {"type": "video_url", "video_url": {"url": part.get("url")}} + if part.get("media_id"): + normalized["uuid"] = part["media_id"] + return normalized + + +def _normalize_file_part(part: dict[str, Any]) -> dict[str, Any]: + if isinstance(part.get("file"), dict): + file_payload = part["file"] + elif part.get("data"): + file_payload = {"file_data": part.get("data")} + else: + file_payload = {"file_id": part.get("url")} + return {"type": "file", "file": _without_none(file_payload)} + + +def _modality_for_part(part: dict[str, Any]) -> Modality: + part_type = part.get("type") + if part_type == "text": + return Modality.TEXT + if part_type in {"image", "image_url"}: + return Modality.IMAGE + if part_type in {"audio", "input_audio"}: + return Modality.AUDIO + if part_type in {"video", "video_url"}: + return Modality.VIDEO + if part_type == "file": + return Modality.FILE + if part_type == "document": + return Modality.DOCUMENT + return Modality.TEXT + + +def _append_json_instructions(messages: Messages) -> None: + for message in reversed(messages): + if message.get("role") != "user": + continue + content = message.get("content") + if isinstance(content, str): + message["content"] = content + JSON_INSTRUCTIONS + return + if isinstance(content, list): + for part in reversed(content): + if part.get("type") == "text" and isinstance(part.get("text"), str): + part["text"] = part["text"] + JSON_INSTRUCTIONS + return + messages.append({"role": "user", "content": JSON_INSTRUCTIONS.strip()}) + + +def _extract_chat_text(response: Any) -> str: + choice = _get_first_choice(response) + if choice is None: + raise RuntimeError( + f"Unexpected chat response from LiteLLM: {type(response).__name__}" + ) + + message = _get_attr_or_key(choice, "message") + if message is None: + text = _get_attr_or_key(choice, "text") + return "" if text is None else str(text) + + content = _get_attr_or_key(message, "content") + return _content_to_text(content) + + +def _extract_chat_reasoning_content(response: Any) -> str | None: + message = _extract_chat_message(response) + if message is None: + return None + + reasoning_content = _get_attr_or_key(message, "reasoning_content") + if reasoning_content is None: + reasoning_content = _get_attr_or_key(message, "reasoning") + if reasoning_content is None: + return None + if isinstance(reasoning_content, list): + return _content_to_text(reasoning_content).strip() or None + return str(reasoning_content).strip() or None + + +def _extract_chat_thinking_blocks(response: Any) -> list[dict[str, Any]]: + message = _extract_chat_message(response) + if message is None: + return [] + + blocks = _get_attr_or_key(message, "thinking_blocks") + if not blocks: + return [] + if not isinstance(blocks, list): + blocks = [blocks] + return [_normalize_mapping_block(block) for block in blocks] + + +def _extract_chat_images(response: Any) -> list[dict[str, Any]]: + message = _extract_chat_message(response) + if message is None: + return [] + + images = _get_attr_or_key(message, "images") + collected = list(_normalize_optional_list(images)) + + content = _get_attr_or_key(message, "content") + for part in _normalize_optional_list(content): + part_type = _get_attr_or_key(part, "type") + if part_type in {"image", "image_url", "output_image"}: + collected.append(part) + + return [_normalize_mapping_block(image) for image in collected] + + +def _extract_chat_audio(response: Any) -> dict[str, Any] | None: + message = _extract_chat_message(response) + if message is None: + return None + + audio = _get_attr_or_key(message, "audio") + if audio: + return _normalize_mapping_block(audio) + + content = _get_attr_or_key(message, "content") + for part in _normalize_optional_list(content): + part_type = _get_attr_or_key(part, "type") + if part_type in {"audio", "output_audio"}: + return _normalize_mapping_block(part) + return None + + +def _extract_chat_message(response: Any) -> Any: + choice = _get_first_choice(response) + if choice is None: + return None + return _get_attr_or_key(choice, "message") + + +def _extract_responses_text(response: Any) -> str: + output_text = _get_attr_or_key(response, "output_text") + if output_text: + return str(output_text) + + output = _normalize_optional_list(_get_attr_or_key(response, "output")) + texts: list[str] = [] + for item in output: + content = _get_attr_or_key(item, "content") or [] + if isinstance(content, str): + texts.append(content) + continue + for part in _normalize_optional_list(content): + part_type = _get_attr_or_key(part, "type") + if part_type in {"output_text", "text"}: + text = _get_attr_or_key(part, "text") + if text is not None: + texts.append(str(text)) + if texts: + return "".join(texts) + if output: + return "" + raise RuntimeError( + f"Unexpected Responses API response from LiteLLM: {type(response).__name__}" + ) + + +def _extract_responses_reasoning(response: Any) -> str | None: + reasoning_content = _get_attr_or_key(response, "reasoning_content") + if reasoning_content: + return str(reasoning_content).strip() or None + + output = _normalize_optional_list(_get_attr_or_key(response, "output")) + texts: list[str] = [] + for item in output: + item_type = _get_attr_or_key(item, "type") + if item_type != "reasoning": + continue + + for field_name in ("text", "content"): + value = _get_attr_or_key(item, field_name) + if value: + texts.append(_content_to_text(value)) + + summary = _get_attr_or_key(item, "summary") or [] + if isinstance(summary, str): + texts.append(summary) + continue + for part in _normalize_optional_list(summary): + text = _get_attr_or_key(part, "text") or _get_attr_or_key(part, "content") + if text: + texts.append(_content_to_text(text)) + + joined = "\n".join(text.strip() for text in texts if text and text.strip()) + return joined or None + + +def _extract_responses_output_items(response: Any) -> list[dict[str, Any]]: + output = _get_attr_or_key(response, "output") or [] + return [_normalize_mapping_block(item) for item in _normalize_optional_list(output)] + + +def _extract_responses_images(response: Any) -> list[dict[str, Any]]: + images: list[Any] = [] + for item in _normalize_optional_list(_get_attr_or_key(response, "output")): + item_type = _get_attr_or_key(item, "type") + if item_type in {"image", "output_image", "image_generation_call"}: + images.append(item) + for part in _normalize_optional_list(_get_attr_or_key(item, "content")): + part_type = _get_attr_or_key(part, "type") + if part_type in {"image", "image_url", "output_image"}: + images.append(part) + return [_normalize_mapping_block(image) for image in images] + + +def _extract_responses_audio(response: Any) -> dict[str, Any] | None: + for item in _normalize_optional_list(_get_attr_or_key(response, "output")): + item_type = _get_attr_or_key(item, "type") + if item_type in {"audio", "output_audio"}: + return _normalize_mapping_block(item) + for part in _normalize_optional_list(_get_attr_or_key(item, "content")): + part_type = _get_attr_or_key(part, "type") + if part_type in {"audio", "output_audio"}: + return _normalize_mapping_block(part) + return None + + +def _get_first_choice(response: Any) -> Any: + choices = _get_attr_or_key(response, "choices") + if not choices: + return None + return choices[0] + + +def _content_to_text(content: Any) -> str: + if content is None: + return "" + if isinstance(content, str): + return content + if isinstance(content, list): + texts = [] + for part in content: + text = _get_attr_or_key(part, "text") + if text is not None: + texts.append(str(text)) + return "".join(texts) + return str(content) + + +def _get_attr_or_key(value: Any, name: str) -> Any: + if isinstance(value, dict): + return value.get(name) + return getattr(value, name, None) + + +def _normalize_optional_list(value: Any) -> list[Any]: + if value is None: + return [] + if isinstance(value, str): + return [] + if isinstance(value, list): + return value + return [value] + + +def _normalize_mapping_block(value: Any) -> dict[str, Any]: + if isinstance(value, dict): + return dict(value) + + if hasattr(value, "model_dump"): + dumped = value.model_dump() + if isinstance(dumped, dict): + return dumped + + if hasattr(value, "dict"): + dumped = value.dict() + if isinstance(dumped, dict): + return dumped + + result: dict[str, Any] = {} + for name in ("type", "text", "thinking", "content", "signature"): + attr = getattr(value, name, None) + if attr is not None: + result[name] = attr + if result: + return result + return {"content": str(value)} + + +def _without_none(values: dict[str, Any]) -> dict[str, Any]: + return {key: value for key, value in values.items() if value is not None} + + +def _is_retryable_error(exc: Exception) -> bool: + retryable_types = ( + litellm_exceptions.RateLimitError, + litellm_exceptions.APIConnectionError, + litellm_exceptions.Timeout, + litellm_exceptions.InternalServerError, + litellm_exceptions.ServiceUnavailableError, + ) + return isinstance(exc, retryable_types) + + +def _is_unsupported_params_error(exc: Exception) -> bool: + unsupported_type = getattr(litellm_exceptions, "UnsupportedParamsError", None) + if unsupported_type is not None and isinstance(exc, unsupported_type): + return True + return exc.__class__.__name__ == "UnsupportedParamsError" + + __all__ = [ "LLMProvider", "OpenAIProvider", @@ -55,10 +1466,12 @@ def ollama(model_id: str = "gemma3:4b", **kwargs) -> OllamaProvider: "MistralProvider", "OpenRouterProvider", "OllamaProvider", + "OpenAICompatibleProvider", "openai", "anthropic", "gemini", "mistral", "openrouter", "ollama", + "openai_compatible", ] diff --git a/datafast/llm/types.py b/datafast/llm/types.py new file mode 100644 index 0000000..c8f260b --- /dev/null +++ b/datafast/llm/types.py @@ -0,0 +1,151 @@ +"""Shared types for Datafast LLM provider targets.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Literal + + +Message = dict[str, Any] +Messages = list[Message] + +ContentPartType = Literal["text", "image", "audio", "video", "file", "document"] + + +class EndpointMode(str, Enum): + AUTO = "auto" + CHAT = "chat" + RESPONSES = "responses" + + +class UnsupportedParamsPolicy(str, Enum): + FAIL = "fail" + WARN = "warn" + QUIET = "quiet" + + +class StructuredOutputMode(str, Enum): + NONE = "none" + PROMPTED_JSON = "prompted_json" + JSON_OBJECT = "json_object" + JSON_SCHEMA = "json_schema" + + +class BatchMode(str, Enum): + NONE = "none" + LITELLM_BATCH = "litellm_batch" + FALLBACK_CONCURRENCY = "fallback_concurrency" + + +class CacheMode(str, Enum): + NONE = "none" + PROVIDER_PROMPT = "provider_prompt" + ROUTER = "router" + LOCAL_KV = "local_kv" + CLIENT_RESULT = "client_result" + + +class Modality(str, Enum): + TEXT = "text" + IMAGE = "image" + AUDIO = "audio" + VIDEO = "video" + FILE = "file" + DOCUMENT = "document" + + +@dataclass(frozen=True) +class RetryPolicy: + max_retries: int = 3 + base_delay: float = 1.0 + max_delay: float = 30.0 + jitter: float = 0.25 + + +@dataclass(frozen=True) +class TargetCapabilities: + endpoint_modes: frozenset[EndpointMode] + default_endpoint_mode: EndpointMode + supported_params: frozenset[str] = frozenset() + modalities: frozenset[Modality] = frozenset({Modality.TEXT}) + structured_output: StructuredOutputMode = StructuredOutputMode.PROMPTED_JSON + batch_mode: BatchMode = BatchMode.FALLBACK_CONCURRENCY + cache_mode: CacheMode = CacheMode.NONE + supports_reasoning: bool = False + supports_thinking: bool = False + no_api_key: bool = False + requires_chat_template: bool = False + notes: tuple[str, ...] = () + + def supports_endpoint(self, endpoint_mode: EndpointMode) -> bool: + return endpoint_mode in self.endpoint_modes + + +@dataclass(frozen=True) +class TargetConfig: + provider: str + model_id: str + litellm_provider: str + env_key_name: str | None + endpoint_mode: EndpointMode = EndpointMode.AUTO + temperature: float | None = None + max_completion_tokens: int | None = None + thinking: bool | None = None + reasoning_effort: str | None = None + rpm_limit: int | None = None + timeout: float | None = None + api_key: str | None = None + api_base_url: str | None = None + retry_policy: RetryPolicy = field(default_factory=RetryPolicy) + unsupported_params: UnsupportedParamsPolicy = UnsupportedParamsPolicy.WARN + provider_params: dict[str, Any] = field(default_factory=dict) + max_concurrent: int = 4 + + +@dataclass(frozen=True) +class NormalizedRequest: + messages: Messages + metadata: dict[str, Any] | None = None + previous_response_id: str | None = None + + +@dataclass(frozen=True) +class NormalizedResponse: + text: str + raw: Any + reasoning_content: str | None = None + thinking_blocks: list[dict[str, Any]] = field(default_factory=list) + images: list[dict[str, Any]] = field(default_factory=list) + audio: dict[str, Any] | None = None + output_items: list[dict[str, Any]] = field(default_factory=list) + + +@dataclass(frozen=True) +class ContentPart: + type: ContentPartType + text: str | None = None + url: str | None = None + data: str | None = None + media_type: str | None = None + media_id: str | None = None + provider_options: dict[str, Any] = field(default_factory=dict) + + +__all__ = [ + "BatchMode", + "CacheMode", + "ContentPart", + "ContentPartType", + "EndpointMode", + "Message", + "Messages", + "Modality", + "NormalizedRequest", + "NormalizedResponse", + "RetryPolicy", + "StructuredOutputMode", + "TargetCapabilities", + "TargetConfig", + "UnsupportedParamsPolicy", +] diff --git a/datafast/llm_utils.py b/datafast/llm_utils.py index 18890cd..9aa2fb3 100644 --- a/datafast/llm_utils.py +++ b/datafast/llm_utils.py @@ -1,3 +1,8 @@ +from __future__ import annotations + +from collections.abc import Sequence + + def get_messages(prompt: str, system_message: str = "You are a helpful assistant.") -> list[dict[str, str]]: """Convert a single prompt into a message list format expected by LLM APIs. @@ -12,3 +17,48 @@ def get_messages(prompt: str, system_message: str = "You are a helpful assistant {"role": "system", "content": system_message}, {"role": "user", "content": prompt}, ] + + +def format_generated_responses( + prompts: str | Sequence[str], + responses: str | Sequence[str], +) -> str: + """Return a readable string for one or many prompt/response pairs.""" + prompt_items = [prompts] if isinstance(prompts, str) else list(prompts) + response_items = [responses] if isinstance(responses, str) else list(responses) + + if len(prompt_items) != len(response_items): + raise ValueError("prompts and responses must have the same length") + + sections = [ + _format_response_section(prompt, response, index, total=len(prompt_items)) + for index, (prompt, response) in enumerate( + zip(prompt_items, response_items, strict=True), + start=1, + ) + ] + return "\n\n".join(sections) + + +def _format_response_section( + prompt: str, + response: str, + index: int, + *, + total: int, +) -> str: + lines = [] + if total > 1: + lines.append(f"Example {index}") + lines.extend( + [ + "Prompt", + "------", + prompt, + "", + "Response", + "--------", + response, + ] + ) + return "\n".join(lines) diff --git a/datafast/llms.py b/datafast/llms.py index 092346a..3478e30 100644 --- a/datafast/llms.py +++ b/datafast/llms.py @@ -1,892 +1,28 @@ -"""LLM providers for datafast using LiteLLM. +"""Compatibility exports for Datafast LLM providers. -This module provides classes for different LLM providers (OpenAI, Anthropic, Gemini, Mistral) -with a unified interface using LiteLLM under the hood. +The implementation lives in :mod:`datafast.llm.provider`. """ -from typing import Any, Type, TypeVar -from abc import ABC, abstractmethod -import os -import time -import traceback -import warnings -from loguru import logger - -# Pydantic -from pydantic import BaseModel - -# LiteLLM -import litellm -from litellm.exceptions import RateLimitError -from litellm.utils import ModelResponse - -# Internal imports -from .llm_utils import get_messages -from .tracing import ( - build_trace_metadata, - load_env_once, - maybe_configure_langfuse_tracing, +from datafast.llm.provider import ( + LLMProvider, + AnthropicProvider, + GeminiProvider, + MistralProvider, + OllamaProvider, + OpenAICompatibleProvider, + OpenAIProvider, + OpenRouterProvider, + anthropic, + gemini, + mistral, + ollama, + openai, + openai_compatible, + openrouter, ) +from datafast.tracing import load_env_once, maybe_configure_langfuse_tracing -# Type aliases for Python 3.10+ -Message = dict[str, str] -Messages = list[Message] -T = TypeVar('T', bound=BaseModel) - - -class LLMProvider(ABC): - """Abstract base class for LLM providers.""" - - def __init__( - self, - model_id: str, - api_key: str | None = None, - temperature: float | None = None, - max_completion_tokens: int | None = None, - top_p: float | None = None, - frequency_penalty: float | None = None, - rpm_limit: int | None = None, - timeout: int | None = None, - ): - """Initialize the LLM provider with common parameters. - - Args: - model_id: The model identifier - api_key: API key (if None, will get from environment) - temperature: The sampling temperature to be used, between 0 and 2. Higher values like 0.8 produce more random outputs, while lower values like 0.2 make outputs more focused and deterministic - max_completion_tokens: An upper bound for the number of tokens that can be generated for a completion, including visible output tokens and reasoning tokens. - top_p: Nucleus sampling parameter (0.0 to 1.0) - frequency_penalty: Penalty for token frequency (-2.0 to 2.0) - """ - self.model_id = model_id - load_env_once() - maybe_configure_langfuse_tracing(load_env=False) - self.api_key = api_key or self._get_api_key() - - # Set generation parameters - self.temperature = temperature - self.max_completion_tokens = max_completion_tokens - self.top_p = top_p - self.frequency_penalty = frequency_penalty - - # Rate limiting - self.rpm_limit = rpm_limit - self._request_timestamps: list[float] = [] - - # timeout - self.timeout = timeout - - # Configure environment with API key if needed - self._configure_env() - # Log successful initialization - logger.info(f"Initialized {self.provider_name} | Model: {self.model_id}") - - def _build_request_metadata( - self, - metadata: dict[str, Any] | None = None, - ) -> dict[str, Any]: - """Build default tracing metadata for provider-level calls.""" - return build_trace_metadata( - model=self, - component="provider.generate", - trace_name=f"datafast.{self.provider_name}", - metadata=metadata, - ) - - @property - @abstractmethod - def provider_name(self) -> str: - """Return the provider name used by LiteLLM.""" - pass - - @property - @abstractmethod - def env_key_name(self) -> str: - """Return the environment variable name for API key.""" - pass - - def _get_api_key(self) -> str: - """Get API key from environment variables.""" - api_key = os.getenv(self.env_key_name) - if not api_key: - logger.error( - f"Missing API key | Set {self.env_key_name} environment variable" - ) - raise ValueError( - f"{self.env_key_name} environment variable not set. " - f"Please set it or provide an API key when initializing the provider." - ) - return api_key - - def _configure_env(self) -> None: - """Configure environment variables for API key.""" - if self.api_key: - os.environ[self.env_key_name] = self.api_key - - def _get_model_string(self) -> str: - """Get the full model string for LiteLLM.""" - return f"{self.provider_name}/{self.model_id}" - - def _respect_rate_limit(self) -> None: - """Block execution to ensure we do not exceed the rpm_limit.""" - if self.rpm_limit is None: - return - current = time.monotonic() - # Keep only timestamps within the last minute - self._request_timestamps = [ - ts for ts in self._request_timestamps if current - ts < 60] - - # Be more conservative - wait if we're at 90% of the limit - conservative_limit = max(1, int(self.rpm_limit * 0.9)) - - if len(self._request_timestamps) < conservative_limit: - return - - # Need to wait until the earliest request is outside the 60-second window - earliest = self._request_timestamps[0] - # Add a 2s margin to avoid accidental rate limit exceedance - sleep_time = 62 - (current - earliest) - if sleep_time > 0: - logger.warning( - f"Rate limit approaching | Requests: {len(self._request_timestamps)}/{self.rpm_limit} | " - f"Waiting {sleep_time:.1f}s" - ) - time.sleep(sleep_time) - # Clean up old timestamps after waiting - current = time.monotonic() - self._request_timestamps = [ - ts for ts in self._request_timestamps if current - ts < 60] - - @staticmethod - def _strip_code_fences(content: str) -> str: - """Strip markdown code fences from content if present. - - Args: - content: The content string that may contain code fences - - Returns: - Content with code fences removed - """ - if not content: - return content - - content = content.strip() - - # Check for code fences with optional language identifier - if content.startswith('```'): - # Find the end of the first line (language identifier) - first_newline = content.find('\n') - if first_newline != -1: - content = content[first_newline + 1:] - else: - # No newline after opening fence, remove just the fence - content = content[3:] - - # Remove closing fence - if content.endswith('```'): - content = content[:-3] - - return content.strip() - - def generate( - self, - prompt: str | list[str] | None = None, - messages: list[Messages] | Messages | None = None, - response_format: Type[T] | None = None, - metadata: dict[str, Any] | None = None, - ) -> str | list[str] | T | list[T]: - """ - Generate responses from the LLM using single or batch inference. - - Args: - prompt: Single text prompt (str) or list of text prompts for batch processing - messages: Single message list or list of message lists for batch processing - response_format: Optional Pydantic model class for structured output - metadata: Optional LiteLLM metadata for tracing / observability - - Returns: - Single string/model or list of strings/models depending on input type. - - Raises: - ValueError: If neither prompt nor messages is provided, or if both are provided. - RuntimeError: If there's an error during generation. - """ - # Validate inputs - if prompt is None and messages is None: - raise ValueError("Either prompts or messages must be provided") - if prompt is not None and messages is not None: - raise ValueError("Provide either prompts or messages, not both") - - # Determine if this is a single input or batch input - single_input = False - batch_prompts = None - batch_messages = None - - if prompt is not None: - if isinstance(prompt, str): - # Single prompt - convert to batch - batch_prompts = [prompt] - single_input = True - elif isinstance(prompt, list): - # Already a list of prompts - batch_prompts = prompt - single_input = False - else: - raise ValueError("prompt must be a string or list of strings") - - if messages is not None: - if isinstance(messages, list) and len(messages) > 0: - # Check if it's a single message list or batch - if isinstance(messages[0], dict): - # Single message list - convert to batch - batch_messages = [messages] - single_input = True - elif isinstance(messages[0], list): - # Already a batch of message lists - batch_messages = messages - single_input = False - else: - raise ValueError("Invalid messages format") - else: - raise ValueError("messages cannot be empty") - - try: - # Append JSON formatting instructions if response_format is provided - json_instructions = ( - "\nReturn only valid JSON. To do so, don't include ```json ``` markdown " - "or code fences around the JSON. Use double quotes for all keys and values. " - "Escape internal quotes and newlines (use \\n). Do not include trailing commas." - ) - - # Convert batch prompts to messages if needed - batch_to_send = [] - if batch_prompts is not None: - for one_prompt in batch_prompts: - # Append JSON instructions to prompt if response_format is provided - modified_prompt = one_prompt + json_instructions if response_format is not None else one_prompt - batch_to_send.append(get_messages(modified_prompt)) - else: - batch_to_send = batch_messages - # Append JSON instructions to the last user message if response_format is provided - if response_format is not None: - for message_list in batch_to_send: - for msg in reversed(message_list): - if msg.get("role") == "user": - msg["content"] += json_instructions - break - - # Enforce rate limit per batch - self._respect_rate_limit() - - # Prepare completion parameters for batch - completion_params = { - "model": self._get_model_string(), - "messages": batch_to_send, - "temperature": self.temperature, - "max_tokens": self.max_completion_tokens, - "top_p": self.top_p, - "frequency_penalty": self.frequency_penalty, - "timeout": self.timeout, - "metadata": self._build_request_metadata(metadata), - } - if response_format is not None: - completion_params["response_format"] = response_format - - # Call LiteLLM completion with batch messages - retry on rate limit - max_retries = 3 - retry_delay = 5 # Start with 5 seconds - response = None - - for attempt in range(max_retries): - try: - response: list[ModelResponse] = litellm.batch_completion( - **completion_params) - break # Success, exit retry loop - except RateLimitError as e: - if attempt < max_retries - 1: - wait_time = retry_delay * (2 ** attempt) # Exponential backoff - logger.warning( - f"Rate limit hit | Provider: {self.provider_name} | Model: {self.model_id} | " - f"Attempt {attempt + 1}/{max_retries} | Waiting {wait_time}s before retry" - ) - time.sleep(wait_time) - else: - logger.error( - f"Rate limit exceeded after {max_retries} attempts | " - f"Provider: {self.provider_name} | Model: {self.model_id}" - ) - raise - - if response is None: - raise RuntimeError("Failed to get response after retries") - - # Record timestamp for rate limiting (one timestamp per batch item) - if self.rpm_limit is not None: - current_time = time.monotonic() - for _ in range(len(batch_to_send)): - self._request_timestamps.append(current_time) - - # Extract content from each response - results = [] - for idx, one_response in enumerate(response): - if isinstance(one_response, Exception): - if isinstance(one_response, RateLimitError): - logger.warning( - "Rate limit error in batch item | Provider: %s | Model: %s | Item: %d", - self.provider_name, - self.model_id, - idx, - ) - raise RuntimeError( - f"Batch item {idx} failed during generation: {one_response}" - ) from one_response - - if not getattr(one_response, "choices", None): - raise RuntimeError( - f"Unexpected response type from LiteLLM batch completion at item {idx}: {type(one_response).__name__}" - ) - - content = one_response.choices[0].message.content - - if response_format is not None: - # Strip code fences before validation - content = self._strip_code_fences(content) - try: - results.append( - response_format.model_validate_json(content)) - except Exception as validation_error: - # Show the content that failed to parse for debugging - content_preview = content[:200] + "..." if len(content) > 200 else content - logger.warning( - f"JSON parsing failed, skipping response | " - f"Model: {self.model_id} | " - f"Format: {response_format.__name__} | " - f"Content preview: {content_preview}" - ) - raise ValueError( - f"Failed to parse JSON response into {response_format.__name__}.\n" - f"Validation error: {validation_error}\n" - f"Content received (first 200 chars):\n{content_preview}" - ) from validation_error - else: - # Strip leading/trailing whitespace for text responses - results.append(content.strip() if content else content) - - # Return single result for backward compatibility - if single_input and len(results) == 1: - return results[0] - return results - - except Exception as e: - error_trace = traceback.format_exc() - logger.error( - f"Generation failed | Provider: {self.provider_name} | " - f"Model: {self.model_id} | Error: {str(e)}" - ) - raise RuntimeError( - f"Error generating batch response with {self.provider_name}:\n{error_trace}" - ) - - -class OpenAIProvider(LLMProvider): - """OpenAI provider using litellm.responses endpoint. - - Note: This provider uses the new responses endpoint which has different - parameter support compared to the standard completion endpoint: - - temperature, top_p, and frequency_penalty are not supported - - Uses text_format instead of response_format - - Supports reasoning parameter for controlling reasoning effort - - Does not support batch operations (will process sequentially with warning) - """ - - @property - def provider_name(self) -> str: - return "openai" - - @property - def env_key_name(self) -> str: - return "OPENAI_API_KEY" - - def __init__( - self, - model_id: str = "gpt-5-mini-2025-08-07", - api_key: str | None = None, - max_completion_tokens: int | None = None, - reasoning_effort: str = "low", - temperature: float | None = None, - top_p: float | None = None, - frequency_penalty: float | None = None, - timeout: int | None = None, - ): - """Initialize the OpenAI provider. - - Args: - model_id: The model ID (defaults to gpt-5-mini) - api_key: API key (if None, will get from environment) - max_completion_tokens: An upper bound for the number of tokens that can be generated for a completion, including visible output tokens and reasoning tokens. - reasoning_effort: Reasoning effort level - "low", "medium", or "high" (defaults to "low") - temperature: DEPRECATED - Not supported by responses endpoint - top_p: DEPRECATED - Not supported by responses endpoint - frequency_penalty: DEPRECATED - Not supported by responses endpoint - timeout: Request timeout in seconds - """ - # Warn about deprecated parameters - if temperature is not None: - warnings.warn( - "temperature parameter is not supported by OpenAI responses endpoint and will be ignored", - UserWarning, - stacklevel=2 - ) - if top_p is not None: - warnings.warn( - "top_p parameter is not supported by OpenAI responses endpoint and will be ignored", - UserWarning, - stacklevel=2 - ) - if frequency_penalty is not None: - warnings.warn( - "frequency_penalty parameter is not supported by OpenAI responses endpoint and will be ignored", - UserWarning, - stacklevel=2 - ) - - # Store reasoning effort - self.reasoning_effort = reasoning_effort - - # Call parent init with None for unsupported params - super().__init__( - model_id=model_id, - api_key=api_key, - temperature=None, - max_completion_tokens=max_completion_tokens, - top_p=None, - frequency_penalty=None, - timeout=timeout, - ) - - def generate( - self, - prompt: str | list[str] | None = None, - messages: list[Messages] | Messages | None = None, - response_format: Type[T] | None = None, - metadata: dict[str, Any] | None = None, - ) -> str | list[str] | T | list[T]: - """ - Generate responses from the LLM using the responses endpoint. - - Note: Batch operations are processed sequentially as the responses endpoint - does not support native batching. - - Args: - prompt: Single text prompt (str) or list of text prompts for batch processing - messages: Single message list or list of message lists for batch processing - response_format: Optional Pydantic model class for structured output - metadata: Optional LiteLLM metadata for tracing / observability - - Returns: - Single string/model or list of strings/models depending on input type. - - Raises: - ValueError: If neither prompt nor messages is provided, or if both are provided. - RuntimeError: If there's an error during generation. - """ - # Validate inputs - if prompt is None and messages is None: - raise ValueError("Either prompts or messages must be provided") - if prompt is not None and messages is not None: - raise ValueError("Provide either prompts or messages, not both") - - # Determine if this is a single input or batch input - single_input = False - batch_prompts = None - batch_messages = None - - if prompt is not None: - if isinstance(prompt, str): - # Single prompt - convert to batch - batch_prompts = [prompt] - single_input = True - elif isinstance(prompt, list): - # Already a list of prompts - batch_prompts = prompt - single_input = False - else: - raise ValueError("prompt must be a string or list of strings") - - if messages is not None: - if isinstance(messages, list) and len(messages) > 0: - # Check if it's a single message list or batch - if isinstance(messages[0], dict): - # Single message list - convert to batch - batch_messages = [messages] - single_input = True - elif isinstance(messages[0], list): - # Already a batch of message lists - batch_messages = messages - single_input = False - else: - raise ValueError("Invalid messages format") - else: - raise ValueError("messages cannot be empty") - - try: - # Convert batch prompts to messages if needed - batch_to_send = [] - if batch_prompts is not None: - for one_prompt in batch_prompts: - batch_to_send.append([{"role": "user", "content": one_prompt}]) - else: - batch_to_send = batch_messages - - # Warn if batch processing is being used - if len(batch_to_send) > 1: - warnings.warn( - f"OpenAI responses endpoint does not support batch operations. " - f"Processing {len(batch_to_send)} requests sequentially.", - UserWarning, - stacklevel=2 - ) - - # Process each request sequentially - results = [] - for message_list in batch_to_send: - # Enforce rate limit per request - self._respect_rate_limit() - - # Prepare completion parameters - completion_params = { - "model": self._get_model_string(), - "input": message_list, - "reasoning": {"effort": self.reasoning_effort}, - "metadata": self._build_request_metadata(metadata), - } - - # Add max_output_tokens if specified - if self.max_completion_tokens is not None: - completion_params["max_output_tokens"] = self.max_completion_tokens - - # Add text_format if response_format is provided - if response_format is not None: - completion_params["text_format"] = response_format - - # Call LiteLLM responses endpoint - response = litellm.responses(**completion_params) - - # Record timestamp for rate limiting - if self.rpm_limit is not None: - self._request_timestamps.append(time.monotonic()) - - # Extract content from response - # Response structure: response.output[1].content[0].text - content = response.output[1].content[0].text - - if response_format is not None: - # Strip code fences before validation - content = self._strip_code_fences(content) - try: - results.append(response_format.model_validate_json(content)) - except Exception as validation_error: - # Show the content that failed to parse for debugging - content_preview = content[:200] + "..." if len(content) > 200 else content - logger.warning( - f"JSON parsing failed, skipping response | " - f"Model: {self.model_id} | " - f"Format: {response_format.__name__} | " - f"Content preview: {content_preview}" - ) - raise ValueError( - f"Failed to parse JSON response into {response_format.__name__}.\n" - f"Validation error: {validation_error}\n" - f"Content received (first 200 chars):\n{content_preview}" - ) from validation_error - else: - # Strip leading/trailing whitespace for text responses - results.append(content.strip() if content else content) - - # Return single result for backward compatibility - if single_input and len(results) == 1: - return results[0] - return results - - except Exception as e: - error_trace = traceback.format_exc() - logger.error( - f"Generation failed | Provider: {self.provider_name} | " - f"Model: {self.model_id} | Error: {str(e)}" - ) - raise RuntimeError( - f"Error generating response with {self.provider_name}:\n{error_trace}" - ) - - -class AnthropicProvider(LLMProvider): - """Anthropic provider using litellm.""" - - @property - def provider_name(self) -> str: - return "anthropic" - - @property - def env_key_name(self) -> str: - return "ANTHROPIC_API_KEY" - - def __init__( - self, - model_id: str = "claude-haiku-4-5-20251001", - api_key: str | None = None, - temperature: float | None = None, - max_completion_tokens: int | None = None, - timeout: int | None = None, - # top_p: float | None = None, # Not properly supported by anthropic models 4.5 - # frequency_penalty: float | None = None, # Not supported by anthropic models 4.5 - ): - """Initialize the Anthropic provider. - - Args: - model_id: The model ID (defaults to claude-haiku-4-5-20251001) - api_key: API key (if None, will get from environment) - temperature: Temperature for generation (0.0 to 1.0) - max_completion_tokens: Maximum tokens to generate - timeout: Request timeout in seconds - top_p: Nucleus sampling parameter (0.0 to 1.0) - """ - super().__init__( - model_id=model_id, - api_key=api_key, - temperature=temperature, - max_completion_tokens=max_completion_tokens, - timeout=timeout, - ) - - -class GeminiProvider(LLMProvider): - """Google Gemini provider using litellm.""" - - @property - def provider_name(self) -> str: - return "gemini" - - @property - def env_key_name(self) -> str: - return "GEMINI_API_KEY" - - def __init__( - self, - model_id: str = "gemini-2.0-flash", - api_key: str | None = None, - temperature: float | None = None, - max_completion_tokens: int | None = None, - top_p: float | None = None, - frequency_penalty: float | None = None, - rpm_limit: int | None = None, - timeout: int | None = None, - ): - """Initialize the Gemini provider. - - Args: - model_id: The model ID (defaults to gemini-2.0-flash) - api_key: API key (if None, will get from environment) - temperature: Temperature for generation (0.0 to 1.0) - max_completion_tokens: Maximum tokens to generate - top_p: Nucleus sampling parameter (0.0 to 1.0) - frequency_penalty: Penalty for token frequency (-2.0 to 2.0) - timeout: Request timeout in seconds - """ - super().__init__( - model_id=model_id, - api_key=api_key, - temperature=temperature, - max_completion_tokens=max_completion_tokens, - top_p=top_p, - frequency_penalty=frequency_penalty, - rpm_limit=rpm_limit, - timeout=timeout, - ) - - -class OllamaProvider(LLMProvider): - """Ollama provider using litellm. - - Note: Ollama typically doesn't require an API key as it's usually run locally. - """ - - @property - def provider_name(self) -> str: - return "ollama_chat" - - @property - def env_key_name(self) -> str: - return "OLLAMA_API_BASE" - - def _get_api_key(self) -> str: - """Override to handle Ollama not requiring an API key. - - Returns an empty string since Ollama typically doesn't need an API key. - OLLAMA_API_BASE can be used to set a custom base URL. - """ - return "" - - def __init__( - self, - model_id: str = "gemma3:4b", - temperature: float | None = None, - max_completion_tokens: int | None = None, - top_p: float | None = None, - frequency_penalty: float | None = None, - api_base: str | None = None, - rpm_limit: int | None = None, - timeout: int | None = None, - ): - """Initialize the Ollama provider. - - Args: - model_id: The model ID (defaults to llama3) - temperature: Temperature for generation (0.0 to 1.0) - max_completion_tokens: Maximum tokens to generate - top_p: Nucleus sampling parameter (0.0 to 1.0) - frequency_penalty: Penalty for token frequency (-2.0 to 2.0) - api_base: Base URL for Ollama API (e.g., "http://localhost:11434") - timeout: Request timeout in seconds - """ - # Set API base URL if provided - if api_base: - os.environ["OLLAMA_API_BASE"] = api_base - - super().__init__( - model_id=model_id, - api_key="", # Pass empty string since parent class requires this parameter - temperature=temperature, - max_completion_tokens=max_completion_tokens, - top_p=top_p, - frequency_penalty=frequency_penalty, - rpm_limit=rpm_limit, - timeout=timeout, - ) - - -class OpenRouterProvider(LLMProvider): - """OpenRouter provider using litellm""" - - @property - def provider_name(self) -> str: - return "openrouter" - - @property - def env_key_name(self) -> str: - return "OPENROUTER_API_KEY" - - def __init__( - self, - model_id: str = "openai/gpt-5-mini", # for default model - api_key: str | None = None, - temperature: float | None = None, - max_completion_tokens: int | None = None, - top_p: float | None = None, - frequency_penalty: float | None = None, - timeout: int | None = None, - ): - """Initialize the OpenRouter provider. - - Args: - model_id: The model ID (defaults to openai/gpt-5-mini) - api_key: API key (if None, will get from environment) - temperature: Temperature for generation (0.0 to 1.0) - max_completion_tokens: Maximum tokens to generate - top_p: Nucleus sampling parameter (0.0 to 1.0) - frequency_penalty: Penalty for token frequency (-2.0 to 2.0) - timeout: Request timeout in seconds - """ - super().__init__( - model_id = model_id, - api_key = api_key, - temperature = temperature, - max_completion_tokens = max_completion_tokens, - top_p = top_p, - frequency_penalty = frequency_penalty, - timeout = timeout, - ) - - -class MistralProvider(LLMProvider): - """Mistral AI provider using litellm.""" - - @property - def provider_name(self) -> str: - return "mistral" - - @property - def env_key_name(self) -> str: - return "MISTRAL_API_KEY" - - def __init__( - self, - model_id: str = "mistral-small-latest", - api_key: str | None = None, - temperature: float | None = None, - max_completion_tokens: int | None = None, - top_p: float | None = None, - frequency_penalty: float | None = None, - rpm_limit: int | None = None, - timeout: int | None = None, - ): - """Initialize the Mistral provider. - - Args: - model_id: The model ID (defaults to mistral-small-latest) - api_key: API key (if None, will get from MISTRAL_API_KEY env var) - temperature: Temperature for generation (0.0 to 1.0) - max_completion_tokens: Maximum tokens to generate - top_p: Nucleus sampling parameter (0.0 to 1.0) - frequency_penalty: Penalty for token frequency (-2.0 to 2.0) - rpm_limit: Requests per minute limit for rate limiting - timeout: Request timeout in seconds - """ - super().__init__( - model_id=model_id, - api_key=api_key, - temperature=temperature, - max_completion_tokens=max_completion_tokens, - top_p=top_p, - frequency_penalty=frequency_penalty, - rpm_limit=rpm_limit, - timeout=timeout, - ) - - -def openai(model_id: str = "gpt-5-mini-2025-08-07", **kwargs) -> OpenAIProvider: - """Create an OpenAI provider instance.""" - return OpenAIProvider(model_id=model_id, **kwargs) - - -def anthropic( - model_id: str = "claude-haiku-4-5-20251001", - **kwargs, -) -> AnthropicProvider: - """Create an Anthropic provider instance.""" - return AnthropicProvider(model_id=model_id, **kwargs) - - -def gemini(model_id: str = "gemini-2.0-flash", **kwargs) -> GeminiProvider: - """Create a Gemini provider instance.""" - return GeminiProvider(model_id=model_id, **kwargs) - - -def ollama(model_id: str = "gemma3:4b", **kwargs) -> OllamaProvider: - """Create an Ollama provider instance.""" - return OllamaProvider(model_id=model_id, **kwargs) - - -def openrouter( - model_id: str = "openai/gpt-5-mini", - **kwargs, -) -> OpenRouterProvider: - """Create an OpenRouter provider instance.""" - return OpenRouterProvider(model_id=model_id, **kwargs) - - -def mistral(model_id: str = "mistral-small-latest", **kwargs) -> MistralProvider: - """Create a Mistral provider instance.""" - return MistralProvider(model_id=model_id, **kwargs) +import litellm __all__ = [ @@ -894,13 +30,18 @@ def mistral(model_id: str = "mistral-small-latest", **kwargs) -> MistralProvider "OpenAIProvider", "AnthropicProvider", "GeminiProvider", - "OllamaProvider", - "OpenRouterProvider", "MistralProvider", + "OpenRouterProvider", + "OllamaProvider", + "OpenAICompatibleProvider", "openai", "anthropic", "gemini", - "ollama", - "openrouter", "mistral", + "openrouter", + "ollama", + "openai_compatible", + "litellm", + "load_env_once", + "maybe_configure_langfuse_tracing", ] diff --git a/datafast/transforms/__init__.py b/datafast/transforms/__init__.py index 025ea3f..f7a88d2 100644 --- a/datafast/transforms/__init__.py +++ b/datafast/transforms/__init__.py @@ -1,7 +1,7 @@ """Transform steps for datafast v2.""" from datafast.transforms.sample import Sample -from datafast.transforms.data_ops import Map, FlatMap, Filter, Group, Pair, Concat, Join +from datafast.transforms.data_ops import AddUUID, Map, FlatMap, Filter, Group, Pair, Concat, Join from datafast.transforms.llm_step import LLMStep from datafast.transforms.llm_eval import Classify, Score, Compare from datafast.transforms.llm_transform import Rewrite @@ -9,7 +9,7 @@ from datafast.transforms.branch import Branch, JoinBranches __all__ = [ - "Sample", "Map", "FlatMap", "Filter", "Group", "Pair", "Concat", "Join", + "Sample", "AddUUID", "Map", "FlatMap", "Filter", "Group", "Pair", "Concat", "Join", "LLMStep", "Classify", "Score", "Compare", "Rewrite", "Extract", "Branch", "JoinBranches", ] diff --git a/datafast/transforms/data_ops.py b/datafast/transforms/data_ops.py index 3887460..fafb5cf 100644 --- a/datafast/transforms/data_ops.py +++ b/datafast/transforms/data_ops.py @@ -3,6 +3,7 @@ import itertools import random import re +import uuid from collections import defaultdict from collections.abc import Callable, Iterable from typing import Any @@ -62,6 +63,34 @@ def process(self, records: Iterable[Record]) -> Iterable[Record]: yield from self._fn(record) +class AddUUID(Step): + """Add a UUID field to each record.""" + + def __init__(self, column: str = "id", overwrite: bool = False) -> None: + """ + Initialize an AddUUID step. + + Args: + column: Field name to write the UUID into. + overwrite: If True, replace existing values in the target column. + + Examples: + >>> AddUUID() + >>> AddUUID(column="example_id", overwrite=True) + """ + super().__init__() + self._column = column + self._overwrite = overwrite + + def process(self, records: Iterable[Record]) -> Iterable[Record]: + """Add UUIDs while preserving all other fields.""" + for record in records: + if self._column in record and not self._overwrite: + yield record + else: + yield {**record, self._column: str(uuid.uuid4())} + + class Filter(Step): """Keep or drop records based on conditions.""" diff --git a/datafast/transforms/llm_eval.py b/datafast/transforms/llm_eval.py index b0ea320..6fb3e78 100644 --- a/datafast/transforms/llm_eval.py +++ b/datafast/transforms/llm_eval.py @@ -366,7 +366,7 @@ def _process_llm(self, records: Iterable[Record]) -> Iterable[Record]: try: messages = self._build_messages(record) raw = model.generate( - messages, + messages=messages, metadata=build_trace_metadata( model=model, component="step.process", @@ -657,7 +657,7 @@ def _process_llm(self, records: Iterable[Record]) -> Iterable[Record]: try: messages = self._build_messages(record) raw = model.generate( - messages, + messages=messages, metadata=build_trace_metadata( model=model, component="step.process", @@ -1011,7 +1011,7 @@ def _process_llm(self, records: Iterable[Record]) -> Iterable[Record]: try: messages = self._build_messages(record) raw = model.generate( - messages, + messages=messages, metadata=build_trace_metadata( model=model, component="step.process", diff --git a/datafast/transforms/llm_extract.py b/datafast/transforms/llm_extract.py index aa8161d..9d3e095 100644 --- a/datafast/transforms/llm_extract.py +++ b/datafast/transforms/llm_extract.py @@ -418,7 +418,7 @@ def _process_llm(self, records: Iterable[Record]) -> Iterable[Record]: try: messages = self._build_messages(record) raw = model.generate( - messages, + messages=messages, metadata=build_trace_metadata( model=model, component="step.process", diff --git a/datafast/transforms/llm_step.py b/datafast/transforms/llm_step.py index d2aae42..ad1a8fb 100644 --- a/datafast/transforms/llm_step.py +++ b/datafast/transforms/llm_step.py @@ -384,7 +384,7 @@ def process(self, records: Iterable[Record]) -> Iterable[Record]: messages = self._build_messages(prompt_template, context) raw_output = model.generate( - messages, + messages=messages, metadata=build_trace_metadata( model=model, component="step.process", diff --git a/datafast/transforms/llm_transform.py b/datafast/transforms/llm_transform.py index 8901a03..105ce65 100644 --- a/datafast/transforms/llm_transform.py +++ b/datafast/transforms/llm_transform.py @@ -298,7 +298,7 @@ def process(self, records: Iterable[Record]) -> Iterable[Record]: try: messages = self._build_messages(record) raw = model.generate( - messages, + messages=messages, metadata=build_trace_metadata( model=model, component="step.process", diff --git a/docs/api.md b/docs/api.md index edef161..45857e2 100644 --- a/docs/api.md +++ b/docs/api.md @@ -36,6 +36,7 @@ from datafast import Source, LLMStep, Sink, openrouter ## Data Operations - `Sample` +- `AddUUID` - `Map` - `FlatMap` - `Filter` diff --git a/docs/cookbook/assets/index.md b/docs/cookbook/assets/index.md new file mode 100644 index 0000000..65896be --- /dev/null +++ b/docs/cookbook/assets/index.md @@ -0,0 +1,80 @@ +# Cookbook Assets + +Prompt files and dataset details used by cookbook examples. + +## Text Classification + +### Dataset + +- **Source:** seed dimensions created with `Seed.product` +- **Dimensions:** label, trail type, style, language, and model +- **Local output:** `examples/outputs/45_text_classification_cookbook.jsonl` +- **Checkpoints:** `examples/checkpoints/45_text_classification_cookbook` +- **Hub output:** optional, controlled by `DATAFAST_PUSH_TO_HUB=1` + +This cookbook models variation directly as seed dimensions so the label, trail +type, style, language, and model are all explicit in the +pipeline. + +### Prompt + +| File | Style | +| --- | --- | +| [text_classification_generation.txt](text_classification_generation.txt) | One short trail comment per call, with label, trail type, style, and language injected | + +## Persona Generation + +### Dataset + +- **Source:** `xsum` (Hugging Face), `validation` split +- **Fields used:** `id`, `document`, `summary` +- **Filter:** 300–500 words, first 100 matches +- **Local output:** `examples/outputs/43_persona_cookbook.jsonl` +- **Checkpoints:** `examples/checkpoints/43_persona_cookbook` +- **Hub output:** set `HF_REPO_ID` and the `repo_id` in `push_records_to_hub()` to repos under your own Hugging Face username or organization + +The example keeps first-match sampling for reproducibility. For local JSONL corpora with metadata such as `document_filename`, stratified sampling is usually a better fit. + +### Prompt Variants + +Each LLM step picks one prompt at random per record. The script also assigns random `life_stage` and `related_life_stage` values before the corresponding LLM steps. Multiple variants add diversity. + +#### Text-to-Persona + +| File | Style | +| --- | --- | +| [text_to_persona_v1.txt](text_to_persona_v1.txt) | Direct inference of a reader persona | +| [text_to_persona_v2.txt](text_to_persona_v2.txt) | XML-tagged source text, writer/reader framing | +| [text_to_persona_v3.txt](text_to_persona_v3.txt) | System-role preamble, search-interest angle | + +#### Persona-to-Persona + +| File | Style | +| --- | --- | +| [persona_to_persona_v1.txt](persona_to_persona_v1.txt) | Close relationship, standalone description | +| [persona_to_persona_v2.txt](persona_to_persona_v2.txt) | Rule-list format, explicit separation of description and relationship | +| [persona_to_persona_v3.txt](persona_to_persona_v3.txt) | XML-tagged input, concise vivid output | + +### Provenance + +- Text-to-Persona and Persona-to-Persona prompts are paper-aligned adaptations. The Persona Hub paper states its published prompts are simplified, not exact. +- No Persona Hub code is reused. The workflow is built with datafast primitives. + +## Space Engineering Text Generation + +### Dataset + +- **Source:** seed dimensions created with `Seed.product` +- **Dimensions:** document type, topic, expertise level, and language +- **Local output:** `examples/outputs/44_space_text_generation_cookbook.jsonl` +- **Checkpoints:** `examples/checkpoints/44_space_text_generation_cookbook` +- **Hub output:** optional, controlled by `DATAFAST_PUSH_TO_HUB=1` + +### Prompt + +The text-generation cookbook uses one compact prompt and relies on seed +dimensions for variation. + +| File | Style | +| --- | --- | +| [space_text_generation.txt](space_text_generation.txt) | Minimal variable-driven request | diff --git a/docs/cookbook/assets/persona_to_persona_v1.txt b/docs/cookbook/assets/persona_to_persona_v1.txt new file mode 100644 index 0000000..eabb6d6 --- /dev/null +++ b/docs/cookbook/assets/persona_to_persona_v1.txt @@ -0,0 +1,11 @@ +Given the following persona, infer one other specific persona who is in a close relationship with them. + +Persona: +{persona_description} + +Requirements: +1. Use one clear relationship such as family member, colleague, friend, or neighbor, coach, teacher, married partner. +2. Choose a related persona that adds a meaningfully different life perspective but is still likely to be in close contact with the original persona. +3. Keep the related persona realistic and specific. +4. Don't talk about the orginal person in the description of the related persona, as it should be self-contained description. +5. The related persona must be {related_life_stage}. Do not state a precise age, just reflect this life stage naturally. diff --git a/docs/cookbook/assets/persona_to_persona_v2.txt b/docs/cookbook/assets/persona_to_persona_v2.txt new file mode 100644 index 0000000..b4e4adf --- /dev/null +++ b/docs/cookbook/assets/persona_to_persona_v2.txt @@ -0,0 +1,14 @@ +Think of a person who regularly interacts with the following persona in a meaningful way. + +Rules: +- Do not mention the original persona in the description of the related persona. +- Do not mention the relationship between the two personas in the description, only in the relationship_type +- Pick a single, concrete relationship type such as mentor-mentee, colleague, neighbor, supervisor-report, or service provider-client +- The related person should bring a distinctly different viewpoint or expertise, and some uniqueness. +- Keep the description realistic and standalone without mentionning with the original persona. +- The related persona must be {related_life_stage}. Do not state a precise age, just reflect this life stage naturally. + +Original Persona: +{persona_description} + +Now generate a related persona. \ No newline at end of file diff --git a/docs/cookbook/assets/persona_to_persona_v3.txt b/docs/cookbook/assets/persona_to_persona_v3.txt new file mode 100644 index 0000000..9652161 --- /dev/null +++ b/docs/cookbook/assets/persona_to_persona_v3.txt @@ -0,0 +1,16 @@ +Here is the description of someone: + +{persona_description} + + +Come up with one other description of an individual who could be part of this persona's life. +We want the description to be detailed but super concise (max 2 sentences) and vivid. +But we want to have the a standalone description of that new persona without mentioning the original persona or a reason in the description. + +Requirements: +1. Define a clear interpersonal link such as friend, advisor, competitor, family member, or collaborator. +2. The new persona should offer a complementary or contrasting perspective. +3. Make the related persona vivid and believable, avoid generic archetypes. +4. Describe the relation in relationship_type field, not in the description. +5. The related persona must be {related_life_stage}. Do not state a precise age, just reflect this life stage naturally. + diff --git a/docs/cookbook/assets/space_text_generation.txt b/docs/cookbook/assets/space_text_generation.txt new file mode 100644 index 0000000..ca5af4b --- /dev/null +++ b/docs/cookbook/assets/space_text_generation.txt @@ -0,0 +1 @@ +Write one {document_type} excerpt about {topic} for {expertise_level} in {language_name}. diff --git a/docs/cookbook/assets/text_classification_generation.txt b/docs/cookbook/assets/text_classification_generation.txt new file mode 100644 index 0000000..85dc0f1 --- /dev/null +++ b/docs/cookbook/assets/text_classification_generation.txt @@ -0,0 +1,16 @@ +Write one realistic trail comment in {language_name} that sounds like something +an actual hiker would write after being on the trail. + +Target category: {label} +Category definition: {label_description} + +Constraints: +- The comment must clearly match the target category. +- The setting must be a {trail_type}. +- The writing style must be {style}. +- Keep it to 1 or 2 sentences. +- Make it sound first-hand, natural, and slightly informal when appropriate. +- Do not sound like an official report, safety bulletin, or structured form. +- Do not mention the category name directly. +- Do not use bullets, numbering, or explanations. +- Make the comment concrete and varied. diff --git a/docs/cookbook/assets/text_to_persona_v1.txt b/docs/cookbook/assets/text_to_persona_v1.txt new file mode 100644 index 0000000..cd09909 --- /dev/null +++ b/docs/cookbook/assets/text_to_persona_v1.txt @@ -0,0 +1,17 @@ +Infer one specific persona who is likely to read text. + +Source text: +{document} + +Requirements: +1. Return a single persona, not a group. +2. Make the persona specific and fine-grained rather than generic. +3. Ground the persona in signals from the text such as domain, expertise, context, or likely motivation. +4. Do not quote the source text in the persona field. +5. Only write 1 or 2 sentences maximum. +6. The persona is not the subject of the text, but rather someone who would be reading it. +7. Do not refer to the source text, article, or its content in the persona description. The persona must be self-contained. +8. The persona must be {life_stage}. Do not mention a precise age, just reflect this life stage naturally. + +Now figure out a persona description who would be reading this text. + diff --git a/docs/cookbook/assets/text_to_persona_v2.txt b/docs/cookbook/assets/text_to_persona_v2.txt new file mode 100644 index 0000000..294577d --- /dev/null +++ b/docs/cookbook/assets/text_to_persona_v2.txt @@ -0,0 +1,16 @@ + +{document} + + +Identify one precise individual who would naturally encounter or write the . + +Requirements: +1. Describe exactly one person. +2. Be as specific as possible: mention plausible occupation and/or life situation. +3. Derive the persona strictly from cues in the text such as topic, jargon, tone, or implied audience as a potential writter / reader of this text. +4. Do not copy or paraphrase the source text in the persona field. +5. Only return 1 or 2 sentences maximum. +6. The described person is not the subject of the text, but rather someone who would be encountering or writing such text as part of their life. +7. Do not reference the source text, article, or its content in the persona description. The persona must stand on its own. +8. The persona must be {life_stage}. Do not state a precise age, just reflect this life stage naturally. + diff --git a/docs/cookbook/assets/text_to_persona_v3.txt b/docs/cookbook/assets/text_to_persona_v3.txt new file mode 100644 index 0000000..3ccb077 --- /dev/null +++ b/docs/cookbook/assets/text_to_persona_v3.txt @@ -0,0 +1,17 @@ +You are a persona inference assistant. + +Based on the text content below, imagine one real person who would be interested in searching about the topic from this content. + +Rules: +- Output a single, concrete persona rather than a broad demographic. +- Include details like professional background, interests, or situational context that make the persona feel authentic. +- Don't mention the person search or information retrieval action in the persona description, just describe the persona which could explain their interest in the topic. +- Keep it super short and concise. +- Do not mention or refer to the source text, article, or its content in the persona description. The persona must be self-contained. +- The persona must be {life_stage}. Do not state a precise age, just reflect this life stage naturally. + +Source text: +{document} + + + diff --git a/docs/cookbook/index.md b/docs/cookbook/index.md new file mode 100644 index 0000000..1b745ec --- /dev/null +++ b/docs/cookbook/index.md @@ -0,0 +1,16 @@ +# Cookbook + +Cookbooks connect a runnable script to a documentation walkthrough. + +The Python script is the source of truth. Each cookbook page explains: + +- where the executable example lives +- what inputs it uses +- which prompt assets it depends on +- where it writes its output artifacts + +## Available Cookbooks + +- [Text Classification](text_classification.md): generate a multilingual trail-conditions classification dataset from explicit seed dimensions. +- [Persona Generation](persona_generation.md): infer personas from real articles and expand them through relationships using randomized prompt variants. +- [Space Engineering Text Generation](space_text_generation.md): generate a raw multilingual technical text corpus from seed dimensions. diff --git a/docs/cookbook/persona_generation.md b/docs/cookbook/persona_generation.md new file mode 100644 index 0000000..f314a39 --- /dev/null +++ b/docs/cookbook/persona_generation.md @@ -0,0 +1,89 @@ +# Persona Generation + +Build personas from real articles and expand them through relationships. Inspired by the Persona Hub paper, implemented entirely with datafast. + +## Source + +- **Script:** `examples/scripts/43_cookbook_persona_generation.py` +- **Prompt assets:** [asset index](assets/index.md) +- **Local output:** `examples/outputs/43_persona_cookbook.jsonl` +- **Checkpoints:** `examples/checkpoints/43_persona_cookbook` +- **Hub output:** pushed to the Hugging Face Hub repo IDs configured in the script + +## Pipeline + +1. Load `xsum` articles (`validation` split), preserving the dataset `id`. +2. Filter to documents between 300 and 500 words. Keep the first 100 matches. +3. Assign a random life stage to the source persona. +4. **Text-to-Persona** — infer one persona from each article and life stage. +5. Assign a random life stage to the related persona. +6. **Persona-to-Persona** — expand that persona into a related individual. +7. Keep the final output fields, add a row UUID, write JSONL, checkpoint progress, and push results to Hugging Face Hub. + +Each LLM step randomly picks one prompt variant per record using `Sample(prompts, n=1)`. This adds diversity across generations. + +The cookbook keeps `Sample(n=100, strategy="first")` so runs are deterministic and easy to compare. For local corpora with source metadata, use stratified sampling, for example `Sample(n=250, strategy="stratified", by="document_filename")`, to avoid over-representing one source file. + +```text +xsum article + │ + ▼ +life_stage (random from configured stages) + │ + ▼ +Text-to-Persona (random prompt from 3 variants) + │ + ▼ +related_life_stage (random from configured stages) + │ + ▼ +Persona-to-Persona (random prompt from 3 variants) + │ + ▼ +Hugging Face Hub +``` + +## Run + +Prerequisites: + +- `OPENROUTER_API_KEY` set in a `.env` file +- Hugging Face authentication via `HF_TOKEN` in `.env` or a cached `huggingface_hub` login +- Base dependencies from `pyproject.toml` installed + +Before running, replace the example Hugging Face namespaces in the script with your own username or organization: + +- `HF_REPO_ID = "/new-persona-cookbook-dataset"` controls the private pipeline sink. +- `repo_id = "/datafast-persona-cookbook"` inside `push_records_to_hub()` controls the public publish step. + +```bash +python examples/scripts/43_cookbook_persona_generation.py +``` + +The run uses `checkpoint_dir` and `resume=True`, which is useful for paid or rate-limited LLM calls. If a run is interrupted, re-run the same command to continue from the saved checkpoints. + +The main example reads from Hugging Face. For a local JSONL corpus, replace `Source.huggingface(...)` with `Source.file(...)` and map your text column to `document` before `add_word_count`. + +## Prompt Variants + +Each step draws from multiple prompt files stored under `docs/cookbook/assets/`. See the [asset index](assets/index.md) for the full list. + +- **Text-to-Persona:** 3 variants (`text_to_persona_v1.txt`, `v2`, `v3`) +- **Persona-to-Persona:** 3 variants (`persona_to_persona_v1.txt`, `v2`, `v3`) + +## Research Basis + +The Persona Hub paper introduces Text-to-Persona and Persona-to-Persona as scalable methods for building personas from web text. The paper states that its published prompts are simplified, not the exact experiment strings. This cookbook treats them as paper-aligned adaptations. It does not reuse any Persona Hub code. + +## Output Fields + +- `id` — generated row UUID +- `source_id` — original XSum record identifier +- `summary` — original article summary +- `document` — source article text +- `word_count` — whitespace token count +- `life_stage` — randomly selected life stage for the inferred persona +- `persona_description` — inferred persona +- `relationship_type` — link between the two personas +- `related_life_stage` — randomly selected life stage for the expanded persona +- `related_persona_description` — the expanded related persona diff --git a/docs/cookbook/space_text_generation.md b/docs/cookbook/space_text_generation.md new file mode 100644 index 0000000..92c55dc --- /dev/null +++ b/docs/cookbook/space_text_generation.md @@ -0,0 +1,103 @@ +# Space Engineering Text Generation + +Build a raw technical text corpus across document types, topics, expertise levels, +languages, and model choices. + +## Source + +- **Script:** `examples/scripts/44_cookbook_space_text_generation.py` +- **Prompt assets:** [asset index](assets/index.md) +- **Local output:** `examples/outputs/44_space_text_generation_cookbook.jsonl` +- **Checkpoints:** `examples/checkpoints/44_space_text_generation_cookbook` +- **Hub output:** optional, controlled by `DATAFAST_PUSH_TO_HUB=1` + +## Pipeline + +1. Create a seed grid with `Seed.product`. +2. Cross document types, topics, and expertise levels explicitly. +3. Generate one section per seed and language with `LLMStep`. +4. Let the prompt variables drive the corpus variation. +5. Parse `title` and `text` from JSON mode. +6. Keep publication fields, add a row UUID, write JSONL, checkpoint progress, + and optionally push to Hugging Face Hub. + +The default model is `nvidia/nemotron-3-super-120b-a12b:nitro` through +OpenRouter. + +```text +document_type x topic x expertise_level + | + v +LLMStep language expansion: English and French + | + v +JSON fields: title, text + | + v +examples/outputs/44_space_text_generation_cookbook.jsonl +``` + +## Row Count + +The default script generates: + +```text +3 document types x 8 topics x 3 expertise levels x 2 languages +x 1 generated output x 1 model = 144 rows +``` + +To use several models, add provider IDs to `MODEL_IDS`. `LLMStep` will run each +seed-language combination through every model and the row count will multiply by +the number of models. + +## Run + +Prerequisites: + +- `OPENROUTER_API_KEY` set in a `.env` file +- Base dependencies from `pyproject.toml` installed +- Hugging Face authentication only if publishing + +```bash +python examples/scripts/44_cookbook_space_text_generation.py +``` + +To publish, replace `HF_REPO_ID` in the script with a repository under your own +Hugging Face username or organization, then run: + +```bash +DATAFAST_PUSH_TO_HUB=1 python examples/scripts/44_cookbook_space_text_generation.py +``` + +The run uses `checkpoint_dir` and `resume=True`. If generation is interrupted, +run the command again to continue from saved checkpoints. + +## Prompt + +The script uses one compact prompt file: + +```text +Write one {document_type} excerpt about {topic} for {expertise_level} in {language_name}. +``` + +## Generation Controls + +- `MODEL_IDS` controls which models generate each record. +- `LANGUAGES` controls language expansion and writes the emitted language code to + the `language` field. +- `NUM_OUTPUTS` controls how many generated rows are created for each + seed, language, and model combination. +- `PROMPT_PATH` controls the prompt file used for generation. +- `SEED` controls deterministic dataset splitting when publishing. +- `HF_REPO_ID` controls the optional Hugging Face Hub destination. + +## Output Fields + +- `id` - generated row UUID +- `document_type` - requested document style +- `topic` - space engineering topic +- `expertise_level` - intended reader level +- `language` - language code emitted by `LLMStep` +- `model` - model ID emitted by `LLMStep` +- `title` - generated section title +- `text` - generated corpus text diff --git a/docs/cookbook/text_classification.md b/docs/cookbook/text_classification.md new file mode 100644 index 0000000..dd36422 --- /dev/null +++ b/docs/cookbook/text_classification.md @@ -0,0 +1,118 @@ +# Text Classification + +Build a multilingual trail-conditions classification dataset with `datafast`. + +## Source + +- **Script:** `examples/scripts/45_cookbook_text_classification.py` +- **Prompt assets:** [asset index](assets/index.md) +- **Local output:** `examples/outputs/45_text_classification_cookbook.jsonl` +- **Checkpoints:** `examples/checkpoints/45_text_classification_cookbook` +- **Hub output:** optional, controlled by `DATAFAST_PUSH_TO_HUB=1` + +## Use Case + +This cookbook generates short trail comments across four trail-condition labels +so teams can monitor trail quality and surface issues quickly. + +The default setup is: + +- multi-class: 4 trail-condition labels +- multi-lingual: English and French +- multi-model: two generation models by default +- publishable: optional push to Hugging Face Hub + +## Pipeline + +1. Create a seed grid from labels, trail types, and writing styles. +2. Generate one short trail comment for each seed across all configured models + and languages. +3. Keep the label and prompt-variation provenance in flat output columns. +4. Add a UUID, write JSONL locally, and optionally push to Hugging Face Hub. + +Variation is modeled explicitly through `Seed.product(...)`, which keeps the +generation axes inspectable and easy to count. + +```text +label x trail_type x style + | + v +LLMStep language expansion: English and French + | + v +LLMStep model expansion + | + v +examples/outputs/45_text_classification_cookbook.jsonl +``` + +## Row Count + +The default script generates: + +```text +4 labels x 3 trail types x 2 styles x 2 languages +x 2 models = 96 rows +``` + +Each extra model in `MODEL_IDS` multiplies the total row count. + +## Run + +Prerequisites: + +- `OPENROUTER_API_KEY` set in a `.env` file +- Base dependencies from `pyproject.toml` installed +- Hugging Face authentication only if publishing + +```bash +python examples/scripts/45_cookbook_text_classification.py +``` + +To publish, replace `HF_REPO_ID` in the script with a repository under your own +Hugging Face username or organization, then run: + +```bash +DATAFAST_PUSH_TO_HUB=1 python examples/scripts/45_cookbook_text_classification.py +``` + +The run uses `checkpoint_dir` and `resume=True`. If generation is interrupted, +run the command again to continue from saved checkpoints. + +If you want to use provider-specific clients directly, replace `MODEL_IDS` or +the `model=MODELS` argument in `LLMStep` with providers such as `openai(...)` +or `anthropic(...)`. The default setup uses multiple OpenRouter-backed models +so it works with one API key. + +## Prompt + +The cookbook uses one prompt file and drives diversity through seed dimensions: + +```text +Write one realistic trail comment in {language_name}. +``` + +See [text_classification_generation.txt](assets/text_classification_generation.txt) +for the full prompt. + +## Generation Controls + +- `LABELS` defines the target classes and their prompt descriptions. +- `TRAIL_TYPES` controls the trail settings used in generation. +- `STYLES` controls the voice and format of each comment. +- `LANGUAGES` controls language expansion. +- `MODEL_IDS` controls which models generate records. +- `HF_REPO_ID` controls the optional Hugging Face Hub destination. + +If you want an extra quality-control pass, add a downstream `Classify` and +`Filter` stage to verify that generated comments match their intended label. + +## Output Fields + +- `id` - generated row UUID +- `label` - target trail-condition label +- `trail_type` - prompt expansion axis for the trail setting +- `style` - prompt expansion axis for the comment style +- `language` - language code emitted by `LLMStep` +- `model` - model ID emitted by `LLMStep` +- `text` - generated trail comment diff --git a/docs/guides/building_pipelines.md b/docs/guides/building_pipelines.md index 64aaaf2..b755410 100644 --- a/docs/guides/building_pipelines.md +++ b/docs/guides/building_pipelines.md @@ -3,11 +3,12 @@ ## Minimal Pipeline ```python -from datafast import Map, Sink, Source +from datafast import AddUUID, Map, Sink, Source pipeline = ( Source.list([{"text": "hello"}]) >> Map(lambda r: {**r, "length": len(r["text"])}) + >> AddUUID() >> Sink.list() ) @@ -38,6 +39,7 @@ seed = Seed.product( ## Core Data Operations +- `AddUUID`: add a UUID field to each record - `Map`: one record in, one record out - `FlatMap`: one record in, many records out - `Filter`: keep or drop records diff --git a/examples/providers/README.md b/examples/providers/README.md new file mode 100644 index 0000000..5c9e4c7 --- /dev/null +++ b/examples/providers/README.md @@ -0,0 +1,8 @@ +# Provider Examples + +This folder contains direct, provider-focused examples. + +- `openrouter/`: simple OpenRouter calls with `model.generate(...)` + +These scripts are intentionally separate from `examples/scripts/`, which focuses on +pipeline usage. diff --git a/examples/providers/openrouter/01_simple_prompt.py b/examples/providers/openrouter/01_simple_prompt.py new file mode 100644 index 0000000..fdfb279 --- /dev/null +++ b/examples/providers/openrouter/01_simple_prompt.py @@ -0,0 +1,22 @@ +"""Minimal OpenRouter example with a single prompt.""" + +from dotenv import load_dotenv + +from datafast import openrouter +from datafast.llm_utils import format_generated_responses + + +MODEL_ID = "openai/gpt-5.4-mini" +PROMPT = "Write one sentence explaining what OpenRouter is." + + +def main() -> None: + load_dotenv() + + model = openrouter(MODEL_ID, temperature=0) + response = model.generate(prompt=PROMPT) + print(format_generated_responses(PROMPT, response)) + + +if __name__ == "__main__": + main() diff --git a/examples/providers/openrouter/02_batch_prompts.py b/examples/providers/openrouter/02_batch_prompts.py new file mode 100644 index 0000000..765b219 --- /dev/null +++ b/examples/providers/openrouter/02_batch_prompts.py @@ -0,0 +1,26 @@ +"""Minimal OpenRouter example with a batch of prompts.""" + +from dotenv import load_dotenv + +from datafast import openrouter +from datafast.llm_utils import format_generated_responses + + +MODEL_ID = "openai/gpt-5.4-mini" +PROMPTS = [ + "Give a one-sentence definition of synthetic data.", + "Give a one-sentence definition of retrieval-augmented generation.", + "Give a one-sentence definition of tool calling.", +] + + +def main() -> None: + load_dotenv() + + model = openrouter(MODEL_ID, temperature=0) + responses = model.generate(prompt=PROMPTS) + print(format_generated_responses(PROMPTS, responses)) + + +if __name__ == "__main__": + main() diff --git a/examples/providers/openrouter/03_messages_with_system_prompt.py b/examples/providers/openrouter/03_messages_with_system_prompt.py new file mode 100644 index 0000000..546e63f --- /dev/null +++ b/examples/providers/openrouter/03_messages_with_system_prompt.py @@ -0,0 +1,31 @@ +"""OpenRouter example using explicit chat messages.""" + +from dotenv import load_dotenv + +from datafast import openrouter +from datafast.llm_utils import format_generated_responses + + +MODEL_ID = "openai/gpt-5.4-mini" +MESSAGES = [ + { + "role": "system", + "content": "You are a concise technical assistant. Answer in exactly two bullets.", + }, + { + "role": "user", + "content": "Explain why teams use an LLM router.", + }, +] + + +def main() -> None: + load_dotenv() + + model = openrouter(MODEL_ID, temperature=0) + response = model.generate(messages=MESSAGES) + print(format_generated_responses(MESSAGES[-1]["content"], response)) + + +if __name__ == "__main__": + main() diff --git a/examples/providers/openrouter/04_structured_output.py b/examples/providers/openrouter/04_structured_output.py new file mode 100644 index 0000000..1c2bd61 --- /dev/null +++ b/examples/providers/openrouter/04_structured_output.py @@ -0,0 +1,28 @@ +"""OpenRouter example with structured output validation.""" + +from dotenv import load_dotenv +from pydantic import BaseModel + +from datafast import openrouter + + +MODEL_ID = "openai/gpt-5.4-mini" +PROMPT = "Return a JSON object describing OpenRouter in two short sentences." + + +class ProviderSummary(BaseModel): + name: str + summary: str + best_for: str + + +def main() -> None: + load_dotenv() + + model = openrouter(MODEL_ID, temperature=0) + response = model.generate(prompt=PROMPT, response_format=ProviderSummary) + print(response.model_dump_json(indent=2)) + + +if __name__ == "__main__": + main() diff --git a/examples/providers/openrouter/05_batch_messages.py b/examples/providers/openrouter/05_batch_messages.py new file mode 100644 index 0000000..25b42d7 --- /dev/null +++ b/examples/providers/openrouter/05_batch_messages.py @@ -0,0 +1,44 @@ +"""OpenRouter example with a batch of message lists.""" + +from dotenv import load_dotenv + +from datafast import openrouter +from datafast.llm_utils import format_generated_responses + + +MODEL_ID = "openai/gpt-5.4-mini" +BATCH_MESSAGES = [ + [ + { + "role": "system", + "content": "You answer for engineers in one sentence.", + }, + { + "role": "user", + "content": "What is prompt caching?", + }, + ], + [ + { + "role": "system", + "content": "You answer for engineers in one sentence.", + }, + { + "role": "user", + "content": "What is structured output?", + }, + ], +] + + +def main() -> None: + load_dotenv() + + model = openrouter(MODEL_ID, temperature=0) + responses = model.generate(messages=BATCH_MESSAGES) + prompts = [messages[-1]["content"] for messages in BATCH_MESSAGES] + print(format_generated_responses(prompts, responses)) + + +if __name__ == "__main__": + main() diff --git a/examples/scripts/43_cookbook_persona_generation.py b/examples/scripts/43_cookbook_persona_generation.py new file mode 100644 index 0000000..ac4f718 --- /dev/null +++ b/examples/scripts/43_cookbook_persona_generation.py @@ -0,0 +1,137 @@ +"""Persona-generation cookbook: XSum article -> personas -> related personas. + +Demonstrates: Source.huggingface, Map, Filter, Sample, JSON-mode LLMSteps, +and prompt assets stored under docs/cookbook/assets. + +Requires: +- OPENROUTER_API_KEY +- Hugging Face authentication via HF_TOKEN or a cached `huggingface_hub` login +- network access to Hugging Face and OpenRouter +""" + +import random + +from dotenv import load_dotenv + +from datafast import AddUUID, Filter, LLMStep, Map, Sample, Sink, Source, openrouter + +import litellm + +load_dotenv() + +litellm.suppress_debug_info = True + + +MODEL_ID = "nvidia/nemotron-3-super-120b-a12b:nitro" +OUTPUT_PATH = "examples/outputs/43_persona_cookbook.jsonl" +CHECKPOINT_DIR = "examples/checkpoints/43_persona_cookbook" +HF_REPO_ID = "patrickfleith/new-persona-cookbook-dataset" +TEXT_TO_PERSONA_PROMPTS = [ + "docs/cookbook/assets/text_to_persona_v1.txt", + "docs/cookbook/assets/text_to_persona_v2.txt", + "docs/cookbook/assets/text_to_persona_v3.txt", +] +PERSONA_TO_PERSONA_PROMPTS = [ + "docs/cookbook/assets/persona_to_persona_v1.txt", + "docs/cookbook/assets/persona_to_persona_v2.txt", + "docs/cookbook/assets/persona_to_persona_v3.txt", +] +LIFE_STAGES = [ + "a teenager", + "a young adult", + "an adult (30s/40s)", + "a middle-aged person (in their 50s/60s)", + "a senior person (in their 70s/80s)", +] + + +def add_word_count(record: dict) -> dict: + return {**record, "word_count": len(record["document"].split())} + + +def assign_life_stage(record: dict) -> dict: + return {**record, "life_stage": random.choice(LIFE_STAGES)} + + +def assign_related_life_stage(record: dict) -> dict: + return {**record, "related_life_stage": random.choice(LIFE_STAGES)} + + +def keep_output_fields(record: dict) -> dict: + return { + "source_id": record["id"], + "summary": record["summary"], + "document": record["document"], + "word_count": record["word_count"], + "life_stage": record["life_stage"], + "persona_description": record["persona_description"], + "relationship_type": record["relationship_type"], + "related_life_stage": record["related_life_stage"], + "related_persona_description": record["related_persona_description"], + } + + +def build_pipeline(): + model = openrouter(MODEL_ID, temperature=0.7) + + return ( + Source.huggingface( + "xsum", + split="validation", + columns=["id", "document", "summary"], + ) + # For a local JSONL corpus, replace the Hugging Face source with something + # like Source.file("data/articles.jsonl") and map your text field to + # "document" before add_word_count. + >> Map(add_word_count).as_step("add_word_count") + >> Filter(fn=lambda r: 300 <= r["word_count"] <= 500).as_step("filter_word_count") + >> Sample(n=10, strategy="first").as_step("take_first_100") + >> Map(assign_life_stage).as_step("assign_life_stage") + >> LLMStep( + prompt=Sample(TEXT_TO_PERSONA_PROMPTS, n=1), + input_columns=["document", "life_stage"], + output_columns=["persona_description"], + model=model, + parse_mode="json", + on_parse_error="raise", + ).as_step("text_to_persona") + >> Map(assign_related_life_stage).as_step("assign_related_life_stage") + >> LLMStep( + prompt=Sample(PERSONA_TO_PERSONA_PROMPTS, n=1), + input_columns=["persona_description", "related_life_stage"], + output_columns=["relationship_type", "related_persona_description"], + model=model, + parse_mode="json", + on_parse_error="raise", + ).as_step("persona_to_persona") + >> Map(keep_output_fields).as_step("keep_output_fields") + >> AddUUID(column="id", overwrite=True).as_step("add_uuid") + >> Sink.jsonl(OUTPUT_PATH) + >> Sink.hub(HF_REPO_ID, private=True) +) + + +def push_records_to_hub(records: list[dict]) -> None: + repo_id = "patrickfleith/datafast-persona-cookbook" + private = False + + list( + Sink.hub( + repo_id=repo_id, + private=private, + commit_message=f"Publish cookbook 43 persona dataset with {MODEL_ID}", + ).process(records) + ) + + +def main() -> None: + records = build_pipeline().run( + batch_size=1, + checkpoint_dir=CHECKPOINT_DIR, + resume=False, + ) + push_records_to_hub(records) + + +if __name__ == "__main__": + main() diff --git a/examples/scripts/44_cookbook_space_text_generation.py b/examples/scripts/44_cookbook_space_text_generation.py new file mode 100644 index 0000000..6c5d2cb --- /dev/null +++ b/examples/scripts/44_cookbook_space_text_generation.py @@ -0,0 +1,143 @@ +"""Space text-generation cookbook: seed grid -> technical text corpus. + +Demonstrates: Seed.product, LLMStep JSON mode, multi-language generation, +num_outputs, checkpointing, JSONL output, and optional Hub push. + +Requires: +- OPENROUTER_API_KEY +- Hugging Face authentication only if DATAFAST_PUSH_TO_HUB=1 +- network access to OpenRouter, and to Hugging Face when publishing +""" + +from __future__ import annotations + +import os + +import litellm +from dotenv import load_dotenv + +from datafast import AddUUID, LLMStep, Map, Seed, Sink, openrouter + +load_dotenv() +litellm.suppress_debug_info = True + + +SEED = 20250304 +MODEL_IDS = ["nvidia/nemotron-3-super-120b-a12b:nitro"] +OUTPUT_PATH = "examples/outputs/44_space_text_generation_cookbook.jsonl" +CHECKPOINT_DIR = "examples/checkpoints/44_space_text_generation_cookbook" +HF_REPO_ID = "patrickfleith/datafast-space-text-generation-cookbook" +NUM_OUTPUTS = 1 +PROMPT_PATH = "docs/cookbook/assets/space_text_generation.txt" + +DOCUMENT_TYPES = [ + "space engineering textbook", + "spacecraft design justification document", + "personal blog of a space engineer", +] + +TOPICS = [ + "Microgravity", + "Vacuum", + "Heavy Ions", + "Thermal Extremes", + "Atomic Oxygen", + "Debris Impact", + "Electrostatic Charging", + "Propellant Boil-off", +] + +EXPERTISE_LEVELS = [ + "executives", + "senior engineers", + "PhD candidates", +] + +LANGUAGES = { + "en": "English", + "fr": "French", +} + + +def make_models(): + return [openrouter(model_id, temperature=0.7) for model_id in MODEL_IDS] + + +def expected_row_count(model_count: int | None = None) -> int: + """Return the number of rows this configuration should generate.""" + model_total = len(MODEL_IDS) if model_count is None else model_count + return ( + len(DOCUMENT_TYPES) + * len(TOPICS) + * len(EXPERTISE_LEVELS) + * len(LANGUAGES) + * NUM_OUTPUTS + * model_total + ) + + +def finalize_record(record: dict) -> dict: + """Keep the columns meant for publication.""" + return { + "document_type": record["document_type"], + "topic": record["topic"], + "expertise_level": record["expertise_level"], + "language": record.get("_language", ""), + "model": record.get("_model", ""), + "title": record["title"], + "text": record["text"], + } + + +def build_pipeline(): + return ( + Seed.product( + Seed.values("document_type", DOCUMENT_TYPES), + Seed.values("topic", TOPICS), + Seed.values("expertise_level", EXPERTISE_LEVELS), + ).as_step("seed_space_text_grid") + >> LLMStep( + prompt=PROMPT_PATH, + input_columns=["document_type", "topic", "expertise_level"], + output_columns=["title", "text"], + parse_mode="json", + model=make_models(), + language=LANGUAGES, + num_outputs=NUM_OUTPUTS, + on_parse_error="raise", + ).as_step("generate_space_text") + >> Map(finalize_record).as_step("finalize_record") + >> AddUUID(column="id", overwrite=True).as_step("add_uuid") + >> Sink.jsonl(OUTPUT_PATH) + ) + + +def push_records_to_hub(records: list[dict]) -> None: + list( + Sink.hub( + repo_id=HF_REPO_ID, + private=True, + train_size=0.8, + seed=SEED, + shuffle=True, + commit_message=f"Publish cookbook 44 text dataset with {', '.join(MODEL_IDS)}", + ).process(records) + ) + + +def main() -> None: + print(f"Expected rows: {expected_row_count()}") + records = build_pipeline().run( + batch_size=4, + checkpoint_dir=CHECKPOINT_DIR, + resume=True, + ) + + if os.getenv("DATAFAST_PUSH_TO_HUB") == "1": + push_records_to_hub(records) + + print(f"Wrote {len(records)} records to {OUTPUT_PATH}") + + +if __name__ == "__main__": + main() diff --git a/examples/scripts/45_cookbook_text_classification.py b/examples/scripts/45_cookbook_text_classification.py new file mode 100644 index 0000000..9dce7af --- /dev/null +++ b/examples/scripts/45_cookbook_text_classification.py @@ -0,0 +1,158 @@ +"""Text-classification cookbook: seed grid -> multilingual trail comments. + +Demonstrates: Seed.product, prompt expansion via seed dimensions, multi-model +generation, multi-language generation, checkpointing, JSONL output, and +optional Hugging Face Hub publishing. + +Requires: +- OPENROUTER_API_KEY +- Hugging Face authentication only if DATAFAST_PUSH_TO_HUB=1 +- network access to OpenRouter, and to Hugging Face when publishing +""" + +from __future__ import annotations + +import os + +import litellm +from dotenv import load_dotenv + +from datafast import AddUUID, LLMStep, Map, Seed, SeedDimension, Sink, openrouter + +load_dotenv() +litellm.suppress_debug_info = True + + +SEED = 20250611 +MODEL_IDS = [ + "nvidia/nemotron-3-super-120b-a12b:nitro", + "mistralai/ministral-14b-2512", +] +OUTPUT_PATH = "examples/outputs/45_text_classification_cookbook.jsonl" +CHECKPOINT_DIR = "examples/checkpoints/45_text_classification_cookbook" +HF_REPO_ID = "patrickfleith/datafast-text-classification-cookbook" +PROMPT_PATH = "docs/cookbook/assets/text_classification_generation.txt" + +LABELS = [ + { + "label": "trail_obstruction", + "label_description": ( + "The trail is partially or fully blocked by obstacles such as " + "fallen trees, landslides, snow, flooding, erosion, or dense " + "vegetation." + ), + }, + { + "label": "infrastructure_issues", + "label_description": ( + "The report is about damaged or missing bridges, signs, stairs, " + "handrails, markers, boardwalks, or similar trail infrastructure." + ), + }, + { + "label": "hazards", + "label_description": ( + "The trail has immediate safety risks such as slippery surfaces, " + "dangerous crossings, unstable terrain, wildlife threats, or " + "other hazardous conditions." + ), + }, + { + "label": "positive_conditions", + "label_description": ( + "The report highlights clear, safe, enjoyable trail conditions " + "such as good maintenance, solid infrastructure, clear signage, " + "or scenic features." + ), + }, +] + +TRAIL_TYPES = [ + "mountain trail", + "coastal path", + "forest walk", +] + +STYLES = [ + "a brief social media post", + "a hiking review", +] + +LANGUAGES = { + "en": "English", + "fr": "French", +} + +MODELS = [openrouter(model_id, temperature=0.8) for model_id in MODEL_IDS] +EXPECTED_ROWS = ( + len(LABELS) + * len(TRAIL_TYPES) + * len(STYLES) + * len(LANGUAGES) + * len(MODELS) +) + + +def keep_output_fields(record: dict) -> dict: + """Keep only the fields meant for publication.""" + return { + "label": record["label"], + "trail_type": record["trail_type"], + "style": record["style"], + "language": record.get("_language", ""), + "model": record.get("_model", ""), + "text": record["text"], + } + + +pipeline = ( + Seed.product( + SeedDimension( + columns=["label", "label_description"], + values=LABELS, + ), + Seed.values("trail_type", TRAIL_TYPES), + Seed.values("style", STYLES), + ).as_step("seed_trail_report_grid") + >> LLMStep( + prompt=PROMPT_PATH, + input_columns=["label", "label_description", "trail_type", "style"], + output_column="text", + parse_mode="text", + model=MODELS, + language=LANGUAGES, + ).as_step("generate_trail_reports") + >> Map(keep_output_fields).as_step("keep_output_fields") + >> AddUUID(column="id", overwrite=True).as_step("add_uuid") + >> Sink.jsonl(OUTPUT_PATH) +) + + +def main() -> None: + print(f"Expected rows: {EXPECTED_ROWS}") + records = pipeline.run( + batch_size=4, + checkpoint_dir=CHECKPOINT_DIR, + resume=True, + ) + + if os.getenv("DATAFAST_PUSH_TO_HUB") == "1": + list( + Sink.hub( + repo_id=HF_REPO_ID, + private=False, + train_size=0.8, + seed=SEED, + shuffle=True, + commit_message=( + "Publish cookbook 45 classification dataset with " + f"{', '.join(MODEL_IDS)}" + ), + ).process(records) + ) + + print(f"Wrote {len(records)} records to {OUTPUT_PATH}") + + +if __name__ == "__main__": + main() diff --git a/llm_provider_requirements.md b/llm_provider_requirements.md new file mode 100644 index 0000000..df82303 --- /dev/null +++ b/llm_provider_requirements.md @@ -0,0 +1,259 @@ +# LLM Provider Requirements (Draft) + +## Goal + +Design a clean model-provider layer for `datafast/llms.py` with one stable Datafast API, while resolving actual support per target model or deployment. + +The key design rule is: + +- The public API should provide a uniform core model. +- The public API should also provide ergonomic provider-specific entry points. +- Capabilities should be resolved per target: provider + endpoint + model + optional self-hosted server behavior. + +## Core Design Principles + +- Keep a small common config surface for normal usage. +- Do not assume all models under one provider support the same parameters. +- Do not silently pass unsupported parameters unless that behavior is explicitly enabled. +- Preserve provider or server defaults when the user does not override them. +- Separate Datafast-level config from provider-specific request mapping. + +## Common Datafast Config + +Every target should support these common fields when applicable: + +- `model_id` +- `temperature` +- `rpm_limit` +- `timeout` + +Optional fields, only sent when supported: + +- `max_completion_tokens` +- `thinking` +- `reasoning_effort` +- `api_key` +- `api_base_url` +- retry limit +- `unsupported_params` + +`unsupported_params` should control how Datafast handles user-specified parameters that are known to be unsupported by the resolved target. + +- `fail`: raise a clear error before sending the request +- `warn`: omit the unsupported parameter and emit a warning +- `quiet`: omit the unsupported parameter silently + +Default: + +- `unsupported_params="warn"` + +## Public API Ergonomics + +The public API should expose provider-specific entry points such as: + +- `openai(...)` +- `anthropic(...)` +- `openrouter(...)` +- `mistral(...)` +- `ollama(...)` + +Requirements: + +- Provider-specific entry points should be the primary ergonomic API for users. +- They should make provider choice explicit and easy to read in pipelines. +- They should expose sensible provider-specific defaults and validation. +- They should share the same common config surface where possible. +- They may expose provider-specific options when needed, without forcing those options into every provider API. +- They should remain thin wrappers over a shared internal target/config system. +- Core execution behavior such as retries, batching, capability resolution, caching, and parsing should not live separately in each provider wrapper. + +## Capability Resolution + +Requirements should be defined around resolved target capabilities, not provider classes alone. + +That means: + +- OpenAI-compatible transport does not imply OpenAI-equivalent features. +- OpenRouter support is model-specific, not just provider-specific. +- Local servers such as Ollama, vLLM, and `llama.cpp` may expose different controls even when they look OpenAI-compatible. +- Local servers may emulate an endpoint shape without matching the full upstream semantics. +- When support is unknown, the safe default is to omit optional params rather than optimistically send them. + +The design should allow: + +- capability mapping per model or deployment +- endpoint-mode resolution per target, especially chat completions vs Responses API +- provider-specific parameter aliases +- explicit escape hatches for provider-specific params +- controlled dropping of unsupported params when intentionally enabled + +Unsupported-parameter handling should be explicit and user-configurable through `unsupported_params`. + +- The policy should apply to Datafast-known unsupported parameters for the resolved target. +- The default behavior should be `warn`. +- `quiet` should be allowed for users who intentionally want best-effort portability. +- `fail` should be available for users who want strict validation. + +Some targets may work best through `completion()` and others through `responses()`. The public Datafast API should not force users to care about that distinction, but the internal adapter layer should. + +Requirements should also allow target-level compatibility notes such as: + +- chat endpoint requires a compatible chat template +- a parameter is accepted but ignored +- an endpoint is available but implemented as an internal translation layer + +## Request / Response Model + +Datafast should expose one request model that supports: + +- single request +- concurrent batch requests +- prompt input +- message input +- structured output via Pydantic + +The execution layer should support both: + +- native same-target batching for many inputs to one resolved model/deployment when available +- fallback concurrency when native batching is unavailable + +If native batching is unavailable and Datafast falls back to parallel single requests, the user should be warned that a fallback execution path is being used. + +The message model should support both: + +- simple text messages +- typed multimodal content parts + +Supported content parts should include a common shape for: + +- text +- image +- audio +- video +- file +- document + +This keeps the public API compatible with multimodal-capable chat models without forcing separate provider APIs for each modality. + +Content parts should also be able to carry optional stable media IDs / UUIDs for targets that can reuse multimodal processing across requests. + +## Multimodal Requirements + +- Multimodal input support must be capability-aware per target. +- A model that supports text-only should still work with the same public call shape. +- A model that supports image, audio, video, document, or file inputs should accept typed content parts in `messages`. +- The design should also allow non-text outputs when supported, especially image-generation-capable chat models. +- Structured output and multimodal input should coexist when the target supports both. +- The design should support targets that expose multimodal and reasoning features primarily through the Responses API. +- The design should not assume all local backends support the same modalities. For example, support for image, audio, video, and document inputs may differ substantially between vLLM and `llama.cpp`. +- The design should allow target-specific media options when needed, without polluting the common API surface. + +## Reliability and Execution + +Every LLM call should have a standard execution policy: + +- bounded retries +- exponential backoff +- jitter +- retryable vs non-retryable error handling +- consistent timeout handling +- client-side RPM throttling + +Batch execution should: + +- preserve input order +- apply the same retry and timeout rules as single requests +- use native same-target batching when available +- fall back to controlled concurrency when native batching is unavailable +- warn the user when fallback concurrency is used instead of native batching + +## Endpoint Mode Requirements + +The design should explicitly allow multiple endpoint modes behind one public API. + +- Some targets should be called through chat completions. +- Some targets should be called through the Responses API. +- Endpoint choice should be resolved per target capability, not hardcoded per provider class. +- Responses API support matters for targets that expose reasoning, multimodal I/O, image generation, or session continuity through that endpoint. +- When the Responses API is used, the design should allow carrying forward response-session state such as `previous_response_id` when needed. +- The requirements should not assume that every Responses API implementation is native. A local backend may expose `/v1/responses` by translating it into another internal request shape. + +## Caching Requirements + +Caching should be part of the design, but not assumed to behave the same across targets. + +The requirements should distinguish: + +- provider-native prompt caching +- gateway or routing-layer caching +- local server prefix / KV caching +- optional client-side result caching + +Key requirements: + +- caching must be explicit and correctness-preserving +- cache behavior must be capability-aware per target +- cache keys or cache hints must account for model, endpoint, relevant generation params, and multimodal inputs +- provider-specific caching controls should be supported through the mapping layer or escape hatch +- the public API should not promise identical cache semantics across OpenAI, Anthropic, Mistral, OpenRouter, Ollama, vLLM, and `llama.cpp` + +The requirements should also distinguish between: + +- provider-side prompt caching semantics +- prefix / KV-cache reuse for repeated prompt prefixes +- multimodal preprocessing cache reuse keyed by stable media identity + +In particular, local backends may expose caching mainly as performance-oriented KV reuse rather than provider-managed prompt caching. That should be modeled explicitly. + +## What To Keep From The Current Design + +The current `llms.py` points to a few good design directions that should remain in the requirements: + +- one stable API for single and batch calls +- first-class structured output +- proactive client-side rate limiting +- standard retry behavior +- graceful fallback when a target lacks native batching +- support for local backends without requiring an API key +- tracing / metadata hooks on every request + +## Recommended Direction + +The optimal design is: + +- provider-specific public factories as thin entry points +- one common Datafast request/config model +- one target capability layer +- one shared execution layer for retries, throttling, batching, caching, and parsing +- thin internal provider adapters that only map Datafast requests into target-specific LiteLLM calls + +The capability layer should be able to describe at least: + +- supported endpoint modes +- supported modalities +- structured-output mechanism +- cache mechanism type +- chat-template or prompt-format requirements +- parameter caveats such as unsupported, ignored, translated, or model-dependent + +This keeps the user-facing API simple while allowing model-specific behavior where it actually belongs. + +## References + +- LiteLLM provider-specific params: +- LiteLLM drop unsupported params: +- LiteLLM retries / fallbacks: +- LiteLLM batching: +- LiteLLM Responses API: +- LiteLLM structured output / JSON mode: +- LiteLLM reasoning content: +- LiteLLM vision: +- LiteLLM audio: +- LiteLLM document understanding: +- LiteLLM image generation in chat: +- vLLM online serving: +- vLLM structured outputs: +- vLLM automatic prefix caching: +- vLLM multimodal inputs: +- llama.cpp server: +- llama.cpp multimodal: diff --git a/llm_provider_test_plan.md b/llm_provider_test_plan.md new file mode 100644 index 0000000..71bad93 --- /dev/null +++ b/llm_provider_test_plan.md @@ -0,0 +1,285 @@ +# LLM Provider Test Plan (Draft) + +## Goal + +Test the provider redesign without exploding the matrix. + +Main idea: + +- Test shared behavior once at the common layer. +- Test only provider/model deltas at the capability layer. +- Run a meaningful live suite against selected real models. +- Keep the live suite maintainable through a small curated model catalog. +- Defer multimodal live coverage until after the first stable text-first provider test suite is in place. +- Defer caching coverage until after the first stable text-first provider test suite is in place. + +## Test Layers + +| Layer | Purpose | Typical tools | +|---|---|---| +| Unit / contract | Validate request normalization, capability resolution, retry logic, batching decisions, parsing, caching decisions | mocked LiteLLM / fake adapters | +| Adapter tests | Verify mapping from Datafast request to LiteLLM request per endpoint mode | mocked `completion()`, `batch_completion()`, `responses()` | +| Live acceptance | Verify selected real models are safe for Datafast users | live API / local server | + +## Marker Strategy + +Recommended markers: + +- `live`: any test hitting a real provider endpoint +- `multimodal`: reserved for later image / audio / document / video coverage +- `ollama`: real Ollama backend +- `vllm`: real vLLM backend +- `llamacpp`: real `llama.cpp` backend + +Suggested usage: + +- default CI: mocked tests only +- provider CI / pre-release: `-m live` +- targeted local runs: `-m "live and ollama"` / `-m "live and vllm"` / `-m "live and llamacpp"` + +## Matrix Reduction Strategy + +- Do not test every feature against every provider. +- Run a compact acceptance suite against a curated list of selected models. +- Choose one representative provider/model endpoint per endpoint mode for mocked tests. +- Choose one representative provider/model endpoint per modality for deeper live tests. +- For each provider, test only what is different from the shared contract. +- Keep local-backend tests separate from hosted-provider smoke tests. + +## Selected Model Catalog + +Maintain one curated list of current supported / recommended test targets per provider. + +This catalog should not aim to include every available model. It should be a curated test surface for capability coverage and user confidence, not a registry of all provider inventory. + +Each catalog entry should record at least: + +- provider +- model_id +- endpoint mode +- hosted vs local +- expected modalities +- expected structured-output support +- expected reasoning / thinking support +- expected batching behavior +- expected cache mechanism type +- test markers to apply, such as `live`, `multimodal`, `ollama`, `vllm`, `llamacpp` + +Design goal: + +- adding a new model should usually mean adding one catalog entry +- most live tests should parametrize over that catalog +- provider/model-specific regressions should be captured as capability expectations in the catalog + +Models are good candidates for the catalog when they are: + +- recommended to Datafast users +- used in docs or examples +- representative of a distinct capability shape +- newly added and worth validating before being treated as supported +- known to be tricky or historically unstable + +Models are usually not good candidates when they: + +- do not add meaningful new capability coverage +- are deprecated or not intended for ongoing support +- are only one of many near-identical variants from the same provider + +### Current Catalog Decisions + +Current agreed shortlist as of June 2026: + +- OpenAI: `gpt-5.5`, `gpt-5.4`, `gpt-5.4-mini`, `gpt-5.4-nano` +- Anthropic: `claude-sonnet-4-6`, `claude-haiku-4-5` +- Gemini: `gemini-2.5-pro`, `gemini-3.5-flash`, `gemini-3.1-flash-lite` +- Mistral hosted: `mistral-medium-3-5`, `mistral-large-2512`, `mistral-small-2603` +- Mistral local / self-hosted: `ministral-14b-2512`, `ministral-8b-2512`, `ministral-3b-2512` + +Current exclusions / constraints: + +- Exclude Anthropic `claude-fable-5` and `claude-opus-4-8` due to cost. +- Exclude Gemini `gemini-2.5-flash`. +- Keep the catalog curated for capability coverage, not exhaustive by provider inventory. +- Keep hosted Mistral and local Mistral entries separate in the catalog. +- Treat local-server capability expectations as backend-specific, especially for `vLLM`, `llama.cpp`, and other OpenAI-compatible servers. +- If a compact local Mistral subset is needed later, start with `ministral-8b-2512` and `ministral-3b-2512`. + +## Live Acceptance Suite + +These should run against the curated selected-model catalog. + +| ID | Test | +|---|---| +| L01 | Basic text generation works for every selected live model | +| L02 | Structured output works for every selected live model that claims support | +| L03 | Batch request works for every selected live model using the expected execution path, and emits a warning if fallback batching is used | +| L04 | Common params such as `timeout` and `temperature` are accepted or handled according to capability expectations | +| L05 | Declared unsupported params follow `unsupported_params` policy as expected for that model | +| L06 | Endpoint mode matches expectation: chat completions vs Responses API | +| L07 | Provider-specific factory entry point works for that model | +| L08 | Metadata / tracing path does not break live requests | + +For local backends, include: + +| ID | Test | +|---|---| +| L09 | `api_base_url` path works | +| L10 | no-API-key path works where expected | + +## Core Contract Tests + +These should run with mocks only. + +| ID | Test | +|---|---| +| C01 | Factory functions such as `openai(...)`, `openrouter(...)`, `ollama(...)` create the expected internal target/config shape | +| C02 | Single prompt returns a single result | +| C03 | Batch prompts return ordered list results | +| C04 | `messages` input works for single request | +| C05 | Batched `messages` input works and preserves order | +| C06 | Reject `prompt=None` and `messages=None` | +| C07 | Reject providing both `prompt` and `messages` | +| C08 | Structured output with Pydantic parses successfully | +| C09 | Structured output surfaces a clear validation error on invalid JSON / schema mismatch | +| C10 | Text responses are normalized consistently | +| C11 | Metadata / tracing payload is attached to requests | + +## Capability Layer Tests + +These should validate the resolved target rules. + +| ID | Test | +|---|---| +| K01 | Supported params are forwarded for a target that allows them | +| K02 | Unsupported params are omitted by default when capability is unknown | +| K03 | `unsupported_params="warn"` omits unsupported params and emits a warning | +| K04 | `unsupported_params="fail"` raises a clear error before request dispatch | +| K05 | `unsupported_params="quiet"` omits unsupported params without warning | +| K06 | Provider-specific aliases map correctly to the internal common config | +| K07 | `thinking=False` suppresses `reasoning_effort` | +| K08 | `thinking=True` with no explicit `reasoning_effort` uses target default | +| K09 | Endpoint mode resolves correctly: chat completions vs Responses API | +| K10 | Capability notes such as "accepted but ignored" or "translated internally" are represented correctly | +| K11 | OpenAI-compatible target is not assumed to support all OpenAI features | +| K12 | Local target requiring a chat template is flagged correctly | + +## Adapter Tests + +These verify the LiteLLM call shape. + +| ID | Test | +|---|---| +| A01 | Chat-completions target calls `litellm.completion()` for single input | +| A02 | Native same-target batch calls `litellm.batch_completion()` when supported | +| A03 | If native batching is unavailable, batch input is executed via bounded parallel single requests, preserves ordered batch outputs, and emits a user warning | +| A04 | Responses target calls `litellm.responses()` | +| A05 | Responses target forwards `previous_response_id` when present | +| A06 | Structured output maps to the correct LiteLLM field per endpoint mode | +| A07 | Provider-specific extra params pass only through the escape hatch | +| A08 | `api_base_url` and optional `api_key` are passed correctly for local / self-hosted targets | + +## Reliability Tests + +| ID | Test | +|---|---| +| R01 | Retryable error triggers bounded retries | +| R02 | Non-retryable error fails immediately | +| R03 | Backoff grows across retries | +| R04 | Jitter is applied within the expected range | +| R05 | Timeout is forwarded and timeout failure is surfaced clearly | +| R06 | Client-side `rpm_limit` throttles before provider error | +| R07 | Batch retry behavior preserves output ordering | + +## Multimodal Tests + +Multimodal coverage should come later. + +For the first rollout: + +- keep multimodal tests out of the required live acceptance suite +- allow a small number of mocked multimodal contract tests if useful +- add real multimodal coverage only after the text-first live suite is stable + +| ID | Test | +|---|---| +| M01 | Text-only message content remains supported | +| M02 | Image content part is accepted for a target with image input support | +| M03 | Audio content part is accepted for a target with audio input support | +| M04 | Video content part is accepted for a target with video input support | +| M05 | File / document content part is accepted for a target with document support | +| M06 | Unsupported modality is rejected clearly for a text-only target | +| M07 | Mixed text + image multimodal message preserves part order | +| M08 | Stable media ID / UUID is forwarded when provided | +| M09 | Non-text output path is selected correctly for image-generation-capable chat target | + +## Caching Tests + +Caching coverage should come later. + +For the first rollout: + +- keep caching tests out of the required live acceptance suite +- allow mocked cache-resolution tests if useful +- add real cache-behavior coverage only after the text-first live suite is stable + +| ID | Test | +|---|---| +| H01 | Cache mode resolves correctly for provider-native prompt caching | +| H02 | Cache mode resolves correctly for local prefix / KV caching | +| H03 | Cache key / cache hint changes when model changes | +| H04 | Cache key / cache hint changes when relevant generation params change | +| H05 | Cache key / cache hint changes when multimodal input identity changes | +| H06 | Stable media identity enables multimodal reuse hint when supported | +| H07 | Public API does not claim cache hit semantics that the target cannot guarantee | + +## Provider / Model Delta Live Tests + +Add only when a selected model has behavior that differs meaningfully from the common suite. + +| ID | Example | +|---|---| +| D01 | Responses-only reasoning model | +| D02 | OpenRouter model with provider-specific capability caveat | +| D03 | vLLM deployment with structured-output expectations | +| D04 | `llama.cpp` target with chat-template requirement | +| D05 | model with unusual unsupported-param behavior expectations | +| D06 | multimodal model with image input support | +| D07 | cache-relevant local backend behavior | + +## Extended Live Scenarios + +These are later-phase tests, not required for the initial rollout. + +| ID | Target | Test | +|---|---|---| +| E01 | Multimodal hosted model | text + image input | +| E02 | Audio or document-capable model | real multimodal request | +| E03 | Structured-output target | real Pydantic schema validation | +| E04 | Provider with prompt caching | repeated request with cache-relevant setup | +| E05 | vLLM | prefix-cache-friendly repeated prompt | +| E06 | local multimodal target | document or image input if supported | +| E07 | Responses target | `previous_response_id` continuation | +| E08 | selected-model sweep | run the full acceptance suite across the full catalog | + +## New Model Onboarding + +When a new model comes out: + +1. Add it to the selected-model catalog with expected capabilities. +2. Run the shared live acceptance suite against it. +3. Add a provider/model delta test only if it differs from the standard expectations. +4. Add an extended live scenario only if it adds meaningful new capability coverage. + +## Suggested Priorities + +- Phase 1: `C*`, `K*`, `A*`, `R*` +- Phase 2: selected-model `L*` live suite +- Phase 3: `M*`, `H*`, `D*` +- Phase 4: `E*` + +## Success Criteria + +- Shared behavior is covered mostly by fast mocked tests. +- The curated live suite gives confidence against real provider endpoints. +- Provider/model-specific logic is tested as deltas, not full re-runs of the whole matrix. +- Adding a new model is mostly a catalog update plus, if needed, one delta test. diff --git a/mkdocs.yml b/mkdocs.yml index 87e795a..131400c 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -43,6 +43,11 @@ nav: - LLM Steps: guides/llm_steps.md - Checkpointing: guides/checkpointing.md - Langfuse Tracing: guides/langfuse_tracing.md + - Cookbook: + - cookbook/index.md + - Text Classification: cookbook/text_classification.md + - Persona Generation: cookbook/persona_generation.md + - Space Engineering Text Generation: cookbook/space_text_generation.md - Providers: llms.md - Models: models.md - API: api.md diff --git a/pytest.ini b/pytest.ini index 798f789..042626f 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,6 +1,11 @@ [pytest] markers = integration: marks tests that require API connectivity (deselect with '-m "not integration"') + live: marks tests that hit a real provider endpoint + multimodal: marks tests that exercise multimodal provider behavior + ollama: marks tests that require a real Ollama backend + vllm: marks tests that require a real vLLM backend + llamacpp: marks tests that require a real llama.cpp backend slow: marks tests that are slow to run # Other pytest configurations diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..961d50b --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,20 @@ +import pytest + + +def pytest_addoption(parser): + parser.addoption( + "--run-live", + action="store_true", + default=False, + help="run tests marked live or integration", + ) + + +def pytest_collection_modifyitems(config, items): + if config.getoption("--run-live"): + return + + skip_live = pytest.mark.skip(reason="requires --run-live") + for item in items: + if "live" in item.keywords or "integration" in item.keywords: + item.add_marker(skip_live) diff --git a/tests/test_add_uuid.py b/tests/test_add_uuid.py new file mode 100644 index 0000000..e89f837 --- /dev/null +++ b/tests/test_add_uuid.py @@ -0,0 +1,78 @@ +import uuid + +from datafast import AddUUID, LLMStep, Sink, Source + + +def assert_valid_uuid(value: str) -> None: + parsed = uuid.UUID(value) + assert str(parsed) == value + + +def test_add_uuid_adds_id_when_missing(): + records = list(AddUUID().process([{"text": "hello"}])) + + assert records[0]["text"] == "hello" + assert_valid_uuid(records[0]["id"]) + + +def test_add_uuid_preserves_existing_id_by_default(): + records = list(AddUUID().process([{"id": "source-1", "text": "hello"}])) + + assert records == [{"id": "source-1", "text": "hello"}] + + +def test_add_uuid_overwrites_existing_id_when_requested(): + records = list( + AddUUID(overwrite=True).process([{"id": "source-1", "text": "hello"}]) + ) + + assert records[0]["text"] == "hello" + assert records[0]["id"] != "source-1" + assert_valid_uuid(records[0]["id"]) + + +def test_add_uuid_generates_distinct_ids_for_multiple_records(): + records = list(AddUUID().process([{"text": "a"}, {"text": "b"}])) + ids = [record["id"] for record in records] + + assert len(set(ids)) == 2 + for value in ids: + assert_valid_uuid(value) + + +def test_add_uuid_supports_custom_column_name(): + records = list(AddUUID(column="example_id").process([{"text": "hello"}])) + + assert "id" not in records[0] + assert_valid_uuid(records[0]["example_id"]) + + +def test_add_uuid_assigns_unique_ids_to_llm_num_outputs_pipeline(): + class FakeModel: + model_id = "fake-model" + provider_name = "fake" + + def generate(self, messages, metadata=None): + return '{"title": "Generated", "text": "Body"}' + + pipeline = ( + Source.list([{"topic": "vacuum"}]) + >> LLMStep( + prompt="Write about {topic}.", + input_columns=["topic"], + output_columns=["title", "text"], + parse_mode="json", + model=FakeModel(), + num_outputs=2, + ) + >> AddUUID() + >> Sink.list() + ) + + records = pipeline.run() + ids = [record["id"] for record in records] + + assert len(records) == 2 + assert len(set(ids)) == 2 + for value in ids: + assert_valid_uuid(value) diff --git a/tests/test_llm_provider_contract.py b/tests/test_llm_provider_contract.py new file mode 100644 index 0000000..8928e5e --- /dev/null +++ b/tests/test_llm_provider_contract.py @@ -0,0 +1,503 @@ +import pytest +from pydantic import BaseModel + +import datafast.llm.provider as provider_module +from datafast import LLMStep, ListSink, Source +from datafast.llm import ( + ContentPart, + EndpointMode, + Modality, + OpenAIProvider, + OpenRouterProvider, + openai, + openai_compatible, +) + + +class SimpleSchema(BaseModel): + answer: str + + +class _DummyMessage: + def __init__( + self, + content, + reasoning_content=None, + thinking_blocks=None, + images=None, + audio=None, + ): + self.content = content + self.reasoning_content = reasoning_content + self.thinking_blocks = thinking_blocks + self.images = images + self.audio = audio + + +class _DummyChoice: + def __init__( + self, + content, + reasoning_content=None, + thinking_blocks=None, + images=None, + audio=None, + ): + self.message = _DummyMessage( + content, + reasoning_content, + thinking_blocks, + images, + audio, + ) + + +class _DummyChatResponse: + def __init__( + self, + content, + reasoning_content=None, + thinking_blocks=None, + images=None, + audio=None, + ): + self.choices = [ + _DummyChoice( + content, + reasoning_content, + thinking_blocks, + images, + audio, + ) + ] + + +class _DummyResponsesResponse: + def __init__(self, output_text=None, output=None, reasoning_content=None): + self.output_text = output_text + self.output = output + self.reasoning_content = reasoning_content + + +@pytest.fixture(autouse=True) +def _disable_provider_side_effects(monkeypatch): + monkeypatch.setattr(provider_module, "load_env_once", lambda: None) + monkeypatch.setattr( + provider_module, + "maybe_configure_langfuse_tracing", + lambda load_env=False: False, + ) + + +def test_factories_resolve_expected_targets(): + hosted = openai(api_key="test-key") + local = openai_compatible( + "ministral-8b-2512", + api_base_url="http://localhost:8000/v1", + ) + + assert hosted.provider_name == "openai" + assert hosted.endpoint_mode == EndpointMode.RESPONSES + assert hosted._get_model_string() == "openai/gpt-5.5" + + assert local.provider_name == "openai_compatible" + assert local.endpoint_mode == EndpointMode.CHAT + assert local.api_base_url == "http://localhost:8000/v1" + + +def test_openai_compatible_backend_profiles_are_distinct(): + generic = openai_compatible( + "local-model", + api_base_url="http://localhost:8000/v1", + ) + vllm = openai_compatible( + "local-model", + api_base_url="http://localhost:8000/v1", + backend="vllm", + ) + llamacpp = openai_compatible( + "local-model", + api_base_url="http://localhost:8080/v1", + backend="llamacpp", + ) + + assert generic.provider_name == "openai_compatible" + assert generic.capabilities.modalities == frozenset({Modality.TEXT}) + + assert vllm.provider_name == "vllm" + assert vllm.capabilities.supports_endpoint(EndpointMode.RESPONSES) + assert Modality.IMAGE in vllm.capabilities.modalities + assert Modality.VIDEO in vllm.capabilities.modalities + + assert llamacpp.provider_name == "llamacpp" + assert Modality.AUDIO in llamacpp.capabilities.modalities + assert Modality.FILE in llamacpp.capabilities.modalities + + +def test_input_validation_rejects_missing_or_ambiguous_inputs(): + provider = OpenRouterProvider(model_id="demo-model", api_key="test-key") + + with pytest.raises(ValueError, match="Either prompt or messages"): + provider.generate() + + with pytest.raises(ValueError, match="either prompt or messages"): + provider.generate(prompt="hello", messages=[{"role": "user", "content": "hi"}]) + + +def test_unsupported_params_warn_and_omit(monkeypatch): + captured = {} + + def fake_completion(**kwargs): + captured.update(kwargs) + return _DummyChatResponse("ok") + + monkeypatch.setattr(provider_module.litellm, "completion", fake_completion) + + provider = openai_compatible( + "local-model", + api_base_url="http://localhost:8000/v1", + temperature=0.7, + ) + + with pytest.warns(UserWarning, match="temperature"): + assert provider.generate(prompt="ping") == "ok" + + assert "temperature" not in captured + assert captured["api_base"] == "http://localhost:8000/v1" + + +def test_unsupported_params_fail_before_dispatch(monkeypatch): + def fake_completion(**kwargs): + raise AssertionError("request should not be dispatched") + + monkeypatch.setattr(provider_module.litellm, "completion", fake_completion) + + provider = openai_compatible( + "local-model", + api_base_url="http://localhost:8000/v1", + temperature=0.7, + unsupported_params="fail", + ) + + with pytest.raises(ValueError, match="temperature"): + provider.generate(prompt="ping") + + +def test_chat_endpoint_warns_and_omits_previous_response_id(monkeypatch): + captured = {} + + def fake_completion(**kwargs): + captured.update(kwargs) + return _DummyChatResponse("ok") + + monkeypatch.setattr(provider_module.litellm, "completion", fake_completion) + + provider = OpenRouterProvider(model_id="demo-model", api_key="test-key") + + with pytest.warns(UserWarning, match="previous_response_id"): + assert provider.generate(prompt="ping", previous_response_id="resp_old") == "ok" + + assert "previous_response_id" not in captured + + +def test_openrouter_thinking_warns_and_omits_reasoning_param(monkeypatch): + captured = {} + + def fake_completion(**kwargs): + captured.update(kwargs) + return _DummyChatResponse("ok") + + monkeypatch.setattr(provider_module.litellm, "completion", fake_completion) + + provider = OpenRouterProvider( + model_id="nvidia/nemotron-3-super-120b-a12b:nitro", + api_key="test-key", + thinking=True, + ) + + with pytest.warns(UserWarning, match="reasoning_effort"): + assert provider.generate(prompt="ping") == "ok" + + assert "reasoning_effort" not in captured + assert "reasoning" not in captured + + +def test_provider_params_escape_hatch_is_forwarded(monkeypatch): + captured = {} + + def fake_completion(**kwargs): + captured.update(kwargs) + return _DummyChatResponse("ok") + + monkeypatch.setattr(provider_module.litellm, "completion", fake_completion) + + provider = openai_compatible( + "local-model", + api_base_url="http://localhost:8000/v1", + provider_params={"extra_body": {"backend_hint": "vllm"}}, + ) + + assert provider.generate(prompt="ping") == "ok" + assert captured["extra_body"] == {"backend_hint": "vllm"} + + +def test_content_parts_normalize_multimodal_and_document_shapes(): + vllm = openai_compatible( + "local-model", + api_base_url="http://localhost:8000/v1", + backend="vllm", + ) + prepared = vllm._prepare_messages( + [ + { + "role": "user", + "content": [ + ContentPart(type="text", text="What is in this image?"), + ContentPart( + type="image", + url="https://example.com/image.png", + media_id="img-123", + ), + ContentPart( + type="video", + url="https://example.com/video.mp4", + media_id="vid-123", + ), + ], + } + ], + response_format=None, + ) + + assert prepared[0]["content"] == [ + {"type": "text", "text": "What is in this image?"}, + { + "type": "image_url", + "image_url": {"url": "https://example.com/image.png"}, + "uuid": "img-123", + }, + { + "type": "video_url", + "video_url": {"url": "https://example.com/video.mp4"}, + "uuid": "vid-123", + }, + ] + + llamacpp = openai_compatible( + "local-model", + api_base_url="http://localhost:8080/v1", + backend="llamacpp", + ) + prepared = llamacpp._prepare_messages( + [ + { + "role": "user", + "content": [ + ContentPart( + type="document", + data="data:application/pdf;base64,abc", + media_type="application/pdf", + ), + ], + } + ], + response_format=None, + ) + + assert prepared[0]["content"] == [ + { + "type": "file", + "file": {"file_data": "data:application/pdf;base64,abc"}, + } + ] + + +def test_litellm_unsupported_params_can_retry_with_drop_params(monkeypatch): + unsupported_error = type("UnsupportedParamsError", (Exception,), {}) + calls = [] + + def fake_completion(**kwargs): + calls.append(kwargs) + if len(calls) == 1: + raise unsupported_error("bad param") + return _DummyChatResponse("ok") + + monkeypatch.setattr(provider_module.litellm, "completion", fake_completion) + + provider = OpenRouterProvider(model_id="demo-model", api_key="test-key") + + with pytest.warns(UserWarning, match="drop_params=True"): + assert provider.generate(prompt="ping") == "ok" + + assert calls[0].get("drop_params") is None + assert calls[1]["drop_params"] is True + + +def test_generate_response_preserves_litellm_reasoning_metadata(monkeypatch): + monkeypatch.setattr( + provider_module.litellm, + "completion", + lambda **kwargs: _DummyChatResponse( + "final answer", + reasoning_content="internal summary", + thinking_blocks=[ + { + "type": "thinking", + "thinking": "visible thinking block", + "signature": "sig", + } + ], + images=[{"type": "image", "url": "https://example.com/out.png"}], + audio={"id": "audio-1", "expires_at": 123}, + ), + ) + + provider = OpenRouterProvider(model_id="demo-model", api_key="test-key") + response = provider.generate_response(prompt="ping") + + assert response.text == "final answer" + assert response.reasoning_content == "internal summary" + assert response.thinking_blocks == [ + { + "type": "thinking", + "thinking": "visible thinking block", + "signature": "sig", + } + ] + assert response.images == [ + {"type": "image", "url": "https://example.com/out.png"} + ] + assert response.audio == {"id": "audio-1", "expires_at": 123} + + +def test_responses_full_response_preserves_output_items_and_media(monkeypatch): + output = [ + {"type": "reasoning", "summary": [{"text": "short rationale"}]}, + {"type": "image_generation_call", "result": "base64-image"}, + { + "type": "message", + "content": [{"type": "output_text", "text": "Here is the image."}], + }, + ] + + monkeypatch.setattr( + provider_module.litellm, + "responses", + lambda **kwargs: _DummyResponsesResponse(output=output), + ) + + provider = OpenAIProvider(model_id="gpt-5.5", api_key="test-key") + response = provider.generate_response(prompt="make an image") + + assert response.text == "Here is the image." + assert response.reasoning_content == "short rationale" + assert response.images == [ + {"type": "image_generation_call", "result": "base64-image"} + ] + assert response.output_items == output + + +def test_responses_endpoint_maps_reasoning_state_and_structured_output(monkeypatch): + captured = {} + + def fake_responses(**kwargs): + captured.update(kwargs) + return _DummyResponsesResponse('{"answer": "Paris"}') + + monkeypatch.setattr(provider_module.litellm, "responses", fake_responses) + + provider = OpenAIProvider( + model_id="gpt-5.5", + api_key="test-key", + thinking=True, + max_completion_tokens=64, + ) + + result = provider.generate( + messages=[{"role": "user", "content": "capital?"}], + response_format=SimpleSchema, + previous_response_id="resp_previous", + metadata={"purpose": "test"}, + ) + + assert result == SimpleSchema(answer="Paris") + assert captured["model"] == "openai/gpt-5.5" + assert captured["previous_response_id"] == "resp_previous" + assert captured["reasoning"] == {"effort": "low"} + assert captured["max_output_tokens"] == 64 + assert captured["text_format"] is SimpleSchema + assert captured["metadata"]["purpose"] == "test" + + +def test_fallback_batching_preserves_order(monkeypatch): + calls = [] + + def fake_completion(**kwargs): + calls.append(kwargs["messages"][0]["content"]) + return _DummyChatResponse(f"reply:{kwargs['messages'][0]['content']}") + + monkeypatch.setattr(provider_module.litellm, "completion", fake_completion) + + provider = openai_compatible( + "local-model", + api_base_url="http://localhost:8000/v1", + max_concurrent=1, + ) + + with pytest.warns(UserWarning, match="Falling back"): + result = provider.generate(prompt=["one", "two", "three"]) + + assert result == ["reply:one", "reply:two", "reply:three"] + assert calls == ["one", "two", "three"] + + +def test_structured_output_validation_error_is_clear(monkeypatch): + monkeypatch.setattr( + provider_module.litellm, + "completion", + lambda **kwargs: _DummyChatResponse("not json"), + ) + + provider = OpenRouterProvider(model_id="demo-model", api_key="test-key") + + with pytest.raises(ValueError, match="Failed to parse JSON response"): + provider.generate(prompt="answer in json", response_format=SimpleSchema) + + +def test_runner_dispatches_same_model_batches_through_generate_batch(): + class FakeBatchModel: + provider_name = "fake" + model_id = "fake-model" + + def __init__(self): + self.batches = [] + + def generate_batch(self, messages, metadata=None, response_format=None): + self.batches.append({"messages": messages, "metadata": metadata}) + return ["first", "second"] + + model = FakeBatchModel() + sink = ListSink() + pipeline = ( + Source.list([{"topic": "alpha"}, {"topic": "beta"}]) + >> LLMStep( + prompt="Write about {topic}.", + input_columns=["topic"], + output_column="result", + model=model, + ) + >> sink + ) + + output = pipeline.run(batch_size=2) + + assert output == [ + {"topic": "alpha", "result": "first", "_model": "fake-model"}, + {"topic": "beta", "result": "second", "_model": "fake-model"}, + ] + assert len(model.batches) == 1 + assert [batch[0]["content"] for batch in model.batches[0]["messages"]] == [ + "Write about alpha.", + "Write about beta.", + ] + assert len(model.batches[0]["metadata"]) == 2 diff --git a/tests/test_llms_unit.py b/tests/test_llms_unit.py new file mode 100644 index 0000000..1c67a6d --- /dev/null +++ b/tests/test_llms_unit.py @@ -0,0 +1,107 @@ +import datafast.llms as llms_module +from datafast.llms import OpenRouterProvider + + +class _DummyMessage: + def __init__(self, content: str, **extra: object) -> None: + self.content = content + for key, value in extra.items(): + setattr(self, key, value) + + +class _DummyChoice: + def __init__(self, content: str, **extra: object) -> None: + self.message = _DummyMessage(content, **extra) + + +class _DummyResponse: + def __init__(self, content: str, **extra: object) -> None: + self.choices = [_DummyChoice(content, **extra)] + + +def test_openrouter_single_messages_use_completion(monkeypatch): + monkeypatch.setattr(llms_module, "load_env_once", lambda: None) + monkeypatch.setattr( + llms_module, + "maybe_configure_langfuse_tracing", + lambda load_env=False: False, + ) + + calls = {"completion": 0, "batch_completion": 0} + + def fake_completion(**kwargs): + calls["completion"] += 1 + assert kwargs["messages"] == [{"role": "user", "content": "ping"}] + return _DummyResponse("ok") + + def fake_batch_completion(**kwargs): + calls["batch_completion"] += 1 + raise AssertionError("single-message requests should not use batch_completion") + + monkeypatch.setattr(llms_module.litellm, "completion", fake_completion) + monkeypatch.setattr(llms_module.litellm, "batch_completion", fake_batch_completion) + + provider = OpenRouterProvider(model_id="demo-model", api_key="test-key") + + response = provider.generate(messages=[{"role": "user", "content": "ping"}]) + + assert response == "ok" + assert calls == {"completion": 1, "batch_completion": 0} + + +def test_openrouter_batch_messages_use_batch_completion(monkeypatch): + monkeypatch.setattr(llms_module, "load_env_once", lambda: None) + monkeypatch.setattr( + llms_module, + "maybe_configure_langfuse_tracing", + lambda load_env=False: False, + ) + + calls = {"completion": 0, "batch_completion": 0} + + def fake_completion(**kwargs): + calls["completion"] += 1 + raise AssertionError("batched requests should not use completion") + + def fake_batch_completion(**kwargs): + calls["batch_completion"] += 1 + assert len(kwargs["messages"]) == 2 + return [_DummyResponse("first"), _DummyResponse("second")] + + monkeypatch.setattr(llms_module.litellm, "completion", fake_completion) + monkeypatch.setattr(llms_module.litellm, "batch_completion", fake_batch_completion) + + provider = OpenRouterProvider(model_id="demo-model", api_key="test-key") + + response = provider.generate(messages=[ + [{"role": "user", "content": "one"}], + [{"role": "user", "content": "two"}], + ]) + + assert response == ["first", "second"] + assert calls == {"completion": 0, "batch_completion": 1} + + +def test_openrouter_generate_response_reads_reasoning_field(monkeypatch): + monkeypatch.setattr(llms_module, "load_env_once", lambda: None) + monkeypatch.setattr( + llms_module, + "maybe_configure_langfuse_tracing", + lambda load_env=False: False, + ) + + monkeypatch.setattr( + llms_module.litellm, + "completion", + lambda **kwargs: _DummyResponse( + "final answer", + reasoning="hidden chain of thought summary", + ), + ) + + provider = OpenRouterProvider(model_id="demo-model", api_key="test-key") + + response = provider.generate_response(prompt="solve this") + + assert response.text == "final answer" + assert response.reasoning_content == "hidden chain of thought summary" diff --git a/tests/test_public_api.py b/tests/test_public_api.py index 7eaf787..ac56477 100644 --- a/tests/test_public_api.py +++ b/tests/test_public_api.py @@ -1,4 +1,5 @@ from datafast import ( + AddUUID, Branch, Classify, Compare, @@ -70,6 +71,7 @@ def test_factory_exports_are_available(monkeypatch): assert Sink is not None assert Seed is not None assert Sample is not None + assert AddUUID is not None assert Map is not None assert FlatMap is not None assert Filter is not None diff --git a/tests/test_runner_llm_messages.py b/tests/test_runner_llm_messages.py new file mode 100644 index 0000000..d870093 --- /dev/null +++ b/tests/test_runner_llm_messages.py @@ -0,0 +1,47 @@ +from datafast import LLMStep, ListSink, Source + + +def test_runner_passes_llm_messages_by_keyword(): + class FakeModel: + provider_name = "fake" + model_id = "fake-model" + + def __init__(self) -> None: + self.calls: list[dict] = [] + + def generate( + self, + prompt=None, + messages=None, + metadata=None, + response_format=None, + ): + self.calls.append({ + "prompt": prompt, + "messages": messages, + "metadata": metadata, + }) + return "done" + + model = FakeModel() + sink = ListSink() + + pipeline = ( + Source.list([{"topic": "robotics"}]) + >> LLMStep( + prompt="Write one short line about {topic}.", + input_columns=["topic"], + output_column="result", + model=model, + ).as_step("generate_copy") + >> sink + ) + + output = pipeline.run() + + assert output == [{"topic": "robotics", "result": "done", "_model": "fake-model"}] + assert len(model.calls) == 1 + assert model.calls[0]["prompt"] is None + assert model.calls[0]["messages"] == [ + {"role": "user", "content": "Write one short line about robotics."} + ]