Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions examples/api_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""Example usage of the ergonomic `substrait.api` facade.

Compare this with `builder_example.py`, which builds the same kinds of plans
with the lower-level `substrait.builders.*` functions.
"""

import substrait.api as sub
from substrait.utils.display import pretty_print_plan


def filter_select_example():
"""read -> filter -> with_columns -> select, with operator expressions."""
plan = (
sub.read_named_table(
"people", {"id": sub.i64, "age": sub.i64, "name": sub.string}
)
.filter((sub.col("age") > 25) & sub.col("name").is_not_null())
.with_columns(next_year=sub.col("age") + 1)
.select("id", "name", "next_year")
.to_plan()
)
pretty_print_plan(plan, use_colors=True)


def aggregate_example():
"""group_by().agg() with the named-function namespace `f`.

Note the explicit nullability: ``region`` is required, ``amount`` nullable.
``amount > 0`` also shows literal coercion -- the int ``0`` is typed to match
the fp64 column so the comparison overload resolves.
"""
plan = (
sub.read_named_table(
"sales", {"region": sub.string.non_null, "amount": sub.fp64}
)
.filter(sub.col("amount") > 0)
.group_by("region")
.agg(
sub.f.sum(sub.col("amount")).alias("total"),
sub.f.count(sub.col("amount")).alias("n"),
)
.to_plan()
)
pretty_print_plan(plan, use_colors=True)


def join_example():
"""Join two tables and project across the combined schema."""
customers = sub.read_named_table(
"customers", {"cust_id": sub.i64, "name": sub.string}
)
orders = sub.read_named_table(
"orders", {"order_id": sub.i64, "cust_ref": sub.i64, "amount": sub.fp64}
)
plan = (
customers.join(
orders, on=sub.col("cust_id") == sub.col("cust_ref"), how="inner"
)
.select("name", "amount")
.sort("amount", descending=True)
.limit(10)
.to_plan()
)
pretty_print_plan(plan, use_colors=True)


if __name__ == "__main__":
print("=== filter / with_columns / select ===")
filter_select_example()
print("\n=== group_by / agg ===")
aggregate_example()
print("\n=== join / sort / limit ===")
join_example()
17 changes: 0 additions & 17 deletions examples/dataframe_example.py

This file was deleted.

20 changes: 11 additions & 9 deletions examples/narwhals_example.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
# Install duckdb and pyarrow before running this example
# Example of the `substrait.narwhals` integration layer: drive Substrait plan
# construction through Narwhals (`nw.from_native`), so backend-agnostic Narwhals
# code compiles to a Substrait plan.
#
# For building plans directly (without Narwhals), see `api_example.py`, which
# uses the Substrait-native DataFrame in `substrait.api` / `substrait.frame`.
#
# /// script
# dependencies = [
# "narwhals==2.9.0",
Expand All @@ -9,7 +15,7 @@
import narwhals as nw
from narwhals.typing import FrameT

import substrait.dataframe as sdf
import substrait.narwhals as sn
from substrait.builders.plan import read_named_table
from substrait.builders.type import boolean, i64, named_struct, struct
from substrait.extension_registry import ExtensionRegistry
Expand All @@ -21,15 +27,11 @@
struct=struct(types=[i64(nullable=False), boolean()], nullable=False),
)

table = read_named_table("example_table", ns)


lazy_frame: FrameT = nw.from_native(
sdf.DataFrame(read_named_table("example_table", ns))
)
# Wrap the Substrait Narwhals backend and drive it with the Narwhals API.
lazy_frame: FrameT = nw.from_native(sn.DataFrame(read_named_table("example_table", ns)))

lazy_frame = lazy_frame.select(nw.col("id").abs(), new_id=nw.col("id"))

df: sdf.DataFrame = lazy_frame.to_native()
df: sn.DataFrame = lazy_frame.to_native()

print(df.to_substrait(registry))
106 changes: 106 additions & 0 deletions src/substrait/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
"""Ergonomic front door for substrait-python.

A single, shallow import that gets you productive::

import substrait.api as sub

plan = (
sub.read_named_table("people", {"id": sub.i64, "age": sub.i64, "name": sub.string})
.filter(sub.col("age") > 25)
.with_columns(adult=sub.col("age") >= 18)
.select("id", "name")
.to_plan()
)

This lives as a *submodule* (``substrait.api``) rather than the package root on
purpose: ``substrait`` is a PEP 420 namespace package shared with the
``substrait-protobuf`` distribution, so adding ``substrait/__init__.py`` would
shadow ``substrait.algebra_pb2`` and friends.

Everything here is an additive facade over the existing ``substrait.builders``,
``substrait.extension_registry`` and ``substrait.proto`` layers, which remain
available and unchanged.
"""

from __future__ import annotations

# Parametrized type builders (need arguments; kept as plain builder functions).
from substrait.builders.type import (
decimal,
fixed_binary,
fixed_char,
interval_compound,
interval_day,
named_struct,
precision_time,
precision_timestamp,
precision_timestamp_tz,
struct,
)
from substrait.builders.type import list as list_ # `list`/`map` shadow builtins
from substrait.builders.type import map as map_
from substrait.builders.type import var_char as varchar # spec spelling

# Primitive / no-argument type shortcuts as nullability-aware DataType objects
# (sub.i64 -> nullable; sub.i64.non_null -> required; sub.i64() still callable).
from substrait.dtypes import (
DataType,
binary,
boolean,
date,
fp32,
fp64,
i8,
i16,
i32,
i64,
interval_year,
string,
uuid,
)
from substrait.expr import Expr, col, infer_literal_type, lit
from substrait.extension_registry import ExtensionRegistry
from substrait.frame import DataFrame, default_registry, read_named_table
from substrait.functions import f, functions_for

__all__ = [
# entry points
"read_named_table",
"DataFrame",
"col",
"lit",
"f",
"functions_for",
"Expr",
# registry
"ExtensionRegistry",
"default_registry",
# types
"boolean",
"i8",
"i16",
"i32",
"i64",
"fp32",
"fp64",
"string",
"binary",
"date",
"uuid",
"interval_year",
"interval_day",
"interval_compound",
"fixed_char",
"varchar",
"fixed_binary",
"decimal",
"precision_time",
"precision_timestamp",
"precision_timestamp_tz",
"struct",
"named_struct",
"list_",
"map_",
"DataType",
"infer_literal_type",
]
Loading
Loading