Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion crates/core/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ use datafusion_proto::physical_plan::PhysicalExtensionCodec;
use datafusion_python_util::{
create_logical_extension_capsule, create_physical_extension_capsule,
ffi_logical_codec_from_pycapsule, get_global_ctx, get_tokio_runtime,
physical_codec_from_pycapsule, spawn_future, wait_for_future,
physical_codec_from_pycapsule, physical_optimizer_rule_from_pycapsule, spawn_future,
wait_for_future,
};
use object_store::ObjectStore;
use pyo3::IntoPyObjectExt;
Expand Down Expand Up @@ -1145,6 +1146,17 @@ impl PySessionContext {
self.ctx.remove_optimizer_rule(name)
}

pub fn add_physical_optimizer_rule(&self, rule: Bound<'_, PyAny>) -> PyDataFusionResult<()> {
let rule = physical_optimizer_rule_from_pycapsule(&rule)?;
let state_ref = self.ctx.state_ref();
let mut guard = state_ref.write();
let new_state = SessionStateBuilder::new_from_existing(guard.clone())
.with_physical_optimizer_rule(rule)
.build();
*guard = new_state;
Ok(())
}

pub fn table_provider(&self, name: &str, py: Python) -> PyResult<PyTable> {
let provider = wait_for_future(py, self.ctx.table_provider(name))
// Outer error: runtime/async failure
Expand Down
9 changes: 9 additions & 0 deletions crates/util/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ use datafusion::datasource::TableProvider;
use datafusion::execution::TaskContext;
use datafusion::execution::context::SessionContext;
use datafusion::logical_expr::Volatility;
use datafusion::physical_optimizer::PhysicalOptimizerRule;
use datafusion_ffi::execution::FFI_TaskContextProvider;
use datafusion_ffi::physical_optimizer::FFI_PhysicalOptimizerRule;
use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec;
use datafusion_ffi::proto::physical_extension_codec::FFI_PhysicalExtensionCodec;
use datafusion_ffi::table_provider::FFI_TableProvider;
Expand Down Expand Up @@ -332,6 +334,13 @@ from_pycapsule!(
dyn PhysicalExtensionCodec
);

from_pycapsule!(
physical_optimizer_rule_from_pycapsule,
"datafusion_physical_optimizer_rule",
FFI_PhysicalOptimizerRule,
dyn PhysicalOptimizerRule + Send + Sync
);

try_from_pycapsule!(
task_context_from_pycapsule,
"datafusion_task_context_provider",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import pyarrow as pa
from datafusion import SessionContext
from datafusion_ffi_example import MyPhysicalOptimizerRule


def test_ffi_physical_optimizer_rule_runs_during_planning():
"""A rule added via add_physical_optimizer_rule is invoked while the
physical plan is built, and the query still returns correct results."""
rule = MyPhysicalOptimizerRule()
ctx = SessionContext()
ctx.add_physical_optimizer_rule(rule)
batch = pa.RecordBatch.from_arrays(
[pa.array([1, 2, 3])],
names=["a"],
)
ctx.register_record_batches("t", [[batch]])

before = rule.optimize_calls()
result = ctx.sql("SELECT a FROM t").collect()
after = rule.optimize_calls()

assert after > before, (
f"Expected user FFI physical optimizer rule to fire, "
f"before={before} after={after}"
)
assert result[0].column(0).to_pylist() == [1, 2, 3]
3 changes: 3 additions & 0 deletions examples/datafusion-ffi-example/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use crate::catalog_provider::{FixedSchemaProvider, MyCatalogProvider, MyCatalogP
use crate::config::MyConfig;
use crate::logical_extension_codec::MyLogicalExtensionCodec;
use crate::physical_extension_codec::MyPhysicalExtensionCodec;
use crate::physical_optimizer::MyPhysicalOptimizerRule;
use crate::scalar_udf::IsNullUDF;
use crate::table_function::MyTableFunction;
use crate::table_provider::MyTableProvider;
Expand All @@ -33,6 +34,7 @@ pub(crate) mod catalog_provider;
pub(crate) mod config;
pub(crate) mod logical_extension_codec;
pub(crate) mod physical_extension_codec;
pub(crate) mod physical_optimizer;
pub(crate) mod scalar_udf;
pub(crate) mod table_function;
pub(crate) mod table_provider;
Expand All @@ -55,5 +57,6 @@ fn datafusion_ffi_example(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<MyConfig>()?;
m.add_class::<MyLogicalExtensionCodec>()?;
m.add_class::<MyPhysicalExtensionCodec>()?;
m.add_class::<MyPhysicalOptimizerRule>()?;
Ok(())
}
98 changes: 98 additions & 0 deletions examples/datafusion-ffi-example/src/physical_optimizer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};

use datafusion::common::Result;
use datafusion::common::config::ConfigOptions;
use datafusion::physical_optimizer::PhysicalOptimizerRule;
use datafusion::physical_plan::ExecutionPlan;
use datafusion_ffi::physical_optimizer::FFI_PhysicalOptimizerRule;
use datafusion_python_util::get_tokio_runtime;
use pyo3::prelude::*;
use pyo3::types::PyCapsule;

/// A physical optimizer rule that leaves every plan unchanged but bumps a
/// shared counter each time it runs. Tests use the counter to prove that a
/// session built with this rule actually routed physical planning through a
/// user-supplied [`PhysicalOptimizerRule`] over FFI.
#[derive(Debug)]
struct CountingPhysicalOptimizerRule {
optimize_calls: Arc<AtomicUsize>,
}

impl PhysicalOptimizerRule for CountingPhysicalOptimizerRule {
fn optimize(
&self,
plan: Arc<dyn ExecutionPlan>,
_config: &ConfigOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
self.optimize_calls.fetch_add(1, Ordering::SeqCst);
Ok(plan)
}

fn name(&self) -> &str {
"counting_physical_optimizer_rule"
}

fn schema_check(&self) -> bool {
// The plan is returned unchanged, so the schema is preserved.
true
}
}

/// Python-visible handle that produces an [`FFI_PhysicalOptimizerRule`] and
/// exposes the shared call counter.
#[pyclass(
from_py_object,
name = "MyPhysicalOptimizerRule",
module = "datafusion_ffi_example",
subclass
)]
#[derive(Debug, Default, Clone)]
pub(crate) struct MyPhysicalOptimizerRule {
optimize_calls: Arc<AtomicUsize>,
}

#[pymethods]
impl MyPhysicalOptimizerRule {
#[new]
fn new() -> Self {
Self::default()
}

fn optimize_calls(&self) -> usize {
self.optimize_calls.load(Ordering::SeqCst)
}

fn __datafusion_physical_optimizer_rule__<'py>(
&self,
py: Python<'py>,
) -> PyResult<Bound<'py, PyCapsule>> {
let rule: Arc<dyn PhysicalOptimizerRule + Send + Sync> =
Arc::new(CountingPhysicalOptimizerRule {
optimize_calls: Arc::clone(&self.optimize_calls),
});

let runtime = get_tokio_runtime().handle().clone();
let ffi = FFI_PhysicalOptimizerRule::new(rule, Some(runtime));

let name = cr"datafusion_physical_optimizer_rule".into();
PyCapsule::new(py, ffi, Some(name))
}
}
34 changes: 34 additions & 0 deletions python/datafusion/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,16 @@ class TableProviderExportable(Protocol):
def __datafusion_table_provider__(self, session: Any) -> object: ... # noqa: D105


class PhysicalOptimizerRuleExportable(Protocol):
"""Type hint for object that has __datafusion_physical_optimizer_rule__ PyCapsule.

The method returns a PyCapsule wrapping an ``FFI_PhysicalOptimizerRule``,
typically produced by a separate compiled extension.
"""

def __datafusion_physical_optimizer_rule__(self) -> object: ... # noqa: D105


class SessionConfig:
"""Session configuration options."""

Expand Down Expand Up @@ -1378,6 +1388,30 @@ def remove_optimizer_rule(self, name: str) -> bool:
"""
return self.ctx.remove_optimizer_rule(name)

def add_physical_optimizer_rule(
self, rule: PhysicalOptimizerRuleExportable
) -> None:
"""Append a user-defined physical optimizer rule to the session.

The rule is imported via its ``__datafusion_physical_optimizer_rule__``
PyCapsule, typically produced by a separate compiled extension. The
underlying :class:`SessionState` is rebuilt from its current state
with the new rule appended, so previously registered tables, UDFs,
and catalogs are preserved.

Args:
rule: Object exposing ``__datafusion_physical_optimizer_rule__``,
a :class:`PhysicalOptimizerRuleExportable`.

Examples:
>>> from datafusion import SessionContext
>>> ctx = SessionContext()
>>> from my_extension import MyPhysicalOptimizerRule # doctest: +SKIP
>>> rule = MyPhysicalOptimizerRule() # doctest: +SKIP
>>> ctx.add_physical_optimizer_rule(rule) # doctest: +SKIP
"""
self.ctx.add_physical_optimizer_rule(rule)

def table_provider(self, name: str) -> Table:
"""Return the :py:class:`~datafusion.catalog.Table` for the given table name.

Expand Down
Loading