From b57325fc357957b6f97c0df22b4a35fee7e2f7c1 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sun, 17 May 2026 12:36:46 -0400 Subject: [PATCH 1/4] feat: pass calling SessionContext to Python UDTF callbacks MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit DataFusion 53 added `TableFunctionImpl::call_with_args(TableFunctionArgs)` where `TableFunctionArgs` carries both the positional expression arguments and the calling `&dyn Session`. The pure-Python UDTF path previously discarded everything but the exprs. Thread the session through when the user callback's signature opts in by declaring a `session` keyword parameter (or `**kwargs`). At call time we downcast the `&dyn Session` to its canonical `SessionState` impl and build a fresh `SessionContext` over the same Arc-shared state, exposed to Python as a `datafusion.SessionContext` wrapper. Existing callbacks whose signatures do not declare `session` continue to be called with the positional expression arguments only — no behavior change for current users. Note: a UDTF body cannot drive a fresh `ctx.sql(...).collect()` on the passed-in session because the outer SQL execution already holds the tokio runtime. Use the session for metadata access (catalogs, UDF lookups, config) rather than nested DataFrame collection. Co-Authored-By: Claude Opus 4.7 (1M context) --- crates/core/src/udtf.rs | 86 +++++++++++++++++++++++++------ python/datafusion/user_defined.py | 52 ++++++++++++++++++- python/tests/test_udtf.py | 65 +++++++++++++++++++++++ 3 files changed, 186 insertions(+), 17 deletions(-) diff --git a/crates/core/src/udtf.rs b/crates/core/src/udtf.rs index b3de25e52..3a244a417 100644 --- a/crates/core/src/udtf.rs +++ b/crates/core/src/udtf.rs @@ -18,20 +18,33 @@ use std::ptr::NonNull; use std::sync::Arc; -use datafusion::catalog::{TableFunctionArgs, TableFunctionImpl, TableProvider}; -use datafusion::error::Result as DataFusionResult; +use datafusion::catalog::{Session, TableFunctionArgs, TableFunctionImpl, TableProvider}; +use datafusion::error::{DataFusionError, Result as DataFusionResult}; +use datafusion::execution::context::SessionContext; +use datafusion::execution::session_state::SessionState; use datafusion::logical_expr::Expr; use datafusion_ffi::udtf::FFI_TableFunction; use pyo3::IntoPyObjectExt; use pyo3::exceptions::{PyImportError, PyTypeError}; use pyo3::prelude::*; -use pyo3::types::{PyCapsule, PyTuple, PyType}; +use pyo3::types::{PyCapsule, PyDict, PyTuple, PyType}; use crate::context::PySessionContext; use crate::errors::{py_datafusion_err, to_datafusion_err}; use crate::expr::PyExpr; use crate::table::PyTable; +/// A pure-Python UDTF callable plus the metadata we discovered about it +/// at registration time. +#[derive(Debug, Clone)] +pub(crate) struct PythonTableFunctionCallable { + pub(crate) callable: Arc>, + /// Whether the callable's signature accepts a ``session`` keyword + /// argument (or ``**kwargs``). When true the calling + /// :class:`SessionContext` is threaded through on each invocation. + pub(crate) accepts_session: bool, +} + /// Represents a user defined table function #[pyclass(from_py_object, frozen, name = "TableFunction", module = "datafusion")] #[derive(Debug, Clone)] @@ -40,21 +53,21 @@ pub struct PyTableFunction { pub(crate) inner: PyTableFunctionInner, } -// TODO: Implement pure python based user defined table functions #[derive(Debug, Clone)] pub(crate) enum PyTableFunctionInner { - PythonFunction(Arc>), + PythonFunction(PythonTableFunctionCallable), FFIFunction(Arc), } #[pymethods] impl PyTableFunction { #[new] - #[pyo3(signature=(name, func, session))] + #[pyo3(signature=(name, func, session, accepts_session=false))] pub fn new( name: &str, func: Bound<'_, PyAny>, session: Option>, + accepts_session: bool, ) -> PyResult { let inner = if func.hasattr("__datafusion_table_function__")? { let py = func.py(); @@ -80,8 +93,10 @@ impl PyTableFunction { PyTableFunctionInner::FFIFunction(foreign_func) } else { - let py_obj = Arc::new(func.unbind()); - PyTableFunctionInner::PythonFunction(py_obj) + PyTableFunctionInner::PythonFunction(PythonTableFunctionCallable { + callable: Arc::new(func.unbind()), + accepts_session, + }) }; Ok(Self { @@ -107,20 +122,59 @@ impl PyTableFunction { } } +/// Materialize a fresh :class:`PySessionContext` from the borrowed +/// ``&dyn Session`` handed in at call time. +/// +/// Upstream invokes ``call_with_args`` with a trait-object reference +/// rather than an owned context; we downcast it to the canonical +/// :class:`SessionState` impl and rebuild a :class:`SessionContext` +/// (sharing the same registries via the Arc-heavy interior of +/// :class:`SessionState`). Returns an error if the trait object is a +/// non-:class:`SessionState` implementation (e.g. a foreign FFI +/// session) — those are not exposed to Python today. +fn py_session_from_session(session: &dyn Session) -> DataFusionResult { + let state = session + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Execution( + "Cannot expose this UDTF's calling session to Python: \ + the session is not a SessionState. Drop the `session` \ + keyword from the callback signature to fall back to the \ + expression-only call form." + .to_string(), + ) + })?; + Ok(PySessionContext::from(SessionContext::new_with_state( + state.clone(), + ))) +} + #[allow(clippy::result_large_err)] fn call_python_table_function( - func: &Arc>, - args: &[Expr], + func: &PythonTableFunctionCallable, + args: TableFunctionArgs, ) -> DataFusionResult> { - let args = args + let py_session = if func.accepts_session { + Some(py_session_from_session(args.session())?) + } else { + None + }; + let py_exprs = args + .exprs() .iter() .map(|arg| PyExpr::from(arg.clone())) .collect::>(); - // move |args: &[ArrayRef]| -> Result { Python::attach(|py| { - let py_args = PyTuple::new(py, args)?; - let provider_obj = func.call1(py, py_args)?; + let py_args = PyTuple::new(py, py_exprs)?; + let provider_obj = if let Some(session) = py_session { + let kwargs = PyDict::new(py); + kwargs.set_item("session", session.into_pyobject(py)?)?; + func.callable.call(py, py_args, Some(&kwargs))? + } else { + func.callable.call1(py, py_args)? + }; let provider = provider_obj.bind(py).clone(); Ok::, PyErr>(PyTable::new(provider, None)?.table) @@ -132,8 +186,8 @@ impl TableFunctionImpl for PyTableFunction { fn call_with_args(&self, args: TableFunctionArgs) -> DataFusionResult> { match &self.inner { PyTableFunctionInner::FFIFunction(func) => func.call_with_args(args), - PyTableFunctionInner::PythonFunction(obj) => { - call_python_table_function(obj, args.exprs()) + PyTableFunctionInner::PythonFunction(callable) => { + call_python_table_function(callable, args) } } } diff --git a/python/datafusion/user_defined.py b/python/datafusion/user_defined.py index 3eb50a094..c524ac4e1 100644 --- a/python/datafusion/user_defined.py +++ b/python/datafusion/user_defined.py @@ -1054,6 +1054,47 @@ def from_pycapsule(func: WindowUDFExportable) -> WindowUDF: ) +def _callable_accepts_session_kwarg(func: object) -> bool: + """Return True if ``func`` accepts a ``session`` keyword argument. + + Used to opt a Python UDTF callback into receiving the calling + :class:`SessionContext` at invocation time. ``**kwargs`` callables + are treated as accepting it; built-ins and objects without an + introspectable signature fall back to ``False``. + """ + import inspect # noqa: PLC0415 + + try: + signature = inspect.signature(func) + except (TypeError, ValueError): + return False + + for parameter in signature.parameters.values(): + if parameter.name == "session": + return True + if parameter.kind is inspect.Parameter.VAR_KEYWORD: + return True + return False + + +def _wrap_session_kwarg_for_udtf(func: Callable[..., Any]) -> Callable[..., Any]: + """Adapt the raw internal session pyo3 object back to a Python wrapper. + + The Rust call site forwards a ``datafusion._internal.SessionContext``, + but UDTF authors expect to interact with the public + :class:`datafusion.SessionContext` wrapper. This closure wraps the + internal object once per call before delegating to ``func``. + """ + + @functools.wraps(func, updated=()) + def adapter(*args: Any, session: Any, **kwargs: Any) -> Any: + wrapped = SessionContext.__new__(SessionContext) + wrapped.ctx = session + return func(*args, session=wrapped, **kwargs) + + return adapter + + class TableFunction: """Class for performing user-defined table functions (UDTF). @@ -1066,10 +1107,19 @@ def __init__( ) -> None: """Instantiate a user-defined table function (UDTF). + If ``func``'s signature accepts a ``session`` keyword (or + ``**kwargs``), the calling :class:`SessionContext` is threaded + through to it on each invocation. Use it inside the body to look + up registered tables, UDFs, or session configuration. Callables + whose signatures do not declare ``session`` are invoked with the + positional expression arguments only. + See :py:func:`udtf` for a convenience function and argument descriptions. """ - self._udtf = df_internal.TableFunction(name, func, ctx) + accepts_session = _callable_accepts_session_kwarg(func) + registered = _wrap_session_kwarg_for_udtf(func) if accepts_session else func + self._udtf = df_internal.TableFunction(name, registered, ctx, accepts_session) def __call__(self, *args: Expr) -> Any: """Execute the UDTF and return a table provider.""" diff --git a/python/tests/test_udtf.py b/python/tests/test_udtf.py index 925a8ba01..7a1b128bf 100644 --- a/python/tests/test_udtf.py +++ b/python/tests/test_udtf.py @@ -134,3 +134,68 @@ def string_arg_func(prefix: Expr) -> TableProviderExportable: result = ctx.sql("SELECT * FROM string_arg_func('test')").collect() assert len(result) == 1 assert result[0].schema.names == ["test_a", "test_b"] + + +def test_python_table_function_receives_session() -> None: + """A UDTF whose signature declares ``session`` gets the calling ctx.""" + ctx = SessionContext() + captured: list[SessionContext] = [] + + @udtf("session_aware_func") + def session_aware_func(*, session: SessionContext) -> TableProviderExportable: + captured.append(session) + batch = pa.RecordBatch.from_pydict({"a": [1, 2, 3]}) + return Table(ds.dataset([batch])) + + ctx.register_udtf(session_aware_func) + result = ctx.sql("SELECT * FROM session_aware_func()").collect() + + assert len(captured) == 1 + assert isinstance(captured[0], SessionContext) + # Sharing the same catalog confirms the wrapper points at the caller's state. + assert captured[0].catalog().schema().names() == ctx.catalog().schema().names() + assert result[0].column(0).to_pylist() == [1, 2, 3] + + +def test_python_table_function_session_used_for_metadata() -> None: + """The UDTF can inspect session state through the passed-in context.""" + ctx = SessionContext() + base_batch = pa.RecordBatch.from_pydict({"x": [10, 20, 30]}) + ctx.register_batch("base_tbl", base_batch) + + seen_tables: list[set[str]] = [] + + @udtf("table_inventory") + def table_inventory(*, session: SessionContext) -> TableProviderExportable: + # Stash the visible tables to verify the session wired through. + seen_tables.append(session.catalog().schema().names()) + batch = pa.RecordBatch.from_pydict({"name": ["base_tbl"]}) + return Table(ds.dataset([batch])) + + ctx.register_udtf(table_inventory) + result = ctx.sql("SELECT * FROM table_inventory()").collect() + + assert seen_tables == [{"base_tbl"}] + assert result[0].column(0).to_pylist() == ["base_tbl"] + + +def test_python_table_function_class_callable_session_kwarg() -> None: + """Class-based UDTFs whose __call__ accepts ``session`` get it too.""" + ctx = SessionContext() + captured: list[SessionContext] = [] + + class SessionAware: + def __call__( + self, n: Expr, *, session: SessionContext + ) -> TableProviderExportable: + captured.append(session) + count = n.to_variant().value_i64() + batch = pa.RecordBatch.from_pydict({"a": list(range(count))}) + return Table(ds.dataset([batch])) + + ctx.register_udtf(udtf(SessionAware(), "session_class_func")) + result = ctx.sql("SELECT * FROM session_class_func(3)").collect() + + assert len(captured) == 1 + assert isinstance(captured[0], SessionContext) + assert result[0].column(0).to_pylist() == [0, 1, 2] From 67715c61ec2117769cad16b4749351d2cbb7022a Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 27 May 2026 19:14:41 -0400 Subject: [PATCH 2/4] docs: clarify py_session_from_session downcast is defensive The doc comment implied a foreign FFI session was a real input. No current path reaches a pure-Python UDTF with a non-SessionState session: the SQL planner and __call__ both hand a SessionState, and a ForeignSession would only arrive via FFI-export of the UDTF, which datafusion-python does not do. Reword to state the guard is defensive and rewrap the error string. Co-Authored-By: Claude Opus 4.7 (1M context) --- crates/core/src/udtf.rs | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/crates/core/src/udtf.rs b/crates/core/src/udtf.rs index 3a244a417..ce88c18c5 100644 --- a/crates/core/src/udtf.rs +++ b/crates/core/src/udtf.rs @@ -129,19 +129,26 @@ impl PyTableFunction { /// rather than an owned context; we downcast it to the canonical /// :class:`SessionState` impl and rebuild a :class:`SessionContext` /// (sharing the same registries via the Arc-heavy interior of -/// :class:`SessionState`). Returns an error if the trait object is a -/// non-:class:`SessionState` implementation (e.g. a foreign FFI -/// session) — those are not exposed to Python today. +/// :class:`SessionState`). +/// +/// The downcast is defensive. Every path that reaches a pure-Python +/// UDTF today hands us a `SessionState`: the SQL planner builds the +/// args from its own `SessionState`, and `PyTableFunction::__call__` +/// uses the global context's state. A non-`SessionState` session +/// (e.g. a `ForeignSession`) would only arrive if this UDTF were +/// exported across the FFI boundary to a foreign-library consumer, +/// which datafusion-python does not do. Should that change, this +/// returns an error rather than silently misbehaving. fn py_session_from_session(session: &dyn Session) -> DataFusionResult { let state = session .as_any() .downcast_ref::() .ok_or_else(|| { DataFusionError::Execution( - "Cannot expose this UDTF's calling session to Python: \ - the session is not a SessionState. Drop the `session` \ - keyword from the callback signature to fall back to the \ - expression-only call form." + "Cannot expose this UDTF's calling session to Python: the \ + session is not a SessionState. Drop the `session` keyword \ + from the callback signature to fall back to the \ + expression-only call form." .to_string(), ) })?; From 12cf674c83bcbe71273592e3863b96451d6d3a69 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Thu, 28 May 2026 17:27:36 -0400 Subject: [PATCH 3/4] refactor: opt-in UDTF session injection via with_session flag Replaces signature sniffing with an explicit ``with_session=True`` kwarg on ``TableFunction`` / ``udtf``. Avoids name-based detection footguns (positional-only ``session`` params, accidental ``**kwargs`` opt-in, shadowing by unrelated params) and makes author intent visible at registration. Also documents the feature in the UDTF user guide. Rust field renamed ``accepts_session`` -> ``inject_session_on_call`` to match the Python-side opt-in semantics. Co-Authored-By: Claude Opus 4.7 (1M context) --- crates/core/src/udtf.rs | 17 ++-- .../common-operations/udf-and-udfa.rst | 36 ++++++++ python/datafusion/user_defined.py | 88 +++++++++---------- python/tests/test_udtf.py | 28 ++++-- 4 files changed, 111 insertions(+), 58 deletions(-) diff --git a/crates/core/src/udtf.rs b/crates/core/src/udtf.rs index ce88c18c5..cffa0c12a 100644 --- a/crates/core/src/udtf.rs +++ b/crates/core/src/udtf.rs @@ -39,10 +39,11 @@ use crate::table::PyTable; #[derive(Debug, Clone)] pub(crate) struct PythonTableFunctionCallable { pub(crate) callable: Arc>, - /// Whether the callable's signature accepts a ``session`` keyword - /// argument (or ``**kwargs``). When true the calling - /// :class:`SessionContext` is threaded through on each invocation. - pub(crate) accepts_session: bool, + /// When true, the calling :class:`SessionContext` is passed to the + /// callable as a ``session`` keyword argument on every invocation. + /// Opt-in at registration time via ``with_session=True`` on the + /// Python wrapper. + pub(crate) inject_session_on_call: bool, } /// Represents a user defined table function @@ -62,12 +63,12 @@ pub(crate) enum PyTableFunctionInner { #[pymethods] impl PyTableFunction { #[new] - #[pyo3(signature=(name, func, session, accepts_session=false))] + #[pyo3(signature=(name, func, session, inject_session_on_call=false))] pub fn new( name: &str, func: Bound<'_, PyAny>, session: Option>, - accepts_session: bool, + inject_session_on_call: bool, ) -> PyResult { let inner = if func.hasattr("__datafusion_table_function__")? { let py = func.py(); @@ -95,7 +96,7 @@ impl PyTableFunction { } else { PyTableFunctionInner::PythonFunction(PythonTableFunctionCallable { callable: Arc::new(func.unbind()), - accepts_session, + inject_session_on_call, }) }; @@ -162,7 +163,7 @@ fn call_python_table_function( func: &PythonTableFunctionCallable, args: TableFunctionArgs, ) -> DataFusionResult> { - let py_session = if func.accepts_session { + let py_session = if func.inject_session_on_call { Some(py_session_from_session(args.session())?) } else { None diff --git a/docs/source/user-guide/common-operations/udf-and-udfa.rst b/docs/source/user-guide/common-operations/udf-and-udfa.rst index 59c47b595..2135bb2dc 100644 --- a/docs/source/user-guide/common-operations/udf-and-udfa.rst +++ b/docs/source/user-guide/common-operations/udf-and-udfa.rst @@ -431,3 +431,39 @@ that you wish to expose via PyO3, you need to expose it as a ``PyCapsule``. PyCapsule::new(py, provider, Some(name)) } } + +Accessing the Calling Session +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Pure-Python UDTFs can opt into receiving the calling +:py:class:`~datafusion.SessionContext` by registering with +``with_session=True``. The context is passed as a ``session`` keyword +argument on every invocation. Use it to look up registered tables, +UDFs, or session configuration from inside the callback. + +.. code-block:: python + + from datafusion import SessionContext, Table, udtf + from datafusion.context import TableProviderExportable + import pyarrow as pa + import pyarrow.dataset as ds + + @udtf("list_tables", with_session=True) + def list_tables(*, session: SessionContext) -> TableProviderExportable: + names = sorted(session.catalog().schema().names()) + batch = pa.RecordBatch.from_pydict({"name": names}) + return Table(ds.dataset([batch])) + + ctx = SessionContext() + ctx.register_batch("t1", pa.RecordBatch.from_pydict({"x": [1]})) + ctx.register_udtf(list_tables) + ctx.sql("SELECT * FROM list_tables()").show() + +Without ``with_session=True``, the callback receives only the positional +expression arguments. The flag is opt-in so existing UDTFs keep working +unchanged. + +The injected ``session`` is a fresh :py:class:`~datafusion.SessionContext` +wrapper backed by the same underlying state as the caller, so registries +(tables, UDFs, catalogs) are visible. Mutations made through it affect +the live session. diff --git a/python/datafusion/user_defined.py b/python/datafusion/user_defined.py index c524ac4e1..c42bffd0a 100644 --- a/python/datafusion/user_defined.py +++ b/python/datafusion/user_defined.py @@ -1054,29 +1054,6 @@ def from_pycapsule(func: WindowUDFExportable) -> WindowUDF: ) -def _callable_accepts_session_kwarg(func: object) -> bool: - """Return True if ``func`` accepts a ``session`` keyword argument. - - Used to opt a Python UDTF callback into receiving the calling - :class:`SessionContext` at invocation time. ``**kwargs`` callables - are treated as accepting it; built-ins and objects without an - introspectable signature fall back to ``False``. - """ - import inspect # noqa: PLC0415 - - try: - signature = inspect.signature(func) - except (TypeError, ValueError): - return False - - for parameter in signature.parameters.values(): - if parameter.name == "session": - return True - if parameter.kind is inspect.Parameter.VAR_KEYWORD: - return True - return False - - def _wrap_session_kwarg_for_udtf(func: Callable[..., Any]) -> Callable[..., Any]: """Adapt the raw internal session pyo3 object back to a Python wrapper. @@ -1103,23 +1080,27 @@ class TableFunction: """ def __init__( - self, name: str, func: Callable[[], any], ctx: SessionContext | None = None + self, + name: str, + func: Callable[..., Any], + ctx: SessionContext | None = None, + *, + with_session: bool = False, ) -> None: """Instantiate a user-defined table function (UDTF). - If ``func``'s signature accepts a ``session`` keyword (or - ``**kwargs``), the calling :class:`SessionContext` is threaded - through to it on each invocation. Use it inside the body to look - up registered tables, UDFs, or session configuration. Callables - whose signatures do not declare ``session`` are invoked with the - positional expression arguments only. + Set ``with_session=True`` to have the calling + :class:`SessionContext` passed as a ``session`` keyword argument + on each invocation. Use it inside the callback to look up + registered tables, UDFs, or session configuration. When + ``with_session`` is ``False`` (the default), ``func`` is invoked + with the positional expression arguments only. See :py:func:`udtf` for a convenience function and argument descriptions. """ - accepts_session = _callable_accepts_session_kwarg(func) - registered = _wrap_session_kwarg_for_udtf(func) if accepts_session else func - self._udtf = df_internal.TableFunction(name, registered, ctx, accepts_session) + registered = _wrap_session_kwarg_for_udtf(func) if with_session else func + self._udtf = df_internal.TableFunction(name, registered, ctx, with_session) def __call__(self, *args: Expr) -> Any: """Execute the UDTF and return a table provider.""" @@ -1130,47 +1111,66 @@ def __call__(self, *args: Expr) -> Any: @staticmethod def udtf( name: str, + *, + with_session: bool = False, ) -> Callable[..., Any]: ... @overload @staticmethod def udtf( - func: Callable[[], Any], + func: Callable[..., Any], name: str, + *, + with_session: bool = False, ) -> TableFunction: ... @staticmethod - def udtf(*args: Any, **kwargs: Any): - """Create a new User-Defined Table Function (UDTF).""" + def udtf(*args: Any, with_session: bool = False, **kwargs: Any): + """Create a new User-Defined Table Function (UDTF). + + Pass ``with_session=True`` to have the calling + :class:`SessionContext` injected as a ``session`` keyword + argument on each invocation. + """ if args and callable(args[0]): # Case 1: Used as a function, require the first parameter to be callable - return TableFunction._create_table_udf(*args, **kwargs) + return TableFunction._create_table_udf( + *args, with_session=with_session, **kwargs + ) if args and hasattr(args[0], "__datafusion_table_function__"): # Case 2: We have a datafusion FFI provided function return TableFunction(args[1], args[0]) # Case 3: Used as a decorator with parameters - return TableFunction._create_table_udf_decorator(*args, **kwargs) + return TableFunction._create_table_udf_decorator( + *args, with_session=with_session, **kwargs + ) @staticmethod def _create_table_udf( func: Callable[..., Any], name: str, + *, + with_session: bool = False, ) -> TableFunction: """Create a TableFunction instance from function arguments.""" if not callable(func): msg = "`func` must be callable." raise TypeError(msg) - return TableFunction(name, func) + return TableFunction(name, func, with_session=with_session) @staticmethod def _create_table_udf_decorator( name: str | None = None, - ) -> Callable[[Callable[[], WindowEvaluator]], Callable[..., Expr]]: - """Create a decorator for a WindowUDF.""" - - def decorator(func: Callable[[], WindowEvaluator]) -> Callable[..., Expr]: - return TableFunction._create_table_udf(func, name) + *, + with_session: bool = False, + ) -> Callable[[Callable[..., Any]], TableFunction]: + """Create a decorator for a TableFunction.""" + + def decorator(func: Callable[..., Any]) -> TableFunction: + return TableFunction._create_table_udf( + func, name, with_session=with_session + ) return decorator diff --git a/python/tests/test_udtf.py b/python/tests/test_udtf.py index 7a1b128bf..1d35bfa14 100644 --- a/python/tests/test_udtf.py +++ b/python/tests/test_udtf.py @@ -137,11 +137,11 @@ def string_arg_func(prefix: Expr) -> TableProviderExportable: def test_python_table_function_receives_session() -> None: - """A UDTF whose signature declares ``session`` gets the calling ctx.""" + """A UDTF registered ``with_session=True`` gets the calling ctx.""" ctx = SessionContext() captured: list[SessionContext] = [] - @udtf("session_aware_func") + @udtf("session_aware_func", with_session=True) def session_aware_func(*, session: SessionContext) -> TableProviderExportable: captured.append(session) batch = pa.RecordBatch.from_pydict({"a": [1, 2, 3]}) @@ -165,7 +165,7 @@ def test_python_table_function_session_used_for_metadata() -> None: seen_tables: list[set[str]] = [] - @udtf("table_inventory") + @udtf("table_inventory", with_session=True) def table_inventory(*, session: SessionContext) -> TableProviderExportable: # Stash the visible tables to verify the session wired through. seen_tables.append(session.catalog().schema().names()) @@ -179,8 +179,8 @@ def table_inventory(*, session: SessionContext) -> TableProviderExportable: assert result[0].column(0).to_pylist() == ["base_tbl"] -def test_python_table_function_class_callable_session_kwarg() -> None: - """Class-based UDTFs whose __call__ accepts ``session`` get it too.""" +def test_python_table_function_class_callable_with_session() -> None: + """Class-based UDTFs opt in via ``with_session=True``.""" ctx = SessionContext() captured: list[SessionContext] = [] @@ -193,9 +193,25 @@ def __call__( batch = pa.RecordBatch.from_pydict({"a": list(range(count))}) return Table(ds.dataset([batch])) - ctx.register_udtf(udtf(SessionAware(), "session_class_func")) + ctx.register_udtf(udtf(SessionAware(), "session_class_func", with_session=True)) result = ctx.sql("SELECT * FROM session_class_func(3)").collect() assert len(captured) == 1 assert isinstance(captured[0], SessionContext) assert result[0].column(0).to_pylist() == [0, 1, 2] + + +def test_python_table_function_without_session_flag_no_injection() -> None: + """Default registration (no ``with_session``) calls func positionally.""" + ctx = SessionContext() + + @udtf("plain_func") + def plain_func(n: Expr) -> TableProviderExportable: + count = n.to_variant().value_i64() + batch = pa.RecordBatch.from_pydict({"a": list(range(count))}) + return Table(ds.dataset([batch])) + + ctx.register_udtf(plain_func) + result = ctx.sql("SELECT * FROM plain_func(4)").collect() + + assert result[0].column(0).to_pylist() == [0, 1, 2, 3] From ed40c3746531fcc1c52fb2c698a43e760b5245e7 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Thu, 28 May 2026 17:47:58 -0400 Subject: [PATCH 4/4] fix: reject with_session=True for FFI UDTFs and qualify mutation docs Raise TypeError when with_session=True is combined with an FFI-exported table function (one exposing __datafusion_table_function__). The Rust FFI branch does not consult the flag, so it would silently be dropped; guard both TableFunction.__init__ and the udtf() convenience entry. Qualify the doc claim that mutations through the injected session propagate to the caller: registry mutations do (shared Arc registries), but config changes do not (SessionConfig is cloned). Mirror the caveat in TableFunction.__init__ per the user-guide caveats convention. Co-Authored-By: Claude Opus 4.7 --- .../common-operations/udf-and-udfa.rst | 7 ++++-- python/datafusion/user_defined.py | 24 +++++++++++++++++++ python/tests/test_udtf.py | 18 ++++++++++++++ 3 files changed, 47 insertions(+), 2 deletions(-) diff --git a/docs/source/user-guide/common-operations/udf-and-udfa.rst b/docs/source/user-guide/common-operations/udf-and-udfa.rst index 2135bb2dc..918c2e29e 100644 --- a/docs/source/user-guide/common-operations/udf-and-udfa.rst +++ b/docs/source/user-guide/common-operations/udf-and-udfa.rst @@ -465,5 +465,8 @@ unchanged. The injected ``session`` is a fresh :py:class:`~datafusion.SessionContext` wrapper backed by the same underlying state as the caller, so registries -(tables, UDFs, catalogs) are visible. Mutations made through it affect -the live session. +(tables, UDFs, catalogs) are visible. Registry mutations (e.g. registering +a new table or UDF) propagate to the live session because the registries +are reference-counted and shared. Configuration changes made through the +wrapper (e.g. setting session options) do **not** propagate — the wrapper +holds its own clone of the session config. diff --git a/python/datafusion/user_defined.py b/python/datafusion/user_defined.py index c42bffd0a..70c199e18 100644 --- a/python/datafusion/user_defined.py +++ b/python/datafusion/user_defined.py @@ -1096,9 +1096,26 @@ def __init__( ``with_session`` is ``False`` (the default), ``func`` is invoked with the positional expression arguments only. + ``with_session=True`` is only supported for pure-Python callables. + Passing it together with an FFI-exported table function (one + exposing ``__datafusion_table_function__``) raises + :class:`TypeError`. + + Registry mutations performed through the injected session (such + as registering tables or UDFs) propagate to the caller's + :class:`SessionContext` because the registries are shared. + Configuration changes do **not** propagate; the wrapper holds + its own clone of the session config. + See :py:func:`udtf` for a convenience function and argument descriptions. """ + if with_session and hasattr(func, "__datafusion_table_function__"): + msg = ( + "`with_session=True` is not supported for FFI-exported table " + "functions; session injection requires a pure-Python callable." + ) + raise TypeError(msg) registered = _wrap_session_kwarg_for_udtf(func) if with_session else func self._udtf = df_internal.TableFunction(name, registered, ctx, with_session) @@ -1139,6 +1156,13 @@ def udtf(*args: Any, with_session: bool = False, **kwargs: Any): ) if args and hasattr(args[0], "__datafusion_table_function__"): # Case 2: We have a datafusion FFI provided function + if with_session: + msg = ( + "`with_session=True` is not supported for FFI-exported " + "table functions; session injection requires a " + "pure-Python callable." + ) + raise TypeError(msg) return TableFunction(args[1], args[0]) # Case 3: Used as a decorator with parameters return TableFunction._create_table_udf_decorator( diff --git a/python/tests/test_udtf.py b/python/tests/test_udtf.py index 1d35bfa14..dcb2bacc3 100644 --- a/python/tests/test_udtf.py +++ b/python/tests/test_udtf.py @@ -17,8 +17,10 @@ import pyarrow as pa import pyarrow.dataset as ds +import pytest from datafusion import Expr, SessionContext, Table, udtf from datafusion.context import TableProviderExportable +from datafusion.user_defined import TableFunction def python_table_function_inner( @@ -215,3 +217,19 @@ def plain_func(n: Expr) -> TableProviderExportable: result = ctx.sql("SELECT * FROM plain_func(4)").collect() assert result[0].column(0).to_pylist() == [0, 1, 2, 3] + + +def test_with_session_rejected_for_ffi_table_function() -> None: + """`with_session=True` is incompatible with FFI-exported table functions.""" + + class FakeFFITableFunction: + # Presence of this attribute is what marks a function as FFI-exported. + __datafusion_table_function__ = "stub" + + fake = FakeFFITableFunction() + + with pytest.raises(TypeError, match="FFI-exported table functions"): + udtf(fake, "fake_ffi", with_session=True) + + with pytest.raises(TypeError, match="FFI-exported table functions"): + TableFunction("fake_ffi", fake, with_session=True)