Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
e26d097
Add unit tests for LLM message handling and batch completion routing
patrickfleith Apr 5, 2026
2aa9ebf
Remove openspec/ from gitignore
patrickfleith Apr 5, 2026
45c83be
Fix OpenRouter batch completion routing and add persona generation co…
patrickfleith Apr 5, 2026
69c1638
adding prompts
patrickfleith Apr 5, 2026
d42d8be
Adding openspec artifacts
patrickfleith Apr 5, 2026
8a19840
Update persona cookbook to use Mistral AI, expand to 20 samples, and …
patrickfleith Apr 5, 2026
2f07906
modified prompts
patrickfleith Apr 16, 2026
c3b28ae
Update persona cookbook to use OpenRouter with Nemotron model, expand…
patrickfleith Apr 16, 2026
6912ba2
Update persona cookbook documentation to reflect prompt variant rando…
patrickfleith Apr 16, 2026
f285b40
Add .agents/ to gitignore and remove archived persona cookbook opensp…
patrickfleith Apr 17, 2026
d62a2f3
Align persona cookbook docs with script
patrickfleith May 27, 2026
7762820
Improve persona cookbook resumability
patrickfleith Jun 10, 2026
0e8ebd4
Remove user-prompt assets from persona cookbook
patrickfleith Jun 10, 2026
973a23f
Add explicit UUID support for generated records
patrickfleith Jun 11, 2026
1e49ed3
Add text classification cookbook
patrickfleith Jun 11, 2026
f826137
Simplify text classification cookbook
patrickfleith Jun 11, 2026
e59363b
change default to public dataset
patrickfleith Jun 11, 2026
1743863
Refine trail comment prompt wording
patrickfleith Jun 11, 2026
684f3ea
Adding LLM Providers requirements
patrickfleith Jun 12, 2026
3b04600
Adding LLM Providers Test Plan
patrickfleith Jun 12, 2026
20c8896
updating LLM provider test plan
patrickfleith Jun 16, 2026
8c4c983
fix llms provider
patrickfleith Jun 16, 2026
58b6765
llm tests
patrickfleith Jun 16, 2026
b518730
utility function
patrickfleith Jun 16, 2026
4507d7c
script simple prompt test
patrickfleith Jun 16, 2026
6e17cf3
example with batch prompts
patrickfleith Jun 16, 2026
b5fcbea
messages with system prompt
patrickfleith Jun 16, 2026
9804cc9
structured output example
patrickfleith Jun 16, 2026
3fa2fb0
example with batch of messages
patrickfleith Jun 16, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ ipython_config.py
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
#uv.lock
uv.lock

# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
Expand Down Expand Up @@ -187,4 +187,4 @@ examples/checkpoints/
examples/outputs/

.codex/
openspec/
.agents/
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ Datafast is a python library for synthetic data generation using llms.
The old dataset-class API has been removed. The canonical package is `datafast`, and the primary model is:

- create records with `Source` or `Seed`
- transform them with composable steps
- transform them with composable steps such as `AddUUID`, `Map`, and `Filter`
- call LLMs with `LLMStep`, `Classify`, `Score`, `Compare`, `Rewrite`, or `Extract`
- persist results with `Sink`

Expand Down Expand Up @@ -53,7 +53,7 @@ pipeline.run(batch_size=4)

- `Source`: load records from Python lists, files, or Hugging Face datasets
- `Seed`: generate record combinations declaratively
- `Map`, `FlatMap`, `Filter`, `Group`, `Pair`, `Concat`, `Join`: data operations
- `AddUUID`, `Map`, `FlatMap`, `Filter`, `Group`, `Pair`, `Concat`, `Join`: data operations
- `LLMStep`: free-form generation
- `Classify`, `Score`, `Compare`, `Rewrite`, `Extract`: higher-level LLM transforms
- `Branch` and `JoinBranches`: multi-path pipelines
Expand Down Expand Up @@ -105,6 +105,7 @@ configure_langfuse_tracing()

- `datafast/`: canonical source package
- `examples/scripts/`: runnable pipeline examples
- `examples/providers/`: direct provider usage examples
- `docs/`: pipeline-first documentation
- `datafast_new_design_document.md`: retained design reference

Expand Down
7 changes: 6 additions & 1 deletion datafast/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
MistralProvider,
OpenRouterProvider,
OllamaProvider,
OpenAICompatibleProvider,
openai,
anthropic,
gemini,
mistral,
openrouter,
ollama,
openai_compatible,
)
from datafast.logger_config import configure_logger
from datafast.sinks.sink import Sink, JSONLSink, CSVSink, ListSink, ParquetSink, HubSink
Expand All @@ -31,7 +33,7 @@
is_langfuse_tracing_enabled,
)
from datafast.transforms.branch import Branch, JoinBranches
from datafast.transforms.data_ops import Map, FlatMap, Filter, Group, Pair, Concat, Join
from datafast.transforms.data_ops import AddUUID, Map, FlatMap, Filter, Group, Pair, Concat, Join
from datafast.transforms.llm_eval import Classify, Score, Compare
from datafast.transforms.llm_extract import Extract
from datafast.transforms.llm_step import LLMStep
Expand Down Expand Up @@ -64,6 +66,7 @@ def get_version() -> str:
"Seed",
"SeedDimension",
"Sample",
"AddUUID",
"Map",
"FlatMap",
"Filter",
Expand Down Expand Up @@ -92,12 +95,14 @@ def get_version() -> str:
"MistralProvider",
"OpenRouterProvider",
"OllamaProvider",
"OpenAICompatibleProvider",
"openai",
"anthropic",
"gemini",
"mistral",
"openrouter",
"ollama",
"openai_compatible",
"configure_logger",
"configure_langfuse_tracing",
"get_version",
Expand Down
231 changes: 183 additions & 48 deletions datafast/core/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import time
import uuid
from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING

from loguru import logger
Expand All @@ -19,7 +20,6 @@

if TYPE_CHECKING:
from datafast.core.step import Pipeline, Step
from datafast.llm.provider import LLMProvider
from datafast.transforms.llm_step import LLMStep


Expand All @@ -29,6 +29,12 @@ def chunked(iterable: list, size: int):
yield iterable[i : i + size]


@dataclass
class _LLMBatchStats:
generated: int = 0
errors: int = 0


class Runner:
"""
Execution engine for pipelines.
Expand Down Expand Up @@ -218,77 +224,206 @@ def _execute_llm_step(
if self._checkpoint_mgr and not skip_call_ids:
self._checkpoint_mgr.clear_step_file(step_index, step_name)

completed_in_batch = 0
completed_since_checkpoint = 0
errors = 0
generated_total = len(skip_call_ids)

for batch in chunked(calls, self.config.batch_size):
batch_start = time.perf_counter()
batch_generated = 0
batch_model_id = batch[0].model_id if batch else "unknown"

for call in batch:
model = models_map[call.model_id]
batch_model_id = call.model_id

try:
result = model.generate(
call.messages,
metadata=build_trace_metadata(
model=model,
component="pipeline.step",
trace_name=f"datafast.{step_name}",
session_id=self._trace_session_id,
step_name=step_name,
step_type=step.__class__.__name__,
record_index=call.record_index,
prompt_index=call.prompt_index,
output_index=call.output_index,
language_code=call.language_code or None,
call_id=call.call_id,
),
)
output_record = step.apply_result(call, result, model)
output_records.append(output_record)
progress.completed_call_ids.append(call.call_id)
completed_in_batch += 1
batch_generated += 1
generated_total += 1

# Append record immediately to JSONL
if self._checkpoint_mgr:
self._checkpoint_mgr.append_record(
step_index, step_name, output_record
)

except Exception as e:
errors += 1
logger.warning(
f"LLM call failed | Model: {call.model_id} | "
f"Call: {call.call_id} | Error: {e}"
)
batch_model_id = (
batch[0].model_id
if len({call.model_id for call in batch}) == 1
else "mixed"
)
stats = self._execute_llm_batch(
step=step,
step_name=step_name,
step_index=step_index,
batch=batch,
models_map=models_map,
progress=progress,
output_records=output_records,
)
completed_since_checkpoint += stats.generated
errors += stats.errors
generated_total += stats.generated

batch_duration = time.perf_counter() - batch_start
logger.info(
f"Generated {batch_generated} samples (total: {generated_total}) | "
f"Generated {stats.generated} samples (total: {generated_total}) | "
f"model: {batch_model_id} | duration: {batch_duration:.2f}s"
)

if (
self._checkpoint_mgr
and manifest
and completed_in_batch >= self.config.checkpoint_every
and completed_since_checkpoint >= self.config.checkpoint_every
):
self._checkpoint_mgr.save_llm_progress(
step_index, step_name, progress, output_records
)
completed_in_batch = 0
completed_since_checkpoint = 0

logger.info(
f"LLMStep complete: {len(output_records)} outputs, {errors} errors"
)
return output_records

def _execute_llm_batch(
self,
*,
step: "LLMStep",
step_name: str,
step_index: int,
batch: list[LLMCall],
models_map: dict[str, object],
progress: LLMStepProgress,
output_records: list[Record],
) -> _LLMBatchStats:
"""Execute and apply one runner batch, preserving input order."""
batch_results, errors = self._collect_llm_batch_results(
step=step,
step_name=step_name,
batch=batch,
models_map=models_map,
)
stats = self._apply_llm_batch_results(
step=step,
step_name=step_name,
step_index=step_index,
batch=batch,
batch_results=batch_results,
models_map=models_map,
progress=progress,
output_records=output_records,
)
stats.errors += errors
return stats

def _collect_llm_batch_results(
self,
*,
step: "LLMStep",
step_name: str,
batch: list[LLMCall],
models_map: dict[str, object],
) -> tuple[list[object | None], int]:
batch_results: list[object | None] = [None] * len(batch)
errors = 0
grouped_indexes: dict[str, list[int]] = defaultdict(list)

for index, call in enumerate(batch):
grouped_indexes[call.model_id].append(index)

for model_id, indexes in grouped_indexes.items():
group_calls = [batch[index] for index in indexes]
model = models_map[model_id]
try:
group_results = self._generate_llm_group(
step=step,
step_name=step_name,
model=model,
group_calls=group_calls,
)
except Exception as e:
errors += len(group_calls)
self._log_llm_failures(group_calls, e)
continue

for result_index, result in zip(indexes, group_results):
batch_results[result_index] = result

return batch_results, errors

def _generate_llm_group(
self,
*,
step: "LLMStep",
step_name: str,
model: object,
group_calls: list[LLMCall],
) -> list[object]:
group_metadata = [
self._build_llm_call_metadata(step, step_name, call, model)
for call in group_calls
]
if hasattr(model, "generate_batch"):
return list(
model.generate_batch( # type: ignore[attr-defined]
[call.messages for call in group_calls],
metadata=group_metadata,
)
)
return [
model.generate( # type: ignore[attr-defined]
messages=call.messages,
metadata=metadata,
)
for call, metadata in zip(group_calls, group_metadata)
]

def _apply_llm_batch_results(
self,
*,
step: "LLMStep",
step_name: str,
step_index: int,
batch: list[LLMCall],
batch_results: list[object | None],
models_map: dict[str, object],
progress: LLMStepProgress,
output_records: list[Record],
) -> _LLMBatchStats:
stats = _LLMBatchStats()

for call, result in zip(batch, batch_results):
if result is None:
continue
try:
output_record = step.apply_result(call, result, models_map[call.model_id])
output_records.append(output_record)
progress.completed_call_ids.append(call.call_id)
stats.generated += 1

if self._checkpoint_mgr:
self._checkpoint_mgr.append_record(
step_index, step_name, output_record
)
except Exception as e:
stats.errors += 1
self._log_llm_failures([call], e)

return stats

def _build_llm_call_metadata(
self,
step: "LLMStep",
step_name: str,
call: LLMCall,
model: object,
) -> dict[str, object]:
return build_trace_metadata(
model=model,
component="pipeline.step",
trace_name=f"datafast.{step_name}",
session_id=self._trace_session_id,
step_name=step_name,
step_type=step.__class__.__name__,
record_index=call.record_index,
prompt_index=call.prompt_index,
output_index=call.output_index,
language_code=call.language_code or None,
call_id=call.call_id,
)

@staticmethod
def _log_llm_failures(calls: list[LLMCall], error: Exception) -> None:
for call in calls:
logger.warning(
f"LLM call failed | Model: {call.model_id} | "
f"Call: {call.call_id} | Error: {error}"
)

def _order_calls(self, calls: list[LLMCall]) -> list[LLMCall]:
"""Order calls according to execution strategy."""
strategy = self.config.llm_strategy
Expand Down
Loading