From 888da2c88d48744054f68587068c7c11fd636348 Mon Sep 17 00:00:00 2001 From: Niels Pardon Date: Thu, 2 Jul 2026 17:21:57 +0200 Subject: [PATCH 1/4] feat(builders): support literals for all Substrait types + registry.iter_functions Extend the builder literal() to construct a literal for every Substrait type (decimal, uuid, precision time/timestamp[_tz], all interval kinds, struct, list, map with empty-list/empty-map handling, and typed nulls via value=None) through a new recursive _make_literal helper. Existing kinds remain byte-identical. Add the missing precision_time case to type_inference.infer_literal_type so every kind round-trips, and add ExtensionRegistry.iter_functions() to enumerate every registered (urn, name, function_type). --- src/substrait/builders/extended_expression.py | 324 +++++++++++++----- src/substrait/extension_registry/registry.py | 12 + src/substrait/type_inference.py | 6 + 3 files changed, 249 insertions(+), 93 deletions(-) diff --git a/src/substrait/builders/extended_expression.py b/src/substrait/builders/extended_expression.py index 26c7f10..7d35bf6 100644 --- a/src/substrait/builders/extended_expression.py +++ b/src/substrait/builders/extended_expression.py @@ -1,5 +1,8 @@ +import calendar import itertools -from datetime import date +import uuid as uuid_module +from datetime import date, datetime, time, timedelta, timezone +from decimal import ROUND_HALF_EVEN, Decimal from typing import Any, Callable, Iterable, Union import substrait.algebra_pb2 as stalg @@ -44,108 +47,243 @@ def resolve_expression( ) +_EPOCH_DATE = date(1970, 1, 1) + + +def _scale_subseconds(microseconds: int, precision: int) -> int: + """Convert a microsecond count to ``precision`` sub-second units.""" + if precision >= 6: + return microseconds * 10 ** (precision - 6) + return microseconds // 10 ** (6 - precision) + + +def _encode_decimal(value: Any, scale: int) -> bytes: + """Encode a decimal as the 16-byte little-endian two's-complement unscaled value.""" + dec = value if isinstance(value, Decimal) else Decimal(str(value)) + unscaled = int((dec * (Decimal(10) ** scale)).to_integral_value(ROUND_HALF_EVEN)) + return unscaled.to_bytes(16, byteorder="little", signed=True) + + +def _encode_uuid(value: Any) -> bytes: + if isinstance(value, uuid_module.UUID): + return value.bytes + if isinstance(value, str): + return uuid_module.UUID(value).bytes + if isinstance(value, (bytes, bytearray)): + if len(value) != 16: + raise ValueError("uuid literal must be exactly 16 bytes") + return bytes(value) + raise TypeError(f"cannot build a uuid literal from {type(value).__name__}") + + +def _timestamp_units(value: Any, precision: int) -> int: + """Sub-second units since the Unix epoch for an int or datetime value.""" + if isinstance(value, datetime): + if value.tzinfo is not None: + value = value.astimezone(timezone.utc) + micros = calendar.timegm(value.timetuple()) * 1_000_000 + value.microsecond + return _scale_subseconds(micros, precision) + return value + + +def _time_units(value: Any, precision: int) -> int: + """Sub-second units since midnight for an int or datetime.time value.""" + if isinstance(value, time): + micros = ( + value.hour * 3600 + value.minute * 60 + value.second + ) * 1_000_000 + value.microsecond + return _scale_subseconds(micros, precision) + return value + + +def _interval_day_to_second(value: Any, precision: int): + """Build an IntervalDayToSecond from a timedelta or a (days, seconds[, subseconds]) tuple.""" + if isinstance(value, timedelta): + days, seconds, subseconds = ( + value.days, + value.seconds, + _scale_subseconds(value.microseconds, precision), + ) + else: + days, seconds, *rest = value + subseconds = rest[0] if rest else 0 + return stalg.Expression.Literal.IntervalDayToSecond( + days=days, seconds=seconds, subseconds=subseconds, precision=precision + ) + + +def _interval_year_to_month(value: Any): + """Build an IntervalYearToMonth from an int (years) or a (years, months) tuple.""" + if isinstance(value, (tuple, list)): + years, months = value + else: + years, months = value, 0 + return stalg.Expression.Literal.IntervalYearToMonth(years=years, months=months) + + +def _make_literal(value: Any, type: stp.Type) -> stalg.Expression.Literal: + """Recursively build an ``Expression.Literal`` for ``value`` of ``type``. + + A ``value`` of ``None`` produces a typed null literal of ``type``. Nested + types (struct/list/map) recurse into their element types. Supported value + representations for the less-obvious kinds: + + - decimal: ``decimal.Decimal`` / ``int`` / ``float`` / ``str`` + - uuid: ``uuid.UUID`` / 16 ``bytes`` / hex ``str`` + - precision_timestamp[_tz]: ``int`` sub-second units, or ``datetime`` + - precision_time: ``int`` sub-second units, or ``datetime.time`` + - interval_year: ``int`` years or ``(years, months)`` + - interval_day: ``datetime.timedelta`` or ``(days, seconds[, subseconds])`` + - interval_compound: ``((years, months), (days, seconds[, subseconds]))`` + - struct: sequence of field values; list: sequence; map: ``dict`` or pairs + """ + Literal = stalg.Expression.Literal + + if value is None: + return Literal(null=type, nullable=True) + + kind = type.WhichOneof("kind") + nullable = getattr(type, kind).nullability == stp.Type.NULLABILITY_NULLABLE + + if kind == "bool": + return Literal(boolean=value, nullable=nullable) + elif kind == "i8": + return Literal(i8=value, nullable=nullable) + elif kind == "i16": + return Literal(i16=value, nullable=nullable) + elif kind == "i32": + return Literal(i32=value, nullable=nullable) + elif kind == "i64": + return Literal(i64=value, nullable=nullable) + elif kind == "fp32": + return Literal(fp32=value, nullable=nullable) + elif kind == "fp64": + return Literal(fp64=value, nullable=nullable) + elif kind == "string": + return Literal(string=value, nullable=nullable) + elif kind == "binary": + return Literal(binary=value, nullable=nullable) + elif kind == "date": + date_value = (value - _EPOCH_DATE).days if isinstance(value, date) else value + return Literal(date=date_value, nullable=nullable) + elif kind == "interval_year": + return Literal( + interval_year_to_month=_interval_year_to_month(value), nullable=nullable + ) + elif kind == "interval_day": + return Literal( + interval_day_to_second=_interval_day_to_second( + value, type.interval_day.precision + ), + nullable=nullable, + ) + elif kind == "interval_compound": + precision = type.interval_compound.precision + ym, ds = value + return Literal( + interval_compound=stalg.Expression.Literal.IntervalCompound( + interval_year_to_month=_interval_year_to_month(ym), + interval_day_to_second=_interval_day_to_second(ds, precision), + ), + nullable=nullable, + ) + elif kind == "fixed_char": + return Literal(fixed_char=value, nullable=nullable) + elif kind == "varchar": + return Literal( + var_char=Literal.VarChar(value=value, length=type.varchar.length), + nullable=nullable, + ) + elif kind == "fixed_binary": + return Literal(fixed_binary=value, nullable=nullable) + elif kind == "decimal": + return Literal( + decimal=Literal.Decimal( + value=_encode_decimal(value, type.decimal.scale), + precision=type.decimal.precision, + scale=type.decimal.scale, + ), + nullable=nullable, + ) + elif kind == "precision_time": + precision = type.precision_time.precision + return Literal( + precision_time=Literal.PrecisionTime( + precision=precision, value=_time_units(value, precision) + ), + nullable=nullable, + ) + elif kind == "precision_timestamp": + precision = type.precision_timestamp.precision + return Literal( + precision_timestamp=Literal.PrecisionTimestamp( + precision=precision, value=_timestamp_units(value, precision) + ), + nullable=nullable, + ) + elif kind == "precision_timestamp_tz": + precision = type.precision_timestamp_tz.precision + return Literal( + precision_timestamp_tz=Literal.PrecisionTimestamp( + precision=precision, value=_timestamp_units(value, precision) + ), + nullable=nullable, + ) + elif kind == "uuid": + return Literal(uuid=_encode_uuid(value), nullable=nullable) + elif kind == "struct": + return Literal( + struct=Literal.Struct( + fields=[_make_literal(v, t) for v, t in zip(value, type.struct.types)] + ), + nullable=nullable, + ) + elif kind == "list": + values = list(value) + if not values: + return Literal(empty_list=type.list, nullable=nullable) + return Literal( + list=Literal.List( + values=[_make_literal(v, type.list.type) for v in values] + ), + nullable=nullable, + ) + elif kind == "map": + items = list(value.items() if isinstance(value, dict) else value) + if not items: + return Literal(empty_map=type.map, nullable=nullable) + return Literal( + map=Literal.Map( + key_values=[ + Literal.Map.KeyValue( + key=_make_literal(k, type.map.key), + value=_make_literal(v, type.map.value), + ) + for k, v in items + ] + ), + nullable=nullable, + ) + else: + raise Exception(f"Unknown literal type - {type}") + + def literal( value: Any, type: stp.Type, alias: Union[Iterable[str], str, None] = None ) -> UnboundExtendedExpression: - """Builds a resolver for ExtendedExpression containing a literal expression""" + """Builds a resolver for ExtendedExpression containing a literal expression. + + ``value`` of ``None`` yields a typed null literal. See :func:`_make_literal` + for the accepted value representations of each type kind. + """ def resolve( base_schema: stp.NamedStruct, registry: ExtensionRegistry ) -> stee.ExtendedExpression: - kind = type.WhichOneof("kind") - - if kind == "bool": - literal = stalg.Expression.Literal( - boolean=value, - nullable=type.bool.nullability == stp.Type.NULLABILITY_NULLABLE, - ) - elif kind == "i8": - literal = stalg.Expression.Literal( - i8=value, nullable=type.i8.nullability == stp.Type.NULLABILITY_NULLABLE - ) - elif kind == "i16": - literal = stalg.Expression.Literal( - i16=value, - nullable=type.i16.nullability == stp.Type.NULLABILITY_NULLABLE, - ) - elif kind == "i32": - literal = stalg.Expression.Literal( - i32=value, - nullable=type.i32.nullability == stp.Type.NULLABILITY_NULLABLE, - ) - elif kind == "i64": - literal = stalg.Expression.Literal( - i64=value, - nullable=type.i64.nullability == stp.Type.NULLABILITY_NULLABLE, - ) - elif kind == "fp32": - literal = stalg.Expression.Literal( - fp32=value, - nullable=type.fp32.nullability == stp.Type.NULLABILITY_NULLABLE, - ) - elif kind == "fp64": - literal = stalg.Expression.Literal( - fp64=value, - nullable=type.fp64.nullability == stp.Type.NULLABILITY_NULLABLE, - ) - elif kind == "string": - literal = stalg.Expression.Literal( - string=value, - nullable=type.string.nullability == stp.Type.NULLABILITY_NULLABLE, - ) - elif kind == "binary": - literal = stalg.Expression.Literal( - binary=value, - nullable=type.binary.nullability == stp.Type.NULLABILITY_NULLABLE, - ) - elif kind == "date": - date_value = ( - (value - date(1970, 1, 1)).days if isinstance(value, date) else value - ) - literal = stalg.Expression.Literal( - date=date_value, - nullable=type.date.nullability == stp.Type.NULLABILITY_NULLABLE, - ) - # TODO - # IntervalYearToMonth interval_year_to_month = 19; - # IntervalDayToSecond interval_day_to_second = 20; - # IntervalCompound interval_compound = 36; - elif kind == "fixed_char": - literal = stalg.Expression.Literal( - fixed_char=value, - nullable=type.fixed_char.nullability == stp.Type.NULLABILITY_NULLABLE, - ) - elif kind == "varchar": - literal = stalg.Expression.Literal( - var_char=stalg.Expression.Literal.VarChar( - value=value, length=type.varchar.length - ), - nullable=type.varchar.nullability == stp.Type.NULLABILITY_NULLABLE, - ) - elif kind == "fixed_binary": - literal = stalg.Expression.Literal( - fixed_binary=value, - nullable=type.fixed_binary.nullability == stp.Type.NULLABILITY_NULLABLE, - ) - # TODO - # Decimal decimal = 24; - # PrecisionTime precision_time = 37; // Time in precision units past midnight. - # PrecisionTimestamp precision_timestamp = 34; - # PrecisionTimestamp precision_timestamp_tz = 35; - # Struct struct = 25; - # Map map = 26; - # bytes uuid = 28; - # Type null = 29; // a typed null literal - # List list = 30; - # Type.List empty_list = 31; - # Type.Map empty_map = 32; - else: - raise Exception(f"Unknown literal type - {type}") - return stee.ExtendedExpression( referred_expr=[ stee.ExpressionReference( - expression=stalg.Expression(literal=literal), + expression=stalg.Expression(literal=_make_literal(value, type)), output_names=_alias_or_inferred(alias, "Literal", [str(value)]), ) ], diff --git a/src/substrait/extension_registry/registry.py b/src/substrait/extension_registry/registry.py index aa0b872..76d18a4 100644 --- a/src/substrait/extension_registry/registry.py +++ b/src/substrait/extension_registry/registry.py @@ -138,6 +138,18 @@ def list_functions_across_urns( def lookup_urn(self, urn: str) -> Optional[int]: return self._urn_mapping.get(urn, None) + def iter_functions(self): + """Yield ``(urn, name, function_type)`` for every registered function. + + One tuple per ``(urn, name)`` group (overloads are collapsed). Useful for + discovering the full set of available functions, e.g. to build a + function-helper namespace. + """ + for urn, names in self._function_mapping.items(): + for name, entries in names.items(): + if entries: + yield urn, name, entries[0].function_type + def validate_urn_format(urn: str) -> str: """Validate that a URN follows the expected format. diff --git a/src/substrait/type_inference.py b/src/substrait/type_inference.py index 5331965..9e07047 100644 --- a/src/substrait/type_inference.py +++ b/src/substrait/type_inference.py @@ -79,6 +79,12 @@ def infer_literal_type(literal: stalg.Expression.Literal) -> stt.Type: nullability=nullability, ) ) + elif literal_type == "precision_time": + return stt.Type( + precision_time=stt.Type.PrecisionTime( + precision=literal.precision_time.precision, nullability=nullability + ) + ) elif literal_type == "precision_timestamp": return stt.Type( precision_timestamp=stt.Type.PrecisionTimestamp( From ec452ab75ce0906bd99231e96e28c289135ca4a5 Mon Sep 17 00:00:00 2001 From: Niels Pardon Date: Thu, 2 Jul 2026 17:22:13 +0200 Subject: [PATCH 2/4] refactor(narwhals)!: rename substrait.dataframe module to substrait.narwhals The module is the Narwhals integration layer, not a general DataFrame; rename it to reflect that role and free the "DataFrame" name for the native frame. BREAKING CHANGE: import substrait.narwhals instead of substrait.dataframe. The module was a minimal, experimental Narwhals wrapper, so impact is expected low. --- src/substrait/dataframe/dataframe.py | 37 ------------ .../{dataframe => narwhals}/__init__.py | 8 +-- src/substrait/narwhals/dataframe.py | 58 +++++++++++++++++++ .../{dataframe => narwhals}/expression.py | 0 .../test_df_project.py | 2 +- 5 files changed, 63 insertions(+), 42 deletions(-) delete mode 100644 src/substrait/dataframe/dataframe.py rename src/substrait/{dataframe => narwhals}/__init__.py (59%) create mode 100644 src/substrait/narwhals/dataframe.py rename src/substrait/{dataframe => narwhals}/expression.py (100%) rename tests/{dataframe => narwhals}/test_df_project.py (98%) diff --git a/src/substrait/dataframe/dataframe.py b/src/substrait/dataframe/dataframe.py deleted file mode 100644 index 57f0da3..0000000 --- a/src/substrait/dataframe/dataframe.py +++ /dev/null @@ -1,37 +0,0 @@ -from typing import Iterable, Union - -import substrait.dataframe -from substrait.builders.plan import select -from substrait.dataframe.expression import Expression - - -class DataFrame: - def __init__(self, plan): - self.plan = plan - self._native_frame = self - - def to_substrait(self, registry): - return self.plan(registry) - - def __narwhals_lazyframe__(self) -> "DataFrame": - """Return object implementing CompliantDataFrame protocol.""" - return self - - def __narwhals_namespace__(self): - """ - Return the namespace object that contains functions like col, lit, etc. - This is how Narwhals knows which backend's functions to use. - """ - return substrait.dataframe - - def select( - self, *exprs: Union[Expression, Iterable[Expression]], **named_exprs: Expression - ) -> "DataFrame": - expressions = [e.expr for e in exprs] + [ - expr.alias(alias).expr for alias, expr in named_exprs.items() - ] - return DataFrame(select(self.plan, expressions=expressions)) - - # TODO handle version - def _with_version(self, version): - return self diff --git a/src/substrait/dataframe/__init__.py b/src/substrait/narwhals/__init__.py similarity index 59% rename from src/substrait/dataframe/__init__.py rename to src/substrait/narwhals/__init__.py index 0a37d04..4e35d95 100644 --- a/src/substrait/dataframe/__init__.py +++ b/src/substrait/narwhals/__init__.py @@ -1,7 +1,7 @@ -import substrait.dataframe +import substrait.narwhals from substrait.builders.extended_expression import column -from substrait.dataframe.dataframe import DataFrame -from substrait.dataframe.expression import Expression +from substrait.narwhals.dataframe import DataFrame +from substrait.narwhals.expression import Expression __all__ = [DataFrame, Expression] @@ -13,4 +13,4 @@ def col(name: str) -> Expression: # TODO handle str_as_lit argument def parse_into_expr(expr, str_as_lit: bool): - return expr._to_compliant_expr(substrait.dataframe) + return expr._to_compliant_expr(substrait.narwhals) diff --git a/src/substrait/narwhals/dataframe.py b/src/substrait/narwhals/dataframe.py new file mode 100644 index 0000000..5476950 --- /dev/null +++ b/src/substrait/narwhals/dataframe.py @@ -0,0 +1,58 @@ +"""The Narwhals integration layer for Substrait. + +This module is the **Narwhals-compliant wrapper**: it lets ``narwhals`` drive +Substrait plan construction via ``nw.from_native(...)`` by exposing the backend +hooks (``__narwhals_lazyframe__`` / ``__narwhals_namespace__``) and translating +Narwhals calls into Substrait plan builders. + +It is distinct from :mod:`substrait.frame`, which is the Substrait-*native* +fluent DataFrame you can call directly without Narwhals. This layer sits on top +of that native machinery; the two compose rather than compete. + +Status: experimental / minimal -- it currently implements only a subset of the +Narwhals compliant protocol, to be built out on top of :mod:`substrait.frame`. +""" + +from typing import Iterable, Union + +import substrait.narwhals +from substrait.builders.plan import select +from substrait.narwhals.expression import Expression + + +class DataFrame: + """Narwhals-compliant wrapper around a Substrait plan. + + Presents as a Narwhals ``LazyFrame`` backend. For direct, non-Narwhals plan + building use :class:`substrait.frame.DataFrame` instead. + """ + + def __init__(self, plan): + self.plan = plan + self._native_frame = self + + def to_substrait(self, registry): + return self.plan(registry) + + def __narwhals_lazyframe__(self) -> "DataFrame": + """Return object implementing CompliantDataFrame protocol.""" + return self + + def __narwhals_namespace__(self): + """ + Return the namespace object that contains functions like col, lit, etc. + This is how Narwhals knows which backend's functions to use. + """ + return substrait.narwhals + + def select( + self, *exprs: Union[Expression, Iterable[Expression]], **named_exprs: Expression + ) -> "DataFrame": + expressions = [e.expr for e in exprs] + [ + expr.alias(alias).expr for alias, expr in named_exprs.items() + ] + return DataFrame(select(self.plan, expressions=expressions)) + + # TODO handle version + def _with_version(self, version): + return self diff --git a/src/substrait/dataframe/expression.py b/src/substrait/narwhals/expression.py similarity index 100% rename from src/substrait/dataframe/expression.py rename to src/substrait/narwhals/expression.py diff --git a/tests/dataframe/test_df_project.py b/tests/narwhals/test_df_project.py similarity index 98% rename from tests/dataframe/test_df_project.py rename to tests/narwhals/test_df_project.py index acfe5b5..bc70b3d 100644 --- a/tests/dataframe/test_df_project.py +++ b/tests/narwhals/test_df_project.py @@ -2,7 +2,7 @@ import substrait.plan_pb2 as stp import substrait.type_pb2 as stt -import substrait.dataframe as sdf +import substrait.narwhals as sdf from substrait.builders.plan import default_version, read_named_table from substrait.builders.type import boolean, i64 from substrait.extension_registry import ExtensionRegistry From 13eb715f44a191a8c676a571153b27e72d5493d9 Mon Sep 17 00:00:00 2001 From: Niels Pardon Date: Thu, 2 Jul 2026 17:22:34 +0200 Subject: [PATCH 3/4] feat(api): ergonomic native DataFrame/Expr API with full function and type coverage Add substrait.api, a shallow front door over the existing builders: - expr.Expr: operator overloading (comparison/arithmetic/boolean), literal auto-wrap with peer-type coercion, and .cast()/.alias()/.is_null() - frame.DataFrame: chainable filter/select/with_columns/sort/limit/join/ group_by().agg(), carrying an ExtensionRegistry so it is not threaded through every call - functions.f: every scalar/aggregate/window function, generated lazily from the registry; multi-extension names resolved by argument type; functions_for() and DataFrame.f expose custom-registry functions - dtypes: nullability-aware type shortcuts (sub.i64 / sub.i64.non_null) covering every concrete Substrait type The facade is faithful: it emits byte-identical protobuf to the equivalent builder calls. Adds tests/api covering expressions, frame verbs, function coverage, type coverage and literal construction. --- src/substrait/api.py | 106 ++++++++++++ src/substrait/dtypes.py | 59 +++++++ src/substrait/expr.py | 312 ++++++++++++++++++++++++++++++++++++ src/substrait/frame.py | 229 ++++++++++++++++++++++++++ src/substrait/functions.py | 175 ++++++++++++++++++++ tests/api/test_dtypes.py | 233 +++++++++++++++++++++++++++ tests/api/test_expr.py | 176 ++++++++++++++++++++ tests/api/test_frame.py | 179 +++++++++++++++++++++ tests/api/test_functions.py | 194 ++++++++++++++++++++++ tests/api/test_literals.py | 197 +++++++++++++++++++++++ 10 files changed, 1860 insertions(+) create mode 100644 src/substrait/api.py create mode 100644 src/substrait/dtypes.py create mode 100644 src/substrait/expr.py create mode 100644 src/substrait/frame.py create mode 100644 src/substrait/functions.py create mode 100644 tests/api/test_dtypes.py create mode 100644 tests/api/test_expr.py create mode 100644 tests/api/test_frame.py create mode 100644 tests/api/test_functions.py create mode 100644 tests/api/test_literals.py diff --git a/src/substrait/api.py b/src/substrait/api.py new file mode 100644 index 0000000..c2ab349 --- /dev/null +++ b/src/substrait/api.py @@ -0,0 +1,106 @@ +"""Ergonomic front door for substrait-python. + +A single, shallow import that gets you productive:: + + import substrait.api as sub + + plan = ( + sub.read_named_table("people", {"id": sub.i64, "age": sub.i64, "name": sub.string}) + .filter(sub.col("age") > 25) + .with_columns(adult=sub.col("age") >= 18) + .select("id", "name") + .to_plan() + ) + +This lives as a *submodule* (``substrait.api``) rather than the package root on +purpose: ``substrait`` is a PEP 420 namespace package shared with the +``substrait-protobuf`` distribution, so adding ``substrait/__init__.py`` would +shadow ``substrait.algebra_pb2`` and friends. + +Everything here is an additive facade over the existing ``substrait.builders``, +``substrait.extension_registry`` and ``substrait.proto`` layers, which remain +available and unchanged. +""" + +from __future__ import annotations + +# Parametrized type builders (need arguments; kept as plain builder functions). +from substrait.builders.type import ( + decimal, + fixed_binary, + fixed_char, + interval_compound, + interval_day, + named_struct, + precision_time, + precision_timestamp, + precision_timestamp_tz, + struct, +) +from substrait.builders.type import list as list_ # `list`/`map` shadow builtins +from substrait.builders.type import map as map_ +from substrait.builders.type import var_char as varchar # spec spelling + +# Primitive / no-argument type shortcuts as nullability-aware DataType objects +# (sub.i64 -> nullable; sub.i64.non_null -> required; sub.i64() still callable). +from substrait.dtypes import ( + DataType, + binary, + boolean, + date, + fp32, + fp64, + i8, + i16, + i32, + i64, + interval_year, + string, + uuid, +) +from substrait.expr import Expr, col, infer_literal_type, lit +from substrait.extension_registry import ExtensionRegistry +from substrait.frame import DataFrame, default_registry, read_named_table +from substrait.functions import f, functions_for + +__all__ = [ + # entry points + "read_named_table", + "DataFrame", + "col", + "lit", + "f", + "functions_for", + "Expr", + # registry + "ExtensionRegistry", + "default_registry", + # types + "boolean", + "i8", + "i16", + "i32", + "i64", + "fp32", + "fp64", + "string", + "binary", + "date", + "uuid", + "interval_year", + "interval_day", + "interval_compound", + "fixed_char", + "varchar", + "fixed_binary", + "decimal", + "precision_time", + "precision_timestamp", + "precision_timestamp_tz", + "struct", + "named_struct", + "list_", + "map_", + "DataType", + "infer_literal_type", +] diff --git a/src/substrait/dtypes.py b/src/substrait/dtypes.py new file mode 100644 index 0000000..867544f --- /dev/null +++ b/src/substrait/dtypes.py @@ -0,0 +1,59 @@ +"""Nullability-aware type shortcuts for the ergonomic API. + +The lower-level ``substrait.builders.type`` builders take a ``nullable`` keyword +that defaults to ``True``, which is easy to apply silently. ``DataType`` wraps a +primitive builder so nullability is explicit and reads well -- inspired by +substrait-java's ``N`` (nullable) / ``R`` (required) ``TypeCreator`` constants:: + + sub.i64 # bare: nullable (the safe default) when used in a schema + sub.i64.nullable # explicitly nullable + sub.i64.non_null # required / non-nullable + sub.i64() # still callable, for parity with the builder layer + sub.i64(nullable=False) + +A ``DataType`` is callable, so anywhere a zero-arg type builder is accepted +(schema dicts, ``lit``) a bare ``sub.i64`` keeps working and yields a nullable +type; ``sub.i64.non_null`` yields a ready-made non-nullable ``proto.Type``. +""" + +from __future__ import annotations + +import substrait.type_pb2 as stp + +from substrait.builders import type as _t + + +class DataType: + __slots__ = ("_name", "_builder") + + def __init__(self, name: str, builder): + self._name = name + self._builder = builder + + def __call__(self, nullable: bool = True) -> stp.Type: + return self._builder(nullable) + + @property + def nullable(self) -> stp.Type: + return self._builder(True) + + @property + def non_null(self) -> stp.Type: + return self._builder(False) + + def __repr__(self) -> str: # pragma: no cover - debugging aid + return f"" + + +boolean = DataType("boolean", _t.boolean) +i8 = DataType("i8", _t.i8) +i16 = DataType("i16", _t.i16) +i32 = DataType("i32", _t.i32) +i64 = DataType("i64", _t.i64) +fp32 = DataType("fp32", _t.fp32) +fp64 = DataType("fp64", _t.fp64) +string = DataType("string", _t.string) +binary = DataType("binary", _t.binary) +date = DataType("date", _t.date) +uuid = DataType("uuid", _t.uuid) +interval_year = DataType("interval_year", _t.interval_year) diff --git a/src/substrait/expr.py b/src/substrait/expr.py new file mode 100644 index 0000000..0273760 --- /dev/null +++ b/src/substrait/expr.py @@ -0,0 +1,312 @@ +"""Ergonomic expression wrapper. + +``Expr`` wraps the existing "unbound" expression callables produced by +``substrait.builders.extended_expression`` and adds Python operator overloading +so that expressions can be written the way users of pandas / Polars / PySpark / +Ibis expect:: + + col("age") > 25 + (col("x") + col("y")) * 2 + col("a").is_null() & col("b") + +Each operator maps to a fixed standard function-extension URN + signature name +and defers to the existing ``scalar_function`` builder, which already resolves +the concrete overload lazily against an ``ExtensionRegistry``. Nothing here +reimplements resolution or type inference -- it is a thin, additive facade. +""" + +from __future__ import annotations + +import uuid as _uuid +from datetime import date as _date +from datetime import datetime as _datetime +from datetime import time as _time +from decimal import Decimal as _Decimal +from typing import Any, Union + +import substrait.type_pb2 as stp + +from substrait.builders import type as _t +from substrait.builders.extended_expression import ( + UnboundExtendedExpression, + cast, + column, + literal, + scalar_function, +) +from substrait.type_inference import infer_extended_expression_schema + +# Standard Substrait function-extension URNs used by the operators below. +FUNCTIONS_COMPARISON = "extension:io.substrait:functions_comparison" +FUNCTIONS_ARITHMETIC = "extension:io.substrait:functions_arithmetic" +FUNCTIONS_BOOLEAN = "extension:io.substrait:functions_boolean" +FUNCTIONS_STRING = "extension:io.substrait:functions_string" +FUNCTIONS_AGGREGATE_GENERIC = "extension:io.substrait:functions_aggregate_generic" + + +def _decimal_type(value: _Decimal) -> stp.Type: + exponent = value.as_tuple().exponent + if not isinstance(exponent, int): # NaN / Infinity have symbolic exponents + raise TypeError("cannot infer a decimal literal type from a non-finite Decimal") + scale = -exponent if exponent < 0 else 0 + precision = max(len(value.as_tuple().digits), scale, 1) + return _t.decimal(scale, precision) + + +def infer_literal_type(value: Any) -> stp.Type: + """Best-effort mapping from a Python scalar to a Substrait type. + + Used to auto-wrap bare Python values on the right-hand side of an operator, + e.g. the ``25`` in ``col("age") > 25``. ``bool`` is checked before ``int`` + (``isinstance(True, int)`` is ``True``) and ``datetime`` before ``date`` + (``datetime`` subclasses ``date``). + """ + if isinstance(value, bool): + return _t.boolean() + if isinstance(value, int): + return _t.i64() + if isinstance(value, float): + return _t.fp64() + if isinstance(value, _Decimal): + return _decimal_type(value) + if isinstance(value, str): + return _t.string() + if isinstance(value, (bytes, bytearray)): + return _t.binary() + if isinstance(value, _datetime): + # microsecond precision; tz-aware values map to the *_tz variant. + return ( + _t.precision_timestamp_tz(6) + if value.tzinfo is not None + else _t.precision_timestamp(6) + ) + if isinstance(value, _date): + return _t.date() + if isinstance(value, _time): + return _t.precision_time(6) + if isinstance(value, _uuid.UUID): + return _t.uuid() + raise TypeError( + f"Cannot infer a Substrait literal type for {value!r} " + f"({type(value).__name__}); wrap it with lit(value, ) instead." + ) + + +_NUMERIC_BUILDERS = { + "i8": _t.i8, + "i16": _t.i16, + "i32": _t.i32, + "i64": _t.i64, + "fp32": _t.fp32, + "fp64": _t.fp64, +} + + +def _match_numeric_type(peer_type: stp.Type, value: Any) -> stp.Type: + """Pick a literal type for ``value`` that matches a numeric ``peer_type``. + + Substrait does not implicitly coerce mixed numeric operands, so + ``col("price_fp64") * 2`` needs the ``2`` typed as ``fp64`` rather than the + default ``i64`` for the ``multiply`` overload to resolve. A ``float`` value + always stays floating point to avoid a lossy narrowing. + """ + kind = peer_type.WhichOneof("kind") + if isinstance(value, float): + return _t.fp32() if kind == "fp32" else _t.fp64() + builder = _NUMERIC_BUILDERS.get(kind) + return builder() if builder else _t.i64() + + +def _numeric_binary( + self_expr: "Expr", other: Any, urn: str, fn: str, *, swap: bool = False +) -> "Expr": + """Build a binary comparison/arithmetic expression with literal coercion. + + A bare Python number is typed to match the *other* (column) operand at + resolve time, so mixed-width numeric comparisons and arithmetic resolve + against the standard extension overloads. ``swap`` handles reflected + operators (e.g. ``100 - col("a")``), keeping operand order intact. + """ + left_operand = other if swap else self_expr + right_operand = self_expr if swap else other + + def resolve(base_schema, registry): + def bind(operand): + if isinstance(operand, Expr): + return operand._unbound(base_schema, registry), True + return operand, False + + left_val, left_is_expr = bind(left_operand) + right_val, right_is_expr = bind(right_operand) + + peer = None + if left_is_expr: + peer = infer_extended_expression_schema(left_val).types[0] + elif right_is_expr: + peer = infer_extended_expression_schema(right_val).types[0] + + def as_bound(value, is_expr): + if is_expr: + return value + if not isinstance(value, bool) and isinstance(value, (int, float)) and peer: + lit_type = _match_numeric_type(peer, value) + return literal(value, lit_type)(base_schema, registry) + return Expr._coerce(value)._unbound(base_schema, registry) + + left_bound = as_bound(left_val, left_is_expr) + right_bound = as_bound(right_val, right_is_expr) + return scalar_function(urn, fn, expressions=[left_bound, right_bound])( + base_schema, registry + ) + + return Expr(resolve) + + +class Expr: + """A composable, unbound Substrait expression.""" + + __slots__ = ("_unbound",) + + def __init__(self, unbound: UnboundExtendedExpression): + self._unbound = unbound + + @property + def unbound(self) -> UnboundExtendedExpression: + """The underlying builder callable, for interop with the builder layer.""" + return self._unbound + + @staticmethod + def _coerce(value: Union["Expr", Any]) -> "Expr": + if isinstance(value, Expr): + return value + return Expr(literal(value, infer_literal_type(value))) + + def _scalar(self, urn: str, fn: str, *others: Any) -> "Expr": + args = [self._unbound] + [Expr._coerce(o)._unbound for o in others] + return Expr(scalar_function(urn, fn, expressions=args)) + + # -- comparison ------------------------------------------------------- + def __lt__(self, other: Any) -> "Expr": + return _numeric_binary(self, other, FUNCTIONS_COMPARISON, "lt") + + def __le__(self, other: Any) -> "Expr": + return _numeric_binary(self, other, FUNCTIONS_COMPARISON, "lte") + + def __gt__(self, other: Any) -> "Expr": + return _numeric_binary(self, other, FUNCTIONS_COMPARISON, "gt") + + def __ge__(self, other: Any) -> "Expr": + return _numeric_binary(self, other, FUNCTIONS_COMPARISON, "gte") + + def __eq__(self, other: Any) -> "Expr": # type: ignore[override] + return _numeric_binary(self, other, FUNCTIONS_COMPARISON, "equal") + + def __ne__(self, other: Any) -> "Expr": # type: ignore[override] + return _numeric_binary(self, other, FUNCTIONS_COMPARISON, "not_equal") + + # Operator-overloaded ``__eq__`` means an Expr is not a normal value; like + # pandas/Polars expressions it is intentionally not hashable. + __hash__ = None # type: ignore[assignment] + + # -- arithmetic ------------------------------------------------------- + def __add__(self, other: Any) -> "Expr": + return _numeric_binary(self, other, FUNCTIONS_ARITHMETIC, "add") + + def __sub__(self, other: Any) -> "Expr": + return _numeric_binary(self, other, FUNCTIONS_ARITHMETIC, "subtract") + + def __mul__(self, other: Any) -> "Expr": + return _numeric_binary(self, other, FUNCTIONS_ARITHMETIC, "multiply") + + def __truediv__(self, other: Any) -> "Expr": + return _numeric_binary(self, other, FUNCTIONS_ARITHMETIC, "divide") + + def __radd__(self, other: Any) -> "Expr": + return _numeric_binary(self, other, FUNCTIONS_ARITHMETIC, "add", swap=True) + + def __rsub__(self, other: Any) -> "Expr": + return _numeric_binary(self, other, FUNCTIONS_ARITHMETIC, "subtract", swap=True) + + def __rmul__(self, other: Any) -> "Expr": + return _numeric_binary(self, other, FUNCTIONS_ARITHMETIC, "multiply", swap=True) + + def __rtruediv__(self, other: Any) -> "Expr": + return _numeric_binary(self, other, FUNCTIONS_ARITHMETIC, "divide", swap=True) + + def __neg__(self) -> "Expr": + return Expr( + scalar_function(FUNCTIONS_ARITHMETIC, "negate", expressions=[self._unbound]) + ) + + # -- boolean logic ---------------------------------------------------- + def __and__(self, other: Any) -> "Expr": + return self._scalar(FUNCTIONS_BOOLEAN, "and", other) + + def __or__(self, other: Any) -> "Expr": + return self._scalar(FUNCTIONS_BOOLEAN, "or", other) + + def __invert__(self) -> "Expr": + return Expr( + scalar_function(FUNCTIONS_BOOLEAN, "not", expressions=[self._unbound]) + ) + + # -- helpers ---------------------------------------------------------- + def is_null(self) -> "Expr": + return Expr( + scalar_function( + FUNCTIONS_COMPARISON, "is_null", expressions=[self._unbound] + ) + ) + + def is_not_null(self) -> "Expr": + return Expr( + scalar_function( + FUNCTIONS_COMPARISON, "is_not_null", expressions=[self._unbound] + ) + ) + + def cast(self, type: Any) -> "Expr": + """Cast this expression to ``type`` (a proto.Type or a type builder). + + The explicit escape hatch when automatic literal coercion is not enough, + e.g. between two columns of different numeric types:: + + col("small_i32").cast(sub.i64) + col("big_i64") + """ + if callable(type): # allow a bare builder / DataType, e.g. sub.i64 + type = type() + return Expr(cast(self._unbound, type)) + + def alias(self, name: str) -> "Expr": + """Return a copy of this expression with its output name set to ``name``.""" + inner = self._unbound + + def resolve(base_schema, registry): + bound = inner(base_schema, registry) + bound.referred_expr[0].output_names[0] = name + return bound + + return Expr(resolve) + + def __repr__(self) -> str: # pragma: no cover - debugging aid + return "Expr()" + + +def col(name: Union[str, int]) -> Expr: + """Reference an input column by name or index.""" + return Expr(column(name)) + + +def lit(value: Any, type: Union[stp.Type, None] = None) -> Expr: + """A literal expression. The Substrait type is inferred when omitted. + + Pass ``value=None`` to build a typed null; a ``type`` is required in that + case since there is nothing to infer from. + """ + if type is None: + if value is None: + raise TypeError("lit(None) needs an explicit type, e.g. lit(None, sub.i64)") + type = infer_literal_type(value) + elif callable(type): # allow passing a bare type builder, e.g. sub.i64 + type = type() + return Expr(literal(value, type)) diff --git a/src/substrait/frame.py b/src/substrait/frame.py new file mode 100644 index 0000000..7beb1fc --- /dev/null +++ b/src/substrait/frame.py @@ -0,0 +1,229 @@ +"""The Substrait-native DataFrame. + +This module is the **native** fluent frame -- the primary, engine-agnostic way +to build a Substrait plan in Python (analogous to how ``daft.DataFrame`` is +Daft's own native frame). It is a thin, chainable wrapper over the +``substrait.builders.plan`` functions: it carries an ``ExtensionRegistry`` so it +does not have to be threaded through every call, and it takes +:class:`~substrait.expr.Expr` objects (or bare column names / Python scalars) +rather than raw ``scalar_function`` invocations:: + + import substrait.api as sub + + plan = ( + sub.read_named_table("people", {"id": sub.i64, "age": sub.i64}) + .filter(sub.col("age") > 25) + .select("id") + .to_plan() + ) + +Verb naming follows Polars: ``select`` replaces the projection, ``with_columns`` +appends. + +Relationship to :mod:`substrait.narwhals`: that module is the **Narwhals +integration layer** -- a compliant wrapper that lets ``narwhals`` drive plan +construction (``nw.from_native(...)``). It adapts Narwhals calls down onto this +native frame; the two layers compose rather than compete. +""" + +from __future__ import annotations + +from typing import Any, Iterable, Optional, Union + +import substrait.algebra_pb2 as stalg +import substrait.type_pb2 as stp + +from substrait.builders import plan as _plan +from substrait.builders import type as _type +from substrait.expr import Expr, col, lit +from substrait.extension_registry import ExtensionRegistry + +_JOIN_TYPES = { + "inner": stalg.JoinRel.JOIN_TYPE_INNER, + "left": stalg.JoinRel.JOIN_TYPE_LEFT, + "right": stalg.JoinRel.JOIN_TYPE_RIGHT, + "outer": stalg.JoinRel.JOIN_TYPE_OUTER, + "left_semi": stalg.JoinRel.JOIN_TYPE_LEFT_SEMI, + "left_anti": stalg.JoinRel.JOIN_TYPE_LEFT_ANTI, +} + +_default_registry: Optional[ExtensionRegistry] = None + + +def default_registry() -> ExtensionRegistry: + """A lazily-created registry preloaded with the standard extensions.""" + global _default_registry + if _default_registry is None: + _default_registry = ExtensionRegistry(load_default_extensions=True) + return _default_registry + + +def _to_named_struct(schema: Any) -> stp.NamedStruct: + if isinstance(schema, stp.NamedStruct): + return schema + if isinstance(schema, dict): + names = list(schema.keys()) + types = [t() if callable(t) else t for t in schema.values()] + return _type.named_struct( + names=names, struct=_type.struct(types=types, nullable=False) + ) + raise TypeError( + "schema must be a NamedStruct or a {name: type} dict, " + f"got {type(schema).__name__}" + ) + + +def _unbound(value: Any): + """Accept an Expr, a bare column name, or an existing unbound callable.""" + if isinstance(value, Expr): + return value.unbound + if isinstance(value, str): + return col(value).unbound + return value # assume already an unbound expression callable + + +class DataFrame: + """The Substrait-native fluent DataFrame. + + Build plans directly (``df.filter(...).select(...).to_plan()``). For the + Narwhals-driven equivalent, see :class:`substrait.narwhals.DataFrame`, which + wraps this frame to satisfy the Narwhals backend protocol. + """ + + def __init__(self, plan, registry: Optional[ExtensionRegistry] = None): + self._plan = plan + self._registry = registry or default_registry() + + def _next(self, plan) -> "DataFrame": + return DataFrame(plan, self._registry) + + @property + def f(self): + """Function namespace bound to this DataFrame's registry. + + Use this instead of the global ``sub.f`` when the DataFrame was built + with a registry carrying custom extensions, so those functions are + reachable by name (e.g. ``df.f.my_double(df_col)``). + """ + cached = getattr(self, "_functions_ns", None) + if cached is None: + from substrait.functions import functions_for + + cached = functions_for(self._registry) + self._functions_ns = cached + return cached + + def filter(self, predicate: Union[Expr, Any]) -> "DataFrame": + return self._next(_plan.filter(self._plan, expression=_unbound(predicate))) + + def select(self, *columns: Union[str, Expr]) -> "DataFrame": + return self._next( + _plan.select(self._plan, expressions=[_unbound(c) for c in columns]) + ) + + def with_columns( + self, *exprs: Union[str, Expr], **named: Union[Expr, Any] + ) -> "DataFrame": + expressions = [_unbound(e) for e in exprs] + expressions += [Expr._coerce(v).alias(k).unbound for k, v in named.items()] + return self._next(_plan.project(self._plan, expressions=expressions)) + + def sort(self, *columns: Union[str, Expr], descending: bool = False) -> "DataFrame": + direction = ( + stalg.SortField.SORT_DIRECTION_DESC_NULLS_LAST + if descending + else stalg.SortField.SORT_DIRECTION_ASC_NULLS_LAST + ) + expressions = [(_unbound(c), direction) for c in columns] + return self._next(_plan.sort(self._plan, expressions=expressions)) + + def limit(self, n: int, offset: int = 0) -> "DataFrame": + return self._next( + _plan.fetch( + self._plan, + offset=lit(offset, _type.i64()).unbound, + count=lit(n, _type.i64()).unbound, + ) + ) + + def join( + self, + other: "DataFrame", + on: Union[Expr, Any], + how: str = "inner", + ) -> "DataFrame": + """Join with another DataFrame. + + ``on`` is an expression evaluated against the concatenation of the left + and right schemas (columns are referenced by name across both inputs). + ``how`` is one of ``inner``, ``left``, ``right``, ``outer``, + ``left_semi`` or ``left_anti``. + """ + try: + join_type = _JOIN_TYPES[how] + except KeyError: + raise ValueError( + f"unknown join type {how!r}; expected one of {sorted(_JOIN_TYPES)}" + ) from None + return self._next( + _plan.join(self._plan, other._plan, expression=_unbound(on), type=join_type) + ) + + def group_by(self, *keys: Union[str, Expr]) -> "GroupBy": + """Begin an aggregation; follow with ``.agg(...)``.""" + return GroupBy(self, keys) + + def aggregate( + self, + group_by: Union[str, Expr, Iterable[Union[str, Expr]]] = (), + *measures: Expr, + ) -> "DataFrame": + """One-shot aggregation. See also the fluent ``group_by().agg()``.""" + if isinstance(group_by, (str, Expr)): + group_by = [group_by] + return self._next( + _plan.aggregate( + self._plan, + grouping_expressions=[_unbound(g) for g in group_by], + measures=[_unbound(m) for m in measures], + ) + ) + + def to_plan(self): + """Materialize to a ``substrait.proto.Plan``.""" + return self._plan(self._registry) + + # Kept for parity with the substrait.narwhals (Narwhals) wrapper's API. + def to_substrait(self, registry: Optional[ExtensionRegistry] = None): + return self._plan(registry or self._registry) + + +class GroupBy: + """Intermediate returned by ``DataFrame.group_by``; call ``.agg(...)``.""" + + def __init__(self, df: DataFrame, keys: Iterable[Union[str, Expr]]): + self._df = df + self._keys = list(keys) + + def agg(self, *measures: Expr) -> DataFrame: + return self._df._next( + _plan.aggregate( + self._df._plan, + grouping_expressions=[_unbound(k) for k in self._keys], + measures=[_unbound(m) for m in measures], + ) + ) + + +def read_named_table( + name: Union[str, Iterable[str]], + schema: Any, + registry: Optional[ExtensionRegistry] = None, +) -> DataFrame: + """Start a DataFrame from a named table and its schema. + + ``schema`` may be a ``NamedStruct`` or a ``{column_name: type}`` dict, where + each type is a type builder (``sub.i64``) or a ``proto.Type``. + """ + names = [name] if isinstance(name, str) else list(name) + return DataFrame(_plan.read_named_table(names, _to_named_struct(schema)), registry) diff --git a/src/substrait/functions.py b/src/substrait/functions.py new file mode 100644 index 0000000..b46c681 --- /dev/null +++ b/src/substrait/functions.py @@ -0,0 +1,175 @@ +"""Named function helpers, generated from the loaded extension registry. + +``f`` is a namespace covering *every* function defined by the Substrait default +extensions -- scalar, aggregate and window -- so anything the specification +ships is reachable by name and hides the extension-URN / signature plumbing:: + + import substrait.api as sub + + sub.f.sum(sub.col("amount")) + sub.f.substring(sub.col("name"), 1, 3) + sub.f.coalesce(sub.col("a"), sub.col("b")) + sub.f.row_number() + +Each helper returns an :class:`~substrait.expr.Expr`. The namespace is built +lazily on first attribute access from :func:`substrait.frame.default_registry`, +and supports ``dir(sub.f)`` for discovery/tab-completion. + +Some function names appear in more than one extension (e.g. ``add`` in +``functions_arithmetic``, ``functions_arithmetic_decimal`` and +``functions_datetime``). For those, the correct extension is chosen at resolve +time from the actual argument types, preferring the base extension over its +``decimal``/``approx`` variants. The three names that are Python keywords +(``and``/``or``/``not``) are exposed as ``and_``/``or_``/``not_`` (and remain +reachable via ``getattr(sub.f, "and")``). + +Note: operators (``+``, ``>``, ...) coerce bare Python literals to the peer +column's type; the explicit ``f.*`` helpers do not, so pass typed operands +(``2.0``, ``sub.lit(...)``) or a column when a specific overload is required. +""" + +from __future__ import annotations + +import keyword +from collections import defaultdict +from typing import Any + +from substrait.builders.extended_expression import ( + aggregate_function, + resolve_expression, + scalar_function, + window_function, +) +from substrait.expr import Expr +from substrait.extension_registry.function_entry import FunctionType +from substrait.type_inference import infer_extended_expression_schema + +_BUILDERS = { + FunctionType.SCALAR: scalar_function, + FunctionType.AGGREGATE: aggregate_function, + FunctionType.WINDOW: window_function, +} + + +def _safe_name(name: str) -> str: + return f"{name}_" if keyword.iskeyword(name) else name + + +def _urn_priority(urn: str) -> int: + """Rank base extensions ahead of their decimal/approx variants.""" + tail = urn.rsplit(":", 1)[-1] + return (2 if "approx" in tail else 0) + (1 if "decimal" in tail else 0) + + +def _single_urn_helper(builder, urn: str, name: str): + def helper(*args: Any, alias: str | None = None) -> Expr: + exprs = [Expr._coerce(a).unbound for a in args] + return Expr(builder(urn, name, expressions=exprs, alias=alias)) + + return helper + + +def _multi_urn_helper(builder, urns: list[str], name: str): + def helper(*args: Any, alias: str | None = None) -> Expr: + exprs = [Expr._coerce(a).unbound for a in args] + + def resolve(base_schema, registry): + bound = [resolve_expression(e, base_schema, registry) for e in exprs] + signature = [ + typ for b in bound for typ in infer_extended_expression_schema(b).types + ] + for urn in urns: + if registry.lookup_function(urn, name, signature): + return builder(urn, name, expressions=bound, alias=alias)( + base_schema, registry + ) + kinds = [t.WhichOneof("kind") for t in signature] + raise Exception( + f"No matching overload for '{name}' across {urns} " + f"with signature {kinds}" + ) + + return Expr(resolve) + + return helper + + +def _build_functions(registry) -> dict: + grouped: dict = defaultdict(lambda: [None, []]) # name -> [function_type, urns] + for urn, name, ftype in registry.iter_functions(): + grouped[name][0] = ftype + grouped[name][1].append(urn) + + fns: dict = {} + for name, (ftype, urns) in grouped.items(): + builder = _BUILDERS[ftype] + urns = sorted(urns, key=lambda u: (_urn_priority(u), urns.index(u))) + if len(urns) == 1: + helper = _single_urn_helper(builder, urns[0], name) + else: + helper = _multi_urn_helper(builder, urns, name) + helper.__name__ = _safe_name(name) + helper.__doc__ = ( + f"Substrait {ftype.value} function '{name}' " + f"(extensions: {', '.join(urns)})." + ) + key = _safe_name(name) + fns[key] = helper + if key != name: # keep the raw keyword name reachable via getattr + fns[name] = helper + return fns + + +class _FunctionNamespace: + """Lazily-populated namespace of a registry's functions. + + With no registry it enumerates the default extensions; pass a registry (see + :func:`functions_for`) to expose custom extensions registered on it too. + """ + + def __init__(self, registry=None): + object.__setattr__(self, "_registry", registry) + object.__setattr__(self, "_fns", None) + + def _ensure(self) -> dict: + if self._fns is None: + registry = self._registry + if registry is None: + from substrait.frame import default_registry + + registry = default_registry() + object.__setattr__(self, "_fns", _build_functions(registry)) + return self._fns + + def __getattr__(self, item: str): + if item.startswith("__") and item.endswith("__"): + raise AttributeError(item) + fns = self._ensure() + try: + return fns[item] + except KeyError: + raise AttributeError(f"no Substrait function named {item!r}") from None + + def __contains__(self, item: str) -> bool: + return item in self._ensure() + + def __dir__(self): + return sorted(self._ensure()) + + +def functions_for(registry) -> _FunctionNamespace: + """A function namespace bound to ``registry``. + + Unlike the global ``f`` (which only knows the default extensions), this + surfaces every function on ``registry`` -- including custom extensions + registered via ``register_extension_yaml`` / ``register_extension_dict``:: + + reg = ExtensionRegistry(load_default_extensions=True) + reg.register_extension_yaml("my_functions.yaml") + myf = sub.functions_for(reg) + myf.my_double(sub.col("x")) + """ + return _FunctionNamespace(registry) + + +f = _FunctionNamespace() diff --git a/tests/api/test_dtypes.py b/tests/api/test_dtypes.py new file mode 100644 index 0000000..ac524a9 --- /dev/null +++ b/tests/api/test_dtypes.py @@ -0,0 +1,233 @@ +"""Tests for nullability control (substrait.dtypes) and literal coercion.""" + +import pytest +import substrait.type_pb2 as stt + +import substrait.api as sub +from substrait.builders.plan import read_named_table as b_read +from substrait.builders.plan import select +from substrait.builders.type import fp64, i32, i64, named_struct, string, struct +from substrait.dtypes import DataType +from substrait.expr import col +from substrait.extension_registry import ExtensionRegistry + +registry = ExtensionRegistry(load_default_extensions=True) + +REQUIRED = stt.Type.NULLABILITY_REQUIRED +NULLABLE = stt.Type.NULLABILITY_NULLABLE + + +# --------------------------------------------------------------------------- +# #1 nullability control +# --------------------------------------------------------------------------- + + +def test_datatype_is_callable_and_defaults_nullable(): + assert sub.i64().i64.nullability == NULLABLE + assert sub.i64(nullable=False).i64.nullability == REQUIRED + + +def test_datatype_nullable_and_non_null_properties(): + assert sub.i64.nullable.i64.nullability == NULLABLE + assert sub.i64.non_null.i64.nullability == REQUIRED + + +def test_bare_datatype_in_schema_is_nullable(): + plan = sub.read_named_table("t", {"id": sub.i64}).to_plan() + schema = plan.relations[-1].root.input.read.base_schema + assert schema.struct.types[0].i64.nullability == NULLABLE + + +def test_non_null_datatype_in_schema_is_required(): + plan = sub.read_named_table("t", {"id": sub.i64.non_null}).to_plan() + schema = plan.relations[-1].root.input.read.base_schema + assert schema.struct.types[0].i64.nullability == REQUIRED + + +def test_mixed_nullability_schema_matches_explicit_builder(): + fluent = sub.read_named_table( + "t", {"id": sub.i64.non_null, "name": sub.string} + ).to_plan() + explicit = b_read( + "t", + named_struct( + names=["id", "name"], + struct=struct(types=[i64(nullable=False), string()], nullable=False), + ), + )(registry) + assert fluent.SerializeToString() == explicit.SerializeToString() + + +def test_datatype_repr(): + assert repr(sub.i64) == "" + + +def test_datatype_exported(): + assert isinstance(sub.i64, DataType) + + +# --------------------------------------------------------------------------- +# Full Substrait type-system coverage +# --------------------------------------------------------------------------- + +# proto Type kinds intentionally NOT surfaced on the ergonomic facade: +# - deprecated in favor of the precision_* variants +# - not concrete data types / advanced extension machinery +_EXCLUDED_KINDS = { + "timestamp", + "time", + "timestamp_tz", + "func", + "user_defined", + "user_defined_type_reference", + "alias", +} +# proto kind -> name exported on substrait.api +_KIND_TO_API = {"bool": "boolean", "varchar": "varchar", "list": "list_", "map": "map_"} + + +def _proto_type_kinds(): + return [f.name for f in stt.Type.DESCRIPTOR.fields] + + +def test_every_concrete_type_is_reachable_on_api(): + missing = [] + for kind in _proto_type_kinds(): + if kind in _EXCLUDED_KINDS: + continue + name = _KIND_TO_API.get(kind, kind) + if name not in sub.__all__: + missing.append(kind) + assert missing == [], f"Substrait types not exposed on substrait.api: {missing}" + + +def test_no_arg_types_are_datatypes_with_nullability(): + for dt in (sub.uuid, sub.interval_year): + assert isinstance(dt, DataType) + assert dt.non_null.WhichOneof("kind") in ("uuid", "interval_year") + + +@pytest.mark.parametrize( + "typ, expected_kind", + [ + (sub.uuid.non_null, "uuid"), + (sub.interval_year.non_null, "interval_year"), + (sub.interval_day(6), "interval_day"), + (sub.interval_compound(6), "interval_compound"), + (sub.fixed_char(10), "fixed_char"), + (sub.varchar(10), "varchar"), + (sub.fixed_binary(16), "fixed_binary"), + (sub.decimal(38, 10), "decimal"), + (sub.precision_time(6), "precision_time"), + (sub.precision_timestamp(6), "precision_timestamp"), + (sub.precision_timestamp_tz(6), "precision_timestamp_tz"), + ], +) +def test_parametrized_types_build_expected_kind(typ, expected_kind): + assert typ.WhichOneof("kind") == expected_kind + + +def test_parametrized_type_usable_in_schema(): + # A decimal + varchar schema round-trips through read_named_table. + plan = sub.read_named_table( + "t", {"price": sub.decimal(38, 10), "code": sub.varchar(8)} + ).to_plan() + schema = plan.relations[-1].root.input.read.base_schema + kinds = [t.WhichOneof("kind") for t in schema.struct.types] + assert kinds == ["decimal", "varchar"] + + +# --------------------------------------------------------------------------- +# #2 literal coercion + cast +# --------------------------------------------------------------------------- + + +def _project_expr(ns, expr): + return select(b_read("t", ns), expressions=[expr.unbound])(registry) + + +def test_fp64_times_int_literal_resolves_to_fp64(): + ns = named_struct( + names=["price"], struct=struct(types=[fp64(nullable=False)], nullable=False) + ) + plan = _project_expr(ns, col("price") * 2) + fn = plan.relations[-1].root.input.project.expressions[0].scalar_function + # The int literal was coerced to fp64 so multiply:fp64_fp64 resolves. + assert fn.output_type.WhichOneof("kind") == "fp64" + assert fn.arguments[1].value.literal.WhichOneof("literal_type") == "fp64" + + +def test_i32_compared_to_int_literal_resolves(): + ns = named_struct( + names=["n"], struct=struct(types=[i32(nullable=False)], nullable=False) + ) + # Without coercion this would try gt:i32_i64 and fail to resolve. + plan = _project_expr(ns, col("n") > 25) + fn = plan.relations[-1].root.input.project.expressions[0].scalar_function + assert fn.arguments[1].value.literal.WhichOneof("literal_type") == "i32" + + +def test_float_literal_not_narrowed_to_int_column(): + ns = named_struct( + names=["n"], struct=struct(types=[i64(nullable=False)], nullable=False) + ) + # A float literal must NOT be narrowed to the integer column type; because + # Substrait has no multiply:i64_fp64 this raises rather than silently + # losing the fractional part. The user casts the column to bridge it. + with pytest.raises(Exception, match="fp64"): + _project_expr(ns, col("n") * 1.5) + # Casting the column resolves it as fp64_fp64. + plan = _project_expr(ns, col("n").cast(sub.fp64) * 1.5) + fn = plan.relations[-1].root.input.project.expressions[0].scalar_function + assert fn.output_type.WhichOneof("kind") == "fp64" + + +def test_i64_column_gt_int_literal_unchanged(): + # Backwards-compatible: the common i64 > int case is still i64_i64. + ns = named_struct( + names=["age"], struct=struct(types=[i64(nullable=False)], nullable=False) + ) + plan = _project_expr(ns, col("age") > 25) + fn = plan.relations[-1].root.input.project.expressions[0].scalar_function + assert fn.arguments[1].value.literal.WhichOneof("literal_type") == "i64" + + +def test_reflected_operator_coerces_literal(): + ns = named_struct( + names=["price"], struct=struct(types=[fp64(nullable=False)], nullable=False) + ) + plan = _project_expr(ns, 100 - col("price")) + fn = plan.relations[-1].root.input.project.expressions[0].scalar_function + # literal on the left, coerced to fp64, operand order preserved. + assert fn.arguments[0].value.literal.WhichOneof("literal_type") == "fp64" + assert fn.arguments[1].value.HasField("selection") + + +def test_cast_bridges_two_column_types(): + ns = named_struct( + names=["a", "b"], + struct=struct(types=[i32(nullable=False), i64(nullable=False)], nullable=False), + ) + # i32 + i64 does not resolve directly; cast makes it i64 + i64. + plan = _project_expr(ns, col("a").cast(sub.i64) + col("b")) + add_fn = plan.relations[-1].root.input.project.expressions[0].scalar_function + assert add_fn.arguments[0].value.HasField("cast") + + +def test_cast_accepts_proto_type_and_builder(): + ns = named_struct( + names=["a"], struct=struct(types=[i32(nullable=False)], nullable=False) + ) + from_builder = _project_expr(ns, col("a").cast(sub.i64)) + from_proto = _project_expr(ns, col("a").cast(i64())) + assert from_builder.SerializeToString() == from_proto.SerializeToString() + + +def test_two_column_mismatch_still_raises_without_cast(): + ns = named_struct( + names=["a", "b"], + struct=struct(types=[i32(nullable=False), i64(nullable=False)], nullable=False), + ) + # Coercion only applies to literals, not between two columns. + with pytest.raises(Exception): + _project_expr(ns, col("a") + col("b")) diff --git a/tests/api/test_expr.py b/tests/api/test_expr.py new file mode 100644 index 0000000..b553d8e --- /dev/null +++ b/tests/api/test_expr.py @@ -0,0 +1,176 @@ +"""Tests for the ergonomic Expr wrapper (substrait.expr). + +The central contract: an operator expression must produce the *same* proto as +the equivalent hand-written scalar_function builder call. +""" + +import pytest + +from substrait.builders.extended_expression import column, literal, scalar_function +from substrait.builders.plan import read_named_table, select +from substrait.builders.type import fp64, i64, named_struct, string, struct +from substrait.expr import ( + FUNCTIONS_ARITHMETIC, + FUNCTIONS_BOOLEAN, + FUNCTIONS_COMPARISON, + col, + infer_literal_type, + lit, +) +from substrait.extension_registry import ExtensionRegistry + +registry = ExtensionRegistry(load_default_extensions=True) + +schema = named_struct( + names=["a", "b", "flag"], + struct=struct( + types=[i64(nullable=False), i64(nullable=False), i64()], nullable=False + ), +) + + +def _plan_from(unbound_expr): + """Materialize a single expression by projecting it over `schema`.""" + return select(read_named_table("t", schema), expressions=[unbound_expr])(registry) + + +def _same(lhs_unbound, rhs_unbound): + return ( + _plan_from(lhs_unbound).SerializeToString() + == _plan_from(rhs_unbound).SerializeToString() + ) + + +@pytest.mark.parametrize( + "op_result, urn, fn", + [ + (col("a") < col("b"), FUNCTIONS_COMPARISON, "lt"), + (col("a") <= col("b"), FUNCTIONS_COMPARISON, "lte"), + (col("a") > col("b"), FUNCTIONS_COMPARISON, "gt"), + (col("a") >= col("b"), FUNCTIONS_COMPARISON, "gte"), + (col("a") == col("b"), FUNCTIONS_COMPARISON, "equal"), + (col("a") != col("b"), FUNCTIONS_COMPARISON, "not_equal"), + (col("a") + col("b"), FUNCTIONS_ARITHMETIC, "add"), + (col("a") - col("b"), FUNCTIONS_ARITHMETIC, "subtract"), + (col("a") * col("b"), FUNCTIONS_ARITHMETIC, "multiply"), + (col("a") / col("b"), FUNCTIONS_ARITHMETIC, "divide"), + ], +) +def test_binary_operator_matches_builder(op_result, urn, fn): + expected = scalar_function(urn, fn, expressions=[column("a"), column("b")]) + assert _same(op_result.unbound, expected) + + +def test_boolean_operators_match_builder(): + lhs = (col("a") < col("b")) & (col("a") > col("flag")) + expected = scalar_function( + FUNCTIONS_BOOLEAN, + "and", + expressions=[ + scalar_function( + FUNCTIONS_COMPARISON, "lt", expressions=[column("a"), column("b")] + ), + scalar_function( + FUNCTIONS_COMPARISON, "gt", expressions=[column("a"), column("flag")] + ), + ], + ) + assert _same(lhs.unbound, expected) + + +def test_invert_matches_not(): + expr = ~(col("a") == col("b")) + expected = scalar_function( + FUNCTIONS_BOOLEAN, + "not", + expressions=[ + scalar_function( + FUNCTIONS_COMPARISON, "equal", expressions=[column("a"), column("b")] + ) + ], + ) + assert _same(expr.unbound, expected) + + +def test_literal_autowrap_on_rhs(): + # `col("a") > 25` should wrap 25 as an i64 literal. + expr = col("a") > 25 + expected = scalar_function( + FUNCTIONS_COMPARISON, "gt", expressions=[column("a"), literal(25, i64())] + ) + assert _same(expr.unbound, expected) + + +def test_reflected_operator_puts_literal_on_left(): + expr = 100 - col("a") + expected = scalar_function( + FUNCTIONS_ARITHMETIC, "subtract", expressions=[literal(100, i64()), column("a")] + ) + assert _same(expr.unbound, expected) + + +@pytest.mark.parametrize( + "value, kind", + [ + (True, "bool"), + (5, "i64"), + (1.5, "fp64"), + ("x", "string"), + ], +) +def test_infer_literal_type(value, kind): + assert infer_literal_type(value).WhichOneof("kind") == kind + + +def test_bool_inferred_before_int(): + # isinstance(True, int) is True -- make sure bool wins. + assert infer_literal_type(True).WhichOneof("kind") == "bool" + + +def test_infer_literal_type_rejects_unknown(): + with pytest.raises(TypeError): + infer_literal_type(object()) + + +def test_lit_accepts_bare_type_builder(): + # sub.i64 is a callable builder; lit should call it. + expr = lit(5, i64) + expected = literal(5, i64()) + assert _same(expr.unbound, expected) + + +def test_expr_is_not_hashable(): + # __eq__ is overloaded to build an expression, so Expr is unhashable. + with pytest.raises(TypeError): + hash(col("a")) + + +def test_alias_sets_output_name(): + expr = (col("a") + col("b")).alias("total") + bound = expr.unbound(schema, registry) + assert bound.referred_expr[0].output_names[0] == "total" + + +def test_is_null_builds_comparison_function(): + ns = named_struct(names=["x"], struct=struct(types=[string()], nullable=False)) + plan = select(read_named_table("t", ns), expressions=[col("x").is_null().unbound])( + registry + ) + # is_null resolves against functions_comparison and yields a boolean output. + root = plan.relations[-1].root.input + assert ( + root.project.expressions[0].scalar_function.output_type.WhichOneof("kind") + == "bool" + ) + + +def test_arithmetic_overload_resolves_and_types_output(): + ns = named_struct( + names=["price"], struct=struct(types=[fp64(nullable=False)], nullable=False) + ) + # fp64 column * fp64 literal -> multiply overload resolves to an fp64 output. + plan = select( + read_named_table("t", ns), expressions=[(col("price") * 2.0).unbound] + )(registry) + fn = plan.relations[-1].root.input.project.expressions[0].scalar_function + assert fn.output_type.WhichOneof("kind") == "fp64" diff --git a/tests/api/test_frame.py b/tests/api/test_frame.py new file mode 100644 index 0000000..49a9702 --- /dev/null +++ b/tests/api/test_frame.py @@ -0,0 +1,179 @@ +"""Tests for the fluent DataFrame facade (substrait.frame / substrait.api). + +Each fluent chain is checked against the equivalent raw builder pipeline for +byte-identical protobuf output. +""" + +import pytest +import substrait.algebra_pb2 as stalg + +import substrait.api as sub +from substrait.builders.extended_expression import ( + aggregate_function, + column, + literal, + scalar_function, +) +from substrait.builders.plan import aggregate as b_aggregate +from substrait.builders.plan import fetch as b_fetch +from substrait.builders.plan import filter as b_filter +from substrait.builders.plan import join as b_join +from substrait.builders.plan import read_named_table as b_read +from substrait.builders.plan import select as b_select +from substrait.builders.plan import sort as b_sort +from substrait.builders.type import fp64, i64, named_struct, string, struct +from substrait.extension_registry import ExtensionRegistry + +registry = ExtensionRegistry(load_default_extensions=True) + +COMPARISON = "extension:io.substrait:functions_comparison" +ARITHMETIC = "extension:io.substrait:functions_arithmetic" + + +def people_ns(): + # Matches the {name: sub.} dict form, whose columns default to nullable. + return named_struct( + names=["id", "age", "name"], + struct=struct(types=[i64(), i64(), string()], nullable=False), + ) + + +def people_df(): + return sub.read_named_table( + "people", {"id": sub.i64, "age": sub.i64, "name": sub.string} + ) + + +def test_schema_dict_matches_named_struct(): + # A {name: type} dict must build the same NamedStruct as the explicit form. + from_dict = sub.read_named_table( + "people", {"id": sub.i64, "age": sub.i64, "name": sub.string} + ).to_plan() + explicit = b_read("people", people_ns())(registry) + assert from_dict.SerializeToString() == explicit.SerializeToString() + + +def test_filter_select_matches_builder(): + fluent = people_df().filter(sub.col("age") > 25).select("id").to_plan() + + raw = b_select( + b_filter( + b_read("people", people_ns()), + expression=scalar_function( + COMPARISON, "gt", expressions=[column("age"), literal(25, i64())] + ), + ), + expressions=[column("id")], + )(registry) + + assert fluent.SerializeToString() == raw.SerializeToString() + + +def test_with_columns_named_appends_projection(): + fluent = people_df().with_columns(bonus=sub.col("age") + 1).to_plan() + # ProjectRel appends: output has original columns + the new one. + root = fluent.relations[-1].root.input + assert root.HasField("project") + assert len(root.project.expressions) == 1 # the appended bonus expression + + +def test_sort_descending_matches_builder(): + fluent = people_df().sort("age", descending=True).to_plan() + raw = b_sort( + b_read("people", people_ns()), + expressions=[(column("age"), stalg.SortField.SORT_DIRECTION_DESC_NULLS_LAST)], + )(registry) + assert fluent.SerializeToString() == raw.SerializeToString() + + +def test_limit_matches_builder_fetch(): + fluent = people_df().limit(5).to_plan() + raw = b_fetch( + b_read("people", people_ns()), + offset=literal(0, i64()), + count=literal(5, i64()), + )(registry) + assert fluent.SerializeToString() == raw.SerializeToString() + + +def test_join_matches_builder(): + left_ns = named_struct( + names=["cust_id", "name"], + struct=struct(types=[i64(), string()], nullable=False), + ) + right_ns = named_struct( + names=["order_id", "cust_ref", "amount"], + struct=struct(types=[i64(), i64(), fp64()], nullable=False), + ) + + left = sub.read_named_table("customers", {"cust_id": sub.i64, "name": sub.string}) + right = sub.read_named_table( + "orders", {"order_id": sub.i64, "cust_ref": sub.i64, "amount": sub.fp64} + ) + fluent = left.join( + right, on=sub.col("cust_id") == sub.col("cust_ref"), how="inner" + ).to_plan() + + raw = b_join( + b_read("customers", left_ns), + b_read("orders", right_ns), + expression=scalar_function( + COMPARISON, "equal", expressions=[column("cust_id"), column("cust_ref")] + ), + type=stalg.JoinRel.JOIN_TYPE_INNER, + )(registry) + + assert fluent.SerializeToString() == raw.SerializeToString() + + +def test_join_unknown_type_raises(): + left = sub.read_named_table("a", {"x": sub.i64}) + right = sub.read_named_table("b", {"x": sub.i64}) + with pytest.raises(ValueError, match="unknown join type"): + left.join(right, on=sub.col("x") == sub.col("x"), how="banana") + + +def test_group_by_agg_matches_builder(): + ns = named_struct( + names=["region", "amount"], + struct=struct(types=[string(), fp64()], nullable=False), + ) + fluent = ( + sub.read_named_table("sales", {"region": sub.string, "amount": sub.fp64}) + .group_by("region") + .agg(sub.f.sum(sub.col("amount")).alias("total")) + .to_plan() + ) + raw = b_aggregate( + b_read("sales", ns), + grouping_expressions=[column("region")], + measures=[ + aggregate_function( + ARITHMETIC, "sum", expressions=[column("amount")], alias="total" + ) + ], + )(registry) + assert fluent.SerializeToString() == raw.SerializeToString() + + +def test_group_by_agg_equals_aggregate_oneshot(): + df = sub.read_named_table("sales", {"region": sub.string, "amount": sub.fp64}) + via_groupby = ( + df.group_by("region").agg(sub.f.sum(sub.col("amount")).alias("t")).to_plan() + ) + via_aggregate = df.aggregate( + "region", sub.f.sum(sub.col("amount")).alias("t") + ).to_plan() + assert via_groupby.SerializeToString() == via_aggregate.SerializeToString() + + +def test_default_registry_is_reused(): + assert sub.default_registry() is sub.default_registry() + + +def test_to_substrait_registry_override(): + df = people_df() + custom = ExtensionRegistry(load_default_extensions=True) + # Should not raise and should honor the explicit registry. + plan = df.filter(sub.col("age") > 25).to_substrait(registry=custom) + assert plan.relations diff --git a/tests/api/test_functions.py b/tests/api/test_functions.py new file mode 100644 index 0000000..e0847d1 --- /dev/null +++ b/tests/api/test_functions.py @@ -0,0 +1,194 @@ +"""Tests for the generated function namespace (substrait.functions).""" + +import keyword + +import pytest + +import substrait.api as sub +from substrait.builders.plan import consistent_partition_window +from substrait.builders.plan import read_named_table as b_read +from substrait.builders.type import fp64, i64, named_struct, string, struct +from substrait.extension_registry import ExtensionRegistry +from substrait.functions import _safe_name + +registry = ExtensionRegistry(load_default_extensions=True) + + +def _all_registry_names(): + return {name for _, name, _ in registry.iter_functions()} + + +def test_covers_every_default_function(): + # Every function the default extensions define must be reachable on f. + missing = [n for n in _all_registry_names() if _safe_name(n) not in dir(sub.f)] + assert missing == [], f"functions not exposed on f: {missing}" + + +def test_coverage_is_substantial(): + # Guard against a regression that silently drops most functions. + assert len(_all_registry_names()) > 150 + assert len(dir(sub.f)) >= len(_all_registry_names()) + + +def test_every_helper_is_callable(): + for name in _all_registry_names(): + assert callable(getattr(sub.f, _safe_name(name))) + + +def test_keyword_names_are_suffixed_and_raw_reachable(): + for kw in ("and", "or", "not"): + assert keyword.iskeyword(kw) + suffixed = getattr(sub.f, kw + "_") + assert callable(suffixed) + # raw keyword name still reachable via getattr + assert getattr(sub.f, kw) is suffixed + + +def test_unknown_function_raises_attributeerror(): + with pytest.raises(AttributeError, match="no Substrait function"): + sub.f.definitely_not_a_function + + +def test_dunder_access_does_not_trigger_build(): + with pytest.raises(AttributeError): + sub.f.__wrapped__ + + +# --------------------------------------------------------------------------- +# Building / resolution +# --------------------------------------------------------------------------- + + +def _urns(plan): + return {u.urn.rsplit(":", 1)[-1] for u in plan.extension_urns} + + +def test_scalar_function_builds(): + df = sub.read_named_table("t", {"name": sub.string}) + plan = df.with_columns(u=sub.f.upper(sub.col("name"))).to_plan() + assert "functions_string" in _urns(plan) + + +def test_aggregate_function_builds(): + df = sub.read_named_table("s", {"region": sub.string, "amount": sub.fp64}) + plan = df.group_by("region").agg(sub.f.avg(sub.col("amount")).alias("a")).to_plan() + assert "functions_arithmetic" in _urns(plan) + + +def test_window_function_builds(): + ns = named_struct( + names=["x"], struct=struct(types=[i64(nullable=False)], nullable=False) + ) + plan = consistent_partition_window( + b_read("t", ns), window_functions=[sub.f.row_number().unbound] + )(registry) + assert plan.relations + + +def test_collision_int_add_uses_base_arithmetic(): + df = sub.read_named_table("t", {"a": sub.i64.non_null, "b": sub.i64.non_null}) + plan = df.with_columns(s=sub.f.add(sub.col("a"), sub.col("b"))).to_plan() + assert _urns(plan) == {"functions_arithmetic"} + + +def test_collision_count_uses_generic_not_decimal(): + df = sub.read_named_table("t", {"a": sub.i64.non_null}) + plan = df.group_by().agg(sub.f.count(sub.col("a")).alias("n")).to_plan() + assert _urns(plan) == {"functions_aggregate_generic"} + + +def test_multi_urn_no_matching_overload_raises(): + df = sub.read_named_table("t", {"flag": sub.boolean}) + # add is a multi-URN function but has no boolean overload anywhere. + with pytest.raises(Exception, match="No matching overload"): + df.with_columns(s=sub.f.add(sub.col("flag"), sub.col("flag"))).to_plan() + + +def test_generated_sum_matches_raw_builder(): + from substrait.builders.extended_expression import aggregate_function, column + from substrait.builders.plan import aggregate as b_aggregate + + ns = named_struct( + names=["region", "amount"], + struct=struct(types=[string(), fp64()], nullable=False), + ) + fluent = ( + sub.read_named_table("sales", {"region": sub.string, "amount": sub.fp64}) + .group_by("region") + .agg(sub.f.sum(sub.col("amount")).alias("total")) + .to_plan() + ) + raw = b_aggregate( + b_read("sales", ns), + grouping_expressions=[column("region")], + measures=[ + aggregate_function( + "extension:io.substrait:functions_arithmetic", + "sum", + expressions=[column("amount")], + alias="total", + ) + ], + )(registry) + assert fluent.SerializeToString() == raw.SerializeToString() + + +def test_helper_has_docstring_naming_extensions(): + assert "functions_string" in sub.f.upper.__doc__ + + +# --------------------------------------------------------------------------- +# Custom / user-defined extensions +# --------------------------------------------------------------------------- + +_CUSTOM_YAML = """%YAML 1.2 +--- +urn: extension:com.acme:my_functions +scalar_functions: + - name: "my_double" + description: Double an integer + impls: + - args: + - name: x + value: i64 + return: i64 +""" + + +def _custom_registry(): + reg = ExtensionRegistry(load_default_extensions=True) + reg.register_extension_dict(__import__("yaml").safe_load(_CUSTOM_YAML)) + return reg + + +def test_functions_for_exposes_custom_extension(): + myf = sub.functions_for(_custom_registry()) + assert "my_double" in dir(myf) + assert callable(myf.my_double) + + +def test_global_f_does_not_see_custom_extension(): + # The global f is bound to the default registry only. + assert "my_double" not in dir(sub.f) + + +def test_functions_for_builds_custom_function_plan(): + reg = _custom_registry() + myf = sub.functions_for(reg) + df = sub.read_named_table("t", {"x": sub.i64.non_null}, registry=reg) + plan = df.with_columns(d=myf.my_double(sub.col("x"))).to_plan() + assert any(u.urn == "extension:com.acme:my_functions" for u in plan.extension_urns) + + +def test_dataframe_f_is_bound_to_its_registry(): + reg = _custom_registry() + df = sub.read_named_table("t", {"x": sub.i64.non_null}, registry=reg) + # df.f is the ergonomic accessor: reachable and composable with operators. + plan = df.filter(df.f.my_double(sub.col("x")) > 10).to_plan() + assert plan.relations + assert any(u.urn == "extension:com.acme:my_functions" for u in plan.extension_urns) + + +def test_dataframe_f_is_cached(): + df = sub.read_named_table("t", {"x": sub.i64.non_null}) + assert df.f is df.f diff --git a/tests/api/test_literals.py b/tests/api/test_literals.py new file mode 100644 index 0000000..3825f51 --- /dev/null +++ b/tests/api/test_literals.py @@ -0,0 +1,197 @@ +"""Tests for literal construction across every Substrait literal kind. + +The builder ``literal()`` must be able to construct a literal for every type, +and each built literal must round-trip through ``infer_literal_type`` back to the +requested type kind. +""" + +import datetime as dt +import uuid +from decimal import Decimal + +import pytest +import substrait.type_pb2 as stt + +import substrait.api as sub +from substrait.builders import type as t +from substrait.builders.extended_expression import _make_literal +from substrait.type_inference import infer_literal_type + + +def _built(value, typ): + return _make_literal(value, typ) + + +# --------------------------------------------------------------------------- +# Round-trip: built literal -> inferred type kind matches the requested kind +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "value, typ", + [ + (True, t.boolean()), + (5, t.i64()), + (1.5, t.fp64()), + ("hi", t.string()), + (b"\x00\x01", t.binary()), + (dt.date(2021, 1, 1), t.date()), + (Decimal("12.34"), t.decimal(2, 10)), + (uuid.uuid4(), t.uuid()), + (dt.datetime(2021, 1, 1, 12), t.precision_timestamp(6)), + (1_600_000_000_000_000, t.precision_timestamp(6)), + (dt.datetime(2021, 1, 1, tzinfo=dt.timezone.utc), t.precision_timestamp_tz(6)), + (dt.time(12, 30), t.precision_time(6)), + ("fixedchars", t.fixed_char(10)), + ("abc", t.var_char(8)), + (b"1234", t.fixed_binary(4)), + ((2, 6), t.interval_year()), + (dt.timedelta(days=1, seconds=30), t.interval_day(6)), + ((1, 30, 500), t.interval_day(6)), + (((1, 2), (3, 4, 5)), t.interval_compound(6)), + ([1, 2, 3], t.list(t.i64(nullable=False))), + ({"a": 1}, t.map(t.string(nullable=False), t.i64(nullable=False))), + ([1, "x"], t.struct([t.i64(nullable=False), t.string(nullable=False)])), + ], +) +def test_literal_kind_round_trips(value, typ): + lit = _built(value, typ) + assert infer_literal_type(lit).WhichOneof("kind") == typ.WhichOneof("kind") + + +def test_every_concrete_type_kind_can_build_a_literal(): + # A guard that no concrete type is left unsupported by literal(). + samples = { + "bool": (True, t.boolean()), + "i8": (1, t.i8()), + "i16": (1, t.i16()), + "i32": (1, t.i32()), + "i64": (1, t.i64()), + "fp32": (1.0, t.fp32()), + "fp64": (1.0, t.fp64()), + "string": ("x", t.string()), + "binary": (b"x", t.binary()), + "date": (dt.date(2021, 1, 1), t.date()), + "interval_year": ((1, 0), t.interval_year()), + "interval_day": ((1, 0), t.interval_day(6)), + "interval_compound": (((1, 0), (1, 0)), t.interval_compound(6)), + "fixed_char": ("x", t.fixed_char(1)), + "varchar": ("x", t.var_char(1)), + "fixed_binary": (b"x", t.fixed_binary(1)), + "decimal": (Decimal("1"), t.decimal(0, 10)), + "precision_time": (0, t.precision_time(6)), + "precision_timestamp": (0, t.precision_timestamp(6)), + "precision_timestamp_tz": (0, t.precision_timestamp_tz(6)), + "uuid": (uuid.uuid4(), t.uuid()), + "struct": ([1], t.struct([t.i64(nullable=False)])), + "list": ([1], t.list(t.i64(nullable=False))), + "map": ({"a": 1}, t.map(t.string(nullable=False), t.i64(nullable=False))), + } + for kind, (value, typ) in samples.items(): + assert typ.WhichOneof("kind") == kind + lit = _built(value, typ) + assert lit.WhichOneof("literal_type") is not None + + +# --------------------------------------------------------------------------- +# Value encodings +# --------------------------------------------------------------------------- + + +def test_decimal_encoding_is_16_byte_little_endian_unscaled(): + lit = _built(Decimal("-12.34"), t.decimal(2, 10)) + assert len(lit.decimal.value) == 16 + assert int.from_bytes(lit.decimal.value, "little", signed=True) == -1234 + assert lit.decimal.scale == 2 + assert lit.decimal.precision == 10 + + +def test_uuid_encoding_16_bytes(): + u = uuid.uuid4() + assert _built(u, t.uuid()).uuid == u.bytes + # hex string and raw bytes accepted too + assert _built(str(u), t.uuid()).uuid == u.bytes + assert _built(u.bytes, t.uuid()).uuid == u.bytes + + +def test_precision_timestamp_from_datetime_microseconds(): + lit = _built(dt.datetime(1970, 1, 1, 0, 0, 1), t.precision_timestamp(6)) + assert lit.precision_timestamp.value == 1_000_000 # 1s in microseconds + assert lit.precision_timestamp.precision == 6 + + +def test_precision_timestamp_tz_normalizes_to_utc(): + naive_utc = _built(dt.datetime(2021, 6, 1, 12, 0), t.precision_timestamp_tz(6)) + aware = _built( + dt.datetime(2021, 6, 1, 12, 0, tzinfo=dt.timezone.utc), + t.precision_timestamp_tz(6), + ) + assert naive_utc.precision_timestamp_tz.value == aware.precision_timestamp_tz.value + + +def test_interval_year_tuple_and_int(): + assert _built((2, 6), t.interval_year()).interval_year_to_month.months == 6 + assert _built(3, t.interval_year()).interval_year_to_month.years == 3 + + +def test_empty_list_and_map_use_empty_variants(): + assert _built([], t.list(t.i64())).WhichOneof("literal_type") == "empty_list" + assert ( + _built({}, t.map(t.string(), t.i64())).WhichOneof("literal_type") == "empty_map" + ) + + +def test_nested_struct_recurses(): + lit = _built( + [1, [2, 3]], + t.struct( + [t.i64(nullable=False), t.list(t.i64(nullable=False))], + ), + ) + assert lit.struct.fields[0].i64 == 1 + assert [v.i64 for v in lit.struct.fields[1].list.values] == [2, 3] + + +def test_typed_null(): + lit = _built(None, t.i64()) + assert lit.WhichOneof("literal_type") == "null" + assert lit.null.WhichOneof("kind") == "i64" + assert lit.nullable is True + + +# --------------------------------------------------------------------------- +# Ergonomic lit() inference +# --------------------------------------------------------------------------- + + +def _lit_kind(expr): + ee = expr.unbound(stt.NamedStruct(), sub.default_registry()) + return ee.referred_expr[0].expression.literal.WhichOneof("literal_type") + + +@pytest.mark.parametrize( + "value, expected", + [ + (Decimal("12.34"), "decimal"), + (uuid.uuid4(), "uuid"), + (dt.datetime(2021, 1, 1), "precision_timestamp"), + (dt.datetime(2021, 1, 1, tzinfo=dt.timezone.utc), "precision_timestamp_tz"), + (dt.date(2021, 1, 1), "date"), + (dt.time(9, 30), "precision_time"), + (b"\x00", "binary"), + ], +) +def test_lit_infers_rich_python_types(value, expected): + assert _lit_kind(sub.lit(value)) == expected + + +def test_lit_none_requires_type(): + with pytest.raises(TypeError, match="explicit type"): + sub.lit(None) + assert _lit_kind(sub.lit(None, sub.i64)) == "null" + + +def test_lit_decimal_infers_scale_and_precision(): + ee = sub.lit(Decimal("1.250")).unbound(stt.NamedStruct(), sub.default_registry()) + dec_type = infer_literal_type(ee.referred_expr[0].expression.literal).decimal + assert dec_type.scale == 3 # "1.250" has 3 fractional digits From 9cebc71932621d42aca5cecc179c1419e0ca5546 Mon Sep 17 00:00:00 2001 From: Niels Pardon Date: Thu, 2 Jul 2026 17:22:51 +0200 Subject: [PATCH 4/4] docs(examples): add api_example, consolidate Narwhals example Add examples/api_example.py demonstrating the native substrait.api. Update narwhals_example.py to the renamed substrait.narwhals and label it as the Narwhals integration example. Remove dataframe_example.py, whose direct wrapper usage is superseded by api_example.py (native) and narwhals_example.py. --- examples/api_example.py | 73 +++++++++++++++++++++++++++++++++++ examples/dataframe_example.py | 17 -------- examples/narwhals_example.py | 20 +++++----- 3 files changed, 84 insertions(+), 26 deletions(-) create mode 100644 examples/api_example.py delete mode 100644 examples/dataframe_example.py diff --git a/examples/api_example.py b/examples/api_example.py new file mode 100644 index 0000000..0d76916 --- /dev/null +++ b/examples/api_example.py @@ -0,0 +1,73 @@ +"""Example usage of the ergonomic `substrait.api` facade. + +Compare this with `builder_example.py`, which builds the same kinds of plans +with the lower-level `substrait.builders.*` functions. +""" + +import substrait.api as sub +from substrait.utils.display import pretty_print_plan + + +def filter_select_example(): + """read -> filter -> with_columns -> select, with operator expressions.""" + plan = ( + sub.read_named_table( + "people", {"id": sub.i64, "age": sub.i64, "name": sub.string} + ) + .filter((sub.col("age") > 25) & sub.col("name").is_not_null()) + .with_columns(next_year=sub.col("age") + 1) + .select("id", "name", "next_year") + .to_plan() + ) + pretty_print_plan(plan, use_colors=True) + + +def aggregate_example(): + """group_by().agg() with the named-function namespace `f`. + + Note the explicit nullability: ``region`` is required, ``amount`` nullable. + ``amount > 0`` also shows literal coercion -- the int ``0`` is typed to match + the fp64 column so the comparison overload resolves. + """ + plan = ( + sub.read_named_table( + "sales", {"region": sub.string.non_null, "amount": sub.fp64} + ) + .filter(sub.col("amount") > 0) + .group_by("region") + .agg( + sub.f.sum(sub.col("amount")).alias("total"), + sub.f.count(sub.col("amount")).alias("n"), + ) + .to_plan() + ) + pretty_print_plan(plan, use_colors=True) + + +def join_example(): + """Join two tables and project across the combined schema.""" + customers = sub.read_named_table( + "customers", {"cust_id": sub.i64, "name": sub.string} + ) + orders = sub.read_named_table( + "orders", {"order_id": sub.i64, "cust_ref": sub.i64, "amount": sub.fp64} + ) + plan = ( + customers.join( + orders, on=sub.col("cust_id") == sub.col("cust_ref"), how="inner" + ) + .select("name", "amount") + .sort("amount", descending=True) + .limit(10) + .to_plan() + ) + pretty_print_plan(plan, use_colors=True) + + +if __name__ == "__main__": + print("=== filter / with_columns / select ===") + filter_select_example() + print("\n=== group_by / agg ===") + aggregate_example() + print("\n=== join / sort / limit ===") + join_example() diff --git a/examples/dataframe_example.py b/examples/dataframe_example.py deleted file mode 100644 index ff3d0bb..0000000 --- a/examples/dataframe_example.py +++ /dev/null @@ -1,17 +0,0 @@ -import substrait.dataframe as sdf -from substrait.builders.plan import read_named_table -from substrait.builders.type import boolean, i64, named_struct, struct -from substrait.extension_registry import ExtensionRegistry - -registry = ExtensionRegistry(load_default_extensions=True) - -ns = named_struct( - names=["id", "is_applicable"], - struct=struct(types=[i64(nullable=False), boolean()], nullable=False), -) - -table = read_named_table("example_table", ns) - -frame = sdf.DataFrame(read_named_table("example_table", ns)) -frame = frame.select(sdf.col("id")) -print(frame.to_substrait(registry)) diff --git a/examples/narwhals_example.py b/examples/narwhals_example.py index 0819404..cae4e24 100644 --- a/examples/narwhals_example.py +++ b/examples/narwhals_example.py @@ -1,4 +1,10 @@ -# Install duckdb and pyarrow before running this example +# Example of the `substrait.narwhals` integration layer: drive Substrait plan +# construction through Narwhals (`nw.from_native`), so backend-agnostic Narwhals +# code compiles to a Substrait plan. +# +# For building plans directly (without Narwhals), see `api_example.py`, which +# uses the Substrait-native DataFrame in `substrait.api` / `substrait.frame`. +# # /// script # dependencies = [ # "narwhals==2.9.0", @@ -9,7 +15,7 @@ import narwhals as nw from narwhals.typing import FrameT -import substrait.dataframe as sdf +import substrait.narwhals as sn from substrait.builders.plan import read_named_table from substrait.builders.type import boolean, i64, named_struct, struct from substrait.extension_registry import ExtensionRegistry @@ -21,15 +27,11 @@ struct=struct(types=[i64(nullable=False), boolean()], nullable=False), ) -table = read_named_table("example_table", ns) - - -lazy_frame: FrameT = nw.from_native( - sdf.DataFrame(read_named_table("example_table", ns)) -) +# Wrap the Substrait Narwhals backend and drive it with the Narwhals API. +lazy_frame: FrameT = nw.from_native(sn.DataFrame(read_named_table("example_table", ns))) lazy_frame = lazy_frame.select(nw.col("id").abs(), new_id=nw.col("id")) -df: sdf.DataFrame = lazy_frame.to_native() +df: sn.DataFrame = lazy_frame.to_native() print(df.to_substrait(registry))