Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 78 additions & 16 deletions crates/core/src/udtf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,34 @@
use std::ptr::NonNull;
use std::sync::Arc;

use datafusion::catalog::{TableFunctionArgs, TableFunctionImpl, TableProvider};
use datafusion::error::Result as DataFusionResult;
use datafusion::catalog::{Session, TableFunctionArgs, TableFunctionImpl, TableProvider};
use datafusion::error::{DataFusionError, Result as DataFusionResult};
use datafusion::execution::context::SessionContext;
use datafusion::execution::session_state::SessionState;
use datafusion::logical_expr::Expr;
use datafusion_ffi::udtf::FFI_TableFunction;
use pyo3::IntoPyObjectExt;
use pyo3::exceptions::{PyImportError, PyTypeError};
use pyo3::prelude::*;
use pyo3::types::{PyCapsule, PyTuple, PyType};
use pyo3::types::{PyCapsule, PyDict, PyTuple, PyType};

use crate::context::PySessionContext;
use crate::errors::{py_datafusion_err, to_datafusion_err};
use crate::expr::PyExpr;
use crate::table::PyTable;

/// A pure-Python UDTF callable plus the metadata we discovered about it
/// at registration time.
#[derive(Debug, Clone)]
pub(crate) struct PythonTableFunctionCallable {
pub(crate) callable: Arc<Py<PyAny>>,
/// When true, the calling :class:`SessionContext` is passed to the
/// callable as a ``session`` keyword argument on every invocation.
/// Opt-in at registration time via ``with_session=True`` on the
/// Python wrapper.
pub(crate) inject_session_on_call: bool,
}

/// Represents a user defined table function
#[pyclass(from_py_object, frozen, name = "TableFunction", module = "datafusion")]
#[derive(Debug, Clone)]
Expand All @@ -40,21 +54,21 @@ pub struct PyTableFunction {
pub(crate) inner: PyTableFunctionInner,
}

// TODO: Implement pure python based user defined table functions
#[derive(Debug, Clone)]
pub(crate) enum PyTableFunctionInner {
PythonFunction(Arc<Py<PyAny>>),
PythonFunction(PythonTableFunctionCallable),
FFIFunction(Arc<dyn TableFunctionImpl>),
}

#[pymethods]
impl PyTableFunction {
#[new]
#[pyo3(signature=(name, func, session))]
#[pyo3(signature=(name, func, session, inject_session_on_call=false))]
pub fn new(
name: &str,
func: Bound<'_, PyAny>,
session: Option<Bound<PyAny>>,
inject_session_on_call: bool,
) -> PyResult<Self> {
let inner = if func.hasattr("__datafusion_table_function__")? {
let py = func.py();
Expand All @@ -80,8 +94,10 @@ impl PyTableFunction {

PyTableFunctionInner::FFIFunction(foreign_func)
} else {
let py_obj = Arc::new(func.unbind());
PyTableFunctionInner::PythonFunction(py_obj)
PyTableFunctionInner::PythonFunction(PythonTableFunctionCallable {
callable: Arc::new(func.unbind()),
inject_session_on_call,
})
};

Ok(Self {
Expand All @@ -107,20 +123,66 @@ impl PyTableFunction {
}
}

/// Materialize a fresh :class:`PySessionContext` from the borrowed
/// ``&dyn Session`` handed in at call time.
///
/// Upstream invokes ``call_with_args`` with a trait-object reference
/// rather than an owned context; we downcast it to the canonical
/// :class:`SessionState` impl and rebuild a :class:`SessionContext`
/// (sharing the same registries via the Arc-heavy interior of
/// :class:`SessionState`).
///
/// The downcast is defensive. Every path that reaches a pure-Python
/// UDTF today hands us a `SessionState`: the SQL planner builds the
/// args from its own `SessionState`, and `PyTableFunction::__call__`
/// uses the global context's state. A non-`SessionState` session
/// (e.g. a `ForeignSession`) would only arrive if this UDTF were
/// exported across the FFI boundary to a foreign-library consumer,
/// which datafusion-python does not do. Should that change, this
/// returns an error rather than silently misbehaving.
fn py_session_from_session(session: &dyn Session) -> DataFusionResult<PySessionContext> {
let state = session
.as_any()
.downcast_ref::<SessionState>()
.ok_or_else(|| {
DataFusionError::Execution(
"Cannot expose this UDTF's calling session to Python: the \
session is not a SessionState. Drop the `session` keyword \
from the callback signature to fall back to the \
expression-only call form."
.to_string(),
)
})?;
Ok(PySessionContext::from(SessionContext::new_with_state(
state.clone(),
)))
}

#[allow(clippy::result_large_err)]
fn call_python_table_function(
func: &Arc<Py<PyAny>>,
args: &[Expr],
func: &PythonTableFunctionCallable,
args: TableFunctionArgs,
) -> DataFusionResult<Arc<dyn TableProvider>> {
let args = args
let py_session = if func.inject_session_on_call {
Some(py_session_from_session(args.session())?)
} else {
None
};
let py_exprs = args
.exprs()
.iter()
.map(|arg| PyExpr::from(arg.clone()))
.collect::<Vec<_>>();

// move |args: &[ArrayRef]| -> Result<ArrayRef, DataFusionError> {
Python::attach(|py| {
let py_args = PyTuple::new(py, args)?;
let provider_obj = func.call1(py, py_args)?;
let py_args = PyTuple::new(py, py_exprs)?;
let provider_obj = if let Some(session) = py_session {
let kwargs = PyDict::new(py);
kwargs.set_item("session", session.into_pyobject(py)?)?;
func.callable.call(py, py_args, Some(&kwargs))?
} else {
func.callable.call1(py, py_args)?
};
let provider = provider_obj.bind(py).clone();

Ok::<Arc<dyn TableProvider>, PyErr>(PyTable::new(provider, None)?.table)
Expand All @@ -132,8 +194,8 @@ impl TableFunctionImpl for PyTableFunction {
fn call_with_args(&self, args: TableFunctionArgs) -> DataFusionResult<Arc<dyn TableProvider>> {
match &self.inner {
PyTableFunctionInner::FFIFunction(func) => func.call_with_args(args),
PyTableFunctionInner::PythonFunction(obj) => {
call_python_table_function(obj, args.exprs())
PyTableFunctionInner::PythonFunction(callable) => {
call_python_table_function(callable, args)
}
}
}
Expand Down
39 changes: 39 additions & 0 deletions docs/source/user-guide/common-operations/udf-and-udfa.rst
Original file line number Diff line number Diff line change
Expand Up @@ -431,3 +431,42 @@ that you wish to expose via PyO3, you need to expose it as a ``PyCapsule``.
PyCapsule::new(py, provider, Some(name))
}
}

Accessing the Calling Session
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Pure-Python UDTFs can opt into receiving the calling
:py:class:`~datafusion.SessionContext` by registering with
``with_session=True``. The context is passed as a ``session`` keyword
argument on every invocation. Use it to look up registered tables,
UDFs, or session configuration from inside the callback.

.. code-block:: python

from datafusion import SessionContext, Table, udtf
from datafusion.context import TableProviderExportable
import pyarrow as pa
import pyarrow.dataset as ds

@udtf("list_tables", with_session=True)
def list_tables(*, session: SessionContext) -> TableProviderExportable:
names = sorted(session.catalog().schema().names())
batch = pa.RecordBatch.from_pydict({"name": names})
return Table(ds.dataset([batch]))

ctx = SessionContext()
ctx.register_batch("t1", pa.RecordBatch.from_pydict({"x": [1]}))
ctx.register_udtf(list_tables)
ctx.sql("SELECT * FROM list_tables()").show()

Without ``with_session=True``, the callback receives only the positional
expression arguments. The flag is opt-in so existing UDTFs keep working
unchanged.

The injected ``session`` is a fresh :py:class:`~datafusion.SessionContext`
wrapper backed by the same underlying state as the caller, so registries
(tables, UDFs, catalogs) are visible. Registry mutations (e.g. registering
a new table or UDF) propagate to the live session because the registries
are reference-counted and shared. Configuration changes made through the
wrapper (e.g. setting session options) do **not** propagate — the wrapper
holds its own clone of the session config.
100 changes: 87 additions & 13 deletions python/datafusion/user_defined.py
Original file line number Diff line number Diff line change
Expand Up @@ -1054,6 +1054,24 @@ def from_pycapsule(func: WindowUDFExportable) -> WindowUDF:
)


def _wrap_session_kwarg_for_udtf(func: Callable[..., Any]) -> Callable[..., Any]:
"""Adapt the raw internal session pyo3 object back to a Python wrapper.

The Rust call site forwards a ``datafusion._internal.SessionContext``,
but UDTF authors expect to interact with the public
:class:`datafusion.SessionContext` wrapper. This closure wraps the
internal object once per call before delegating to ``func``.
"""

@functools.wraps(func, updated=())
def adapter(*args: Any, session: Any, **kwargs: Any) -> Any:
wrapped = SessionContext.__new__(SessionContext)
wrapped.ctx = session
return func(*args, session=wrapped, **kwargs)

return adapter


class TableFunction:
"""Class for performing user-defined table functions (UDTF).

Expand All @@ -1062,14 +1080,44 @@ class TableFunction:
"""

def __init__(
self, name: str, func: Callable[[], any], ctx: SessionContext | None = None
self,
name: str,
func: Callable[..., Any],
ctx: SessionContext | None = None,
*,
with_session: bool = False,
) -> None:
"""Instantiate a user-defined table function (UDTF).

Set ``with_session=True`` to have the calling
:class:`SessionContext` passed as a ``session`` keyword argument
on each invocation. Use it inside the callback to look up
registered tables, UDFs, or session configuration. When
``with_session`` is ``False`` (the default), ``func`` is invoked
with the positional expression arguments only.

``with_session=True`` is only supported for pure-Python callables.
Passing it together with an FFI-exported table function (one
exposing ``__datafusion_table_function__``) raises
:class:`TypeError`.

Registry mutations performed through the injected session (such
as registering tables or UDFs) propagate to the caller's
:class:`SessionContext` because the registries are shared.
Configuration changes do **not** propagate; the wrapper holds
its own clone of the session config.

See :py:func:`udtf` for a convenience function and argument
descriptions.
"""
self._udtf = df_internal.TableFunction(name, func, ctx)
if with_session and hasattr(func, "__datafusion_table_function__"):
msg = (
"`with_session=True` is not supported for FFI-exported table "
"functions; session injection requires a pure-Python callable."
)
raise TypeError(msg)
registered = _wrap_session_kwarg_for_udtf(func) if with_session else func
self._udtf = df_internal.TableFunction(name, registered, ctx, with_session)

def __call__(self, *args: Expr) -> Any:
"""Execute the UDTF and return a table provider."""
Expand All @@ -1080,47 +1128,73 @@ def __call__(self, *args: Expr) -> Any:
@staticmethod
def udtf(
name: str,
*,
with_session: bool = False,
) -> Callable[..., Any]: ...

@overload
@staticmethod
def udtf(
func: Callable[[], Any],
func: Callable[..., Any],
name: str,
*,
with_session: bool = False,
) -> TableFunction: ...

@staticmethod
def udtf(*args: Any, **kwargs: Any):
"""Create a new User-Defined Table Function (UDTF)."""
def udtf(*args: Any, with_session: bool = False, **kwargs: Any):
"""Create a new User-Defined Table Function (UDTF).

Pass ``with_session=True`` to have the calling
:class:`SessionContext` injected as a ``session`` keyword
argument on each invocation.
"""
if args and callable(args[0]):
# Case 1: Used as a function, require the first parameter to be callable
return TableFunction._create_table_udf(*args, **kwargs)
return TableFunction._create_table_udf(
*args, with_session=with_session, **kwargs
)
if args and hasattr(args[0], "__datafusion_table_function__"):
# Case 2: We have a datafusion FFI provided function
if with_session:
msg = (
"`with_session=True` is not supported for FFI-exported "
"table functions; session injection requires a "
"pure-Python callable."
)
raise TypeError(msg)
return TableFunction(args[1], args[0])
# Case 3: Used as a decorator with parameters
return TableFunction._create_table_udf_decorator(*args, **kwargs)
return TableFunction._create_table_udf_decorator(
*args, with_session=with_session, **kwargs
)

@staticmethod
def _create_table_udf(
func: Callable[..., Any],
name: str,
*,
with_session: bool = False,
) -> TableFunction:
"""Create a TableFunction instance from function arguments."""
if not callable(func):
msg = "`func` must be callable."
raise TypeError(msg)

return TableFunction(name, func)
return TableFunction(name, func, with_session=with_session)

@staticmethod
def _create_table_udf_decorator(
name: str | None = None,
) -> Callable[[Callable[[], WindowEvaluator]], Callable[..., Expr]]:
"""Create a decorator for a WindowUDF."""

def decorator(func: Callable[[], WindowEvaluator]) -> Callable[..., Expr]:
return TableFunction._create_table_udf(func, name)
*,
with_session: bool = False,
) -> Callable[[Callable[..., Any]], TableFunction]:
"""Create a decorator for a TableFunction."""

def decorator(func: Callable[..., Any]) -> TableFunction:
return TableFunction._create_table_udf(
func, name, with_session=with_session
)

return decorator

Expand Down
Loading
Loading