diff --git a/crates/core/src/udtf.rs b/crates/core/src/udtf.rs index b3de25e52..cffa0c12a 100644 --- a/crates/core/src/udtf.rs +++ b/crates/core/src/udtf.rs @@ -18,20 +18,34 @@ use std::ptr::NonNull; use std::sync::Arc; -use datafusion::catalog::{TableFunctionArgs, TableFunctionImpl, TableProvider}; -use datafusion::error::Result as DataFusionResult; +use datafusion::catalog::{Session, TableFunctionArgs, TableFunctionImpl, TableProvider}; +use datafusion::error::{DataFusionError, Result as DataFusionResult}; +use datafusion::execution::context::SessionContext; +use datafusion::execution::session_state::SessionState; use datafusion::logical_expr::Expr; use datafusion_ffi::udtf::FFI_TableFunction; use pyo3::IntoPyObjectExt; use pyo3::exceptions::{PyImportError, PyTypeError}; use pyo3::prelude::*; -use pyo3::types::{PyCapsule, PyTuple, PyType}; +use pyo3::types::{PyCapsule, PyDict, PyTuple, PyType}; use crate::context::PySessionContext; use crate::errors::{py_datafusion_err, to_datafusion_err}; use crate::expr::PyExpr; use crate::table::PyTable; +/// A pure-Python UDTF callable plus the metadata we discovered about it +/// at registration time. +#[derive(Debug, Clone)] +pub(crate) struct PythonTableFunctionCallable { + pub(crate) callable: Arc>, + /// When true, the calling :class:`SessionContext` is passed to the + /// callable as a ``session`` keyword argument on every invocation. + /// Opt-in at registration time via ``with_session=True`` on the + /// Python wrapper. + pub(crate) inject_session_on_call: bool, +} + /// Represents a user defined table function #[pyclass(from_py_object, frozen, name = "TableFunction", module = "datafusion")] #[derive(Debug, Clone)] @@ -40,21 +54,21 @@ pub struct PyTableFunction { pub(crate) inner: PyTableFunctionInner, } -// TODO: Implement pure python based user defined table functions #[derive(Debug, Clone)] pub(crate) enum PyTableFunctionInner { - PythonFunction(Arc>), + PythonFunction(PythonTableFunctionCallable), FFIFunction(Arc), } #[pymethods] impl PyTableFunction { #[new] - #[pyo3(signature=(name, func, session))] + #[pyo3(signature=(name, func, session, inject_session_on_call=false))] pub fn new( name: &str, func: Bound<'_, PyAny>, session: Option>, + inject_session_on_call: bool, ) -> PyResult { let inner = if func.hasattr("__datafusion_table_function__")? { let py = func.py(); @@ -80,8 +94,10 @@ impl PyTableFunction { PyTableFunctionInner::FFIFunction(foreign_func) } else { - let py_obj = Arc::new(func.unbind()); - PyTableFunctionInner::PythonFunction(py_obj) + PyTableFunctionInner::PythonFunction(PythonTableFunctionCallable { + callable: Arc::new(func.unbind()), + inject_session_on_call, + }) }; Ok(Self { @@ -107,20 +123,66 @@ impl PyTableFunction { } } +/// Materialize a fresh :class:`PySessionContext` from the borrowed +/// ``&dyn Session`` handed in at call time. +/// +/// Upstream invokes ``call_with_args`` with a trait-object reference +/// rather than an owned context; we downcast it to the canonical +/// :class:`SessionState` impl and rebuild a :class:`SessionContext` +/// (sharing the same registries via the Arc-heavy interior of +/// :class:`SessionState`). +/// +/// The downcast is defensive. Every path that reaches a pure-Python +/// UDTF today hands us a `SessionState`: the SQL planner builds the +/// args from its own `SessionState`, and `PyTableFunction::__call__` +/// uses the global context's state. A non-`SessionState` session +/// (e.g. a `ForeignSession`) would only arrive if this UDTF were +/// exported across the FFI boundary to a foreign-library consumer, +/// which datafusion-python does not do. Should that change, this +/// returns an error rather than silently misbehaving. +fn py_session_from_session(session: &dyn Session) -> DataFusionResult { + let state = session + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Execution( + "Cannot expose this UDTF's calling session to Python: the \ + session is not a SessionState. Drop the `session` keyword \ + from the callback signature to fall back to the \ + expression-only call form." + .to_string(), + ) + })?; + Ok(PySessionContext::from(SessionContext::new_with_state( + state.clone(), + ))) +} + #[allow(clippy::result_large_err)] fn call_python_table_function( - func: &Arc>, - args: &[Expr], + func: &PythonTableFunctionCallable, + args: TableFunctionArgs, ) -> DataFusionResult> { - let args = args + let py_session = if func.inject_session_on_call { + Some(py_session_from_session(args.session())?) + } else { + None + }; + let py_exprs = args + .exprs() .iter() .map(|arg| PyExpr::from(arg.clone())) .collect::>(); - // move |args: &[ArrayRef]| -> Result { Python::attach(|py| { - let py_args = PyTuple::new(py, args)?; - let provider_obj = func.call1(py, py_args)?; + let py_args = PyTuple::new(py, py_exprs)?; + let provider_obj = if let Some(session) = py_session { + let kwargs = PyDict::new(py); + kwargs.set_item("session", session.into_pyobject(py)?)?; + func.callable.call(py, py_args, Some(&kwargs))? + } else { + func.callable.call1(py, py_args)? + }; let provider = provider_obj.bind(py).clone(); Ok::, PyErr>(PyTable::new(provider, None)?.table) @@ -132,8 +194,8 @@ impl TableFunctionImpl for PyTableFunction { fn call_with_args(&self, args: TableFunctionArgs) -> DataFusionResult> { match &self.inner { PyTableFunctionInner::FFIFunction(func) => func.call_with_args(args), - PyTableFunctionInner::PythonFunction(obj) => { - call_python_table_function(obj, args.exprs()) + PyTableFunctionInner::PythonFunction(callable) => { + call_python_table_function(callable, args) } } } diff --git a/docs/source/user-guide/common-operations/udf-and-udfa.rst b/docs/source/user-guide/common-operations/udf-and-udfa.rst index 59c47b595..918c2e29e 100644 --- a/docs/source/user-guide/common-operations/udf-and-udfa.rst +++ b/docs/source/user-guide/common-operations/udf-and-udfa.rst @@ -431,3 +431,42 @@ that you wish to expose via PyO3, you need to expose it as a ``PyCapsule``. PyCapsule::new(py, provider, Some(name)) } } + +Accessing the Calling Session +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Pure-Python UDTFs can opt into receiving the calling +:py:class:`~datafusion.SessionContext` by registering with +``with_session=True``. The context is passed as a ``session`` keyword +argument on every invocation. Use it to look up registered tables, +UDFs, or session configuration from inside the callback. + +.. code-block:: python + + from datafusion import SessionContext, Table, udtf + from datafusion.context import TableProviderExportable + import pyarrow as pa + import pyarrow.dataset as ds + + @udtf("list_tables", with_session=True) + def list_tables(*, session: SessionContext) -> TableProviderExportable: + names = sorted(session.catalog().schema().names()) + batch = pa.RecordBatch.from_pydict({"name": names}) + return Table(ds.dataset([batch])) + + ctx = SessionContext() + ctx.register_batch("t1", pa.RecordBatch.from_pydict({"x": [1]})) + ctx.register_udtf(list_tables) + ctx.sql("SELECT * FROM list_tables()").show() + +Without ``with_session=True``, the callback receives only the positional +expression arguments. The flag is opt-in so existing UDTFs keep working +unchanged. + +The injected ``session`` is a fresh :py:class:`~datafusion.SessionContext` +wrapper backed by the same underlying state as the caller, so registries +(tables, UDFs, catalogs) are visible. Registry mutations (e.g. registering +a new table or UDF) propagate to the live session because the registries +are reference-counted and shared. Configuration changes made through the +wrapper (e.g. setting session options) do **not** propagate — the wrapper +holds its own clone of the session config. diff --git a/python/datafusion/user_defined.py b/python/datafusion/user_defined.py index 3eb50a094..70c199e18 100644 --- a/python/datafusion/user_defined.py +++ b/python/datafusion/user_defined.py @@ -1054,6 +1054,24 @@ def from_pycapsule(func: WindowUDFExportable) -> WindowUDF: ) +def _wrap_session_kwarg_for_udtf(func: Callable[..., Any]) -> Callable[..., Any]: + """Adapt the raw internal session pyo3 object back to a Python wrapper. + + The Rust call site forwards a ``datafusion._internal.SessionContext``, + but UDTF authors expect to interact with the public + :class:`datafusion.SessionContext` wrapper. This closure wraps the + internal object once per call before delegating to ``func``. + """ + + @functools.wraps(func, updated=()) + def adapter(*args: Any, session: Any, **kwargs: Any) -> Any: + wrapped = SessionContext.__new__(SessionContext) + wrapped.ctx = session + return func(*args, session=wrapped, **kwargs) + + return adapter + + class TableFunction: """Class for performing user-defined table functions (UDTF). @@ -1062,14 +1080,44 @@ class TableFunction: """ def __init__( - self, name: str, func: Callable[[], any], ctx: SessionContext | None = None + self, + name: str, + func: Callable[..., Any], + ctx: SessionContext | None = None, + *, + with_session: bool = False, ) -> None: """Instantiate a user-defined table function (UDTF). + Set ``with_session=True`` to have the calling + :class:`SessionContext` passed as a ``session`` keyword argument + on each invocation. Use it inside the callback to look up + registered tables, UDFs, or session configuration. When + ``with_session`` is ``False`` (the default), ``func`` is invoked + with the positional expression arguments only. + + ``with_session=True`` is only supported for pure-Python callables. + Passing it together with an FFI-exported table function (one + exposing ``__datafusion_table_function__``) raises + :class:`TypeError`. + + Registry mutations performed through the injected session (such + as registering tables or UDFs) propagate to the caller's + :class:`SessionContext` because the registries are shared. + Configuration changes do **not** propagate; the wrapper holds + its own clone of the session config. + See :py:func:`udtf` for a convenience function and argument descriptions. """ - self._udtf = df_internal.TableFunction(name, func, ctx) + if with_session and hasattr(func, "__datafusion_table_function__"): + msg = ( + "`with_session=True` is not supported for FFI-exported table " + "functions; session injection requires a pure-Python callable." + ) + raise TypeError(msg) + registered = _wrap_session_kwarg_for_udtf(func) if with_session else func + self._udtf = df_internal.TableFunction(name, registered, ctx, with_session) def __call__(self, *args: Expr) -> Any: """Execute the UDTF and return a table provider.""" @@ -1080,47 +1128,73 @@ def __call__(self, *args: Expr) -> Any: @staticmethod def udtf( name: str, + *, + with_session: bool = False, ) -> Callable[..., Any]: ... @overload @staticmethod def udtf( - func: Callable[[], Any], + func: Callable[..., Any], name: str, + *, + with_session: bool = False, ) -> TableFunction: ... @staticmethod - def udtf(*args: Any, **kwargs: Any): - """Create a new User-Defined Table Function (UDTF).""" + def udtf(*args: Any, with_session: bool = False, **kwargs: Any): + """Create a new User-Defined Table Function (UDTF). + + Pass ``with_session=True`` to have the calling + :class:`SessionContext` injected as a ``session`` keyword + argument on each invocation. + """ if args and callable(args[0]): # Case 1: Used as a function, require the first parameter to be callable - return TableFunction._create_table_udf(*args, **kwargs) + return TableFunction._create_table_udf( + *args, with_session=with_session, **kwargs + ) if args and hasattr(args[0], "__datafusion_table_function__"): # Case 2: We have a datafusion FFI provided function + if with_session: + msg = ( + "`with_session=True` is not supported for FFI-exported " + "table functions; session injection requires a " + "pure-Python callable." + ) + raise TypeError(msg) return TableFunction(args[1], args[0]) # Case 3: Used as a decorator with parameters - return TableFunction._create_table_udf_decorator(*args, **kwargs) + return TableFunction._create_table_udf_decorator( + *args, with_session=with_session, **kwargs + ) @staticmethod def _create_table_udf( func: Callable[..., Any], name: str, + *, + with_session: bool = False, ) -> TableFunction: """Create a TableFunction instance from function arguments.""" if not callable(func): msg = "`func` must be callable." raise TypeError(msg) - return TableFunction(name, func) + return TableFunction(name, func, with_session=with_session) @staticmethod def _create_table_udf_decorator( name: str | None = None, - ) -> Callable[[Callable[[], WindowEvaluator]], Callable[..., Expr]]: - """Create a decorator for a WindowUDF.""" - - def decorator(func: Callable[[], WindowEvaluator]) -> Callable[..., Expr]: - return TableFunction._create_table_udf(func, name) + *, + with_session: bool = False, + ) -> Callable[[Callable[..., Any]], TableFunction]: + """Create a decorator for a TableFunction.""" + + def decorator(func: Callable[..., Any]) -> TableFunction: + return TableFunction._create_table_udf( + func, name, with_session=with_session + ) return decorator diff --git a/python/tests/test_udtf.py b/python/tests/test_udtf.py index 925a8ba01..dcb2bacc3 100644 --- a/python/tests/test_udtf.py +++ b/python/tests/test_udtf.py @@ -17,8 +17,10 @@ import pyarrow as pa import pyarrow.dataset as ds +import pytest from datafusion import Expr, SessionContext, Table, udtf from datafusion.context import TableProviderExportable +from datafusion.user_defined import TableFunction def python_table_function_inner( @@ -134,3 +136,100 @@ def string_arg_func(prefix: Expr) -> TableProviderExportable: result = ctx.sql("SELECT * FROM string_arg_func('test')").collect() assert len(result) == 1 assert result[0].schema.names == ["test_a", "test_b"] + + +def test_python_table_function_receives_session() -> None: + """A UDTF registered ``with_session=True`` gets the calling ctx.""" + ctx = SessionContext() + captured: list[SessionContext] = [] + + @udtf("session_aware_func", with_session=True) + def session_aware_func(*, session: SessionContext) -> TableProviderExportable: + captured.append(session) + batch = pa.RecordBatch.from_pydict({"a": [1, 2, 3]}) + return Table(ds.dataset([batch])) + + ctx.register_udtf(session_aware_func) + result = ctx.sql("SELECT * FROM session_aware_func()").collect() + + assert len(captured) == 1 + assert isinstance(captured[0], SessionContext) + # Sharing the same catalog confirms the wrapper points at the caller's state. + assert captured[0].catalog().schema().names() == ctx.catalog().schema().names() + assert result[0].column(0).to_pylist() == [1, 2, 3] + + +def test_python_table_function_session_used_for_metadata() -> None: + """The UDTF can inspect session state through the passed-in context.""" + ctx = SessionContext() + base_batch = pa.RecordBatch.from_pydict({"x": [10, 20, 30]}) + ctx.register_batch("base_tbl", base_batch) + + seen_tables: list[set[str]] = [] + + @udtf("table_inventory", with_session=True) + def table_inventory(*, session: SessionContext) -> TableProviderExportable: + # Stash the visible tables to verify the session wired through. + seen_tables.append(session.catalog().schema().names()) + batch = pa.RecordBatch.from_pydict({"name": ["base_tbl"]}) + return Table(ds.dataset([batch])) + + ctx.register_udtf(table_inventory) + result = ctx.sql("SELECT * FROM table_inventory()").collect() + + assert seen_tables == [{"base_tbl"}] + assert result[0].column(0).to_pylist() == ["base_tbl"] + + +def test_python_table_function_class_callable_with_session() -> None: + """Class-based UDTFs opt in via ``with_session=True``.""" + ctx = SessionContext() + captured: list[SessionContext] = [] + + class SessionAware: + def __call__( + self, n: Expr, *, session: SessionContext + ) -> TableProviderExportable: + captured.append(session) + count = n.to_variant().value_i64() + batch = pa.RecordBatch.from_pydict({"a": list(range(count))}) + return Table(ds.dataset([batch])) + + ctx.register_udtf(udtf(SessionAware(), "session_class_func", with_session=True)) + result = ctx.sql("SELECT * FROM session_class_func(3)").collect() + + assert len(captured) == 1 + assert isinstance(captured[0], SessionContext) + assert result[0].column(0).to_pylist() == [0, 1, 2] + + +def test_python_table_function_without_session_flag_no_injection() -> None: + """Default registration (no ``with_session``) calls func positionally.""" + ctx = SessionContext() + + @udtf("plain_func") + def plain_func(n: Expr) -> TableProviderExportable: + count = n.to_variant().value_i64() + batch = pa.RecordBatch.from_pydict({"a": list(range(count))}) + return Table(ds.dataset([batch])) + + ctx.register_udtf(plain_func) + result = ctx.sql("SELECT * FROM plain_func(4)").collect() + + assert result[0].column(0).to_pylist() == [0, 1, 2, 3] + + +def test_with_session_rejected_for_ffi_table_function() -> None: + """`with_session=True` is incompatible with FFI-exported table functions.""" + + class FakeFFITableFunction: + # Presence of this attribute is what marks a function as FFI-exported. + __datafusion_table_function__ = "stub" + + fake = FakeFFITableFunction() + + with pytest.raises(TypeError, match="FFI-exported table functions"): + udtf(fake, "fake_ffi", with_session=True) + + with pytest.raises(TypeError, match="FFI-exported table functions"): + TableFunction("fake_ffi", fake, with_session=True)