diff --git a/crates/core/src/context.rs b/crates/core/src/context.rs index 642afeef7..508a00d43 100644 --- a/crates/core/src/context.rs +++ b/crates/core/src/context.rs @@ -59,7 +59,8 @@ use datafusion_proto::physical_plan::PhysicalExtensionCodec; use datafusion_python_util::{ create_logical_extension_capsule, create_physical_extension_capsule, ffi_logical_codec_from_pycapsule, get_global_ctx, get_tokio_runtime, - physical_codec_from_pycapsule, spawn_future, wait_for_future, + physical_codec_from_pycapsule, physical_optimizer_rule_from_pycapsule, spawn_future, + wait_for_future, }; use object_store::ObjectStore; use pyo3::IntoPyObjectExt; @@ -1145,6 +1146,17 @@ impl PySessionContext { self.ctx.remove_optimizer_rule(name) } + pub fn add_physical_optimizer_rule(&self, rule: Bound<'_, PyAny>) -> PyDataFusionResult<()> { + let rule = physical_optimizer_rule_from_pycapsule(&rule)?; + let state_ref = self.ctx.state_ref(); + let mut guard = state_ref.write(); + let new_state = SessionStateBuilder::new_from_existing(guard.clone()) + .with_physical_optimizer_rule(rule) + .build(); + *guard = new_state; + Ok(()) + } + pub fn table_provider(&self, name: &str, py: Python) -> PyResult { let provider = wait_for_future(py, self.ctx.table_provider(name)) // Outer error: runtime/async failure diff --git a/crates/util/src/lib.rs b/crates/util/src/lib.rs index 72dc9aafc..28c8834e9 100644 --- a/crates/util/src/lib.rs +++ b/crates/util/src/lib.rs @@ -24,7 +24,9 @@ use datafusion::datasource::TableProvider; use datafusion::execution::TaskContext; use datafusion::execution::context::SessionContext; use datafusion::logical_expr::Volatility; +use datafusion::physical_optimizer::PhysicalOptimizerRule; use datafusion_ffi::execution::FFI_TaskContextProvider; +use datafusion_ffi::physical_optimizer::FFI_PhysicalOptimizerRule; use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec; use datafusion_ffi::proto::physical_extension_codec::FFI_PhysicalExtensionCodec; use datafusion_ffi::table_provider::FFI_TableProvider; @@ -332,6 +334,13 @@ from_pycapsule!( dyn PhysicalExtensionCodec ); +from_pycapsule!( + physical_optimizer_rule_from_pycapsule, + "datafusion_physical_optimizer_rule", + FFI_PhysicalOptimizerRule, + dyn PhysicalOptimizerRule + Send + Sync +); + try_from_pycapsule!( task_context_from_pycapsule, "datafusion_task_context_provider", diff --git a/examples/datafusion-ffi-example/python/tests/_test_physical_optimizer_rule.py b/examples/datafusion-ffi-example/python/tests/_test_physical_optimizer_rule.py new file mode 100644 index 000000000..0c877d78e --- /dev/null +++ b/examples/datafusion-ffi-example/python/tests/_test_physical_optimizer_rule.py @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import pyarrow as pa +from datafusion import SessionContext +from datafusion_ffi_example import MyPhysicalOptimizerRule + + +def test_ffi_physical_optimizer_rule_runs_during_planning(): + """A rule added via add_physical_optimizer_rule is invoked while the + physical plan is built, and the query still returns correct results.""" + rule = MyPhysicalOptimizerRule() + ctx = SessionContext() + ctx.add_physical_optimizer_rule(rule) + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3])], + names=["a"], + ) + ctx.register_record_batches("t", [[batch]]) + + before = rule.optimize_calls() + result = ctx.sql("SELECT a FROM t").collect() + after = rule.optimize_calls() + + assert after > before, ( + f"Expected user FFI physical optimizer rule to fire, " + f"before={before} after={after}" + ) + assert result[0].column(0).to_pylist() == [1, 2, 3] diff --git a/examples/datafusion-ffi-example/src/lib.rs b/examples/datafusion-ffi-example/src/lib.rs index 3323ac982..eccf7b81a 100644 --- a/examples/datafusion-ffi-example/src/lib.rs +++ b/examples/datafusion-ffi-example/src/lib.rs @@ -22,6 +22,7 @@ use crate::catalog_provider::{FixedSchemaProvider, MyCatalogProvider, MyCatalogP use crate::config::MyConfig; use crate::logical_extension_codec::MyLogicalExtensionCodec; use crate::physical_extension_codec::MyPhysicalExtensionCodec; +use crate::physical_optimizer::MyPhysicalOptimizerRule; use crate::scalar_udf::IsNullUDF; use crate::table_function::MyTableFunction; use crate::table_provider::MyTableProvider; @@ -33,6 +34,7 @@ pub(crate) mod catalog_provider; pub(crate) mod config; pub(crate) mod logical_extension_codec; pub(crate) mod physical_extension_codec; +pub(crate) mod physical_optimizer; pub(crate) mod scalar_udf; pub(crate) mod table_function; pub(crate) mod table_provider; @@ -55,5 +57,6 @@ fn datafusion_ffi_example(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; Ok(()) } diff --git a/examples/datafusion-ffi-example/src/physical_optimizer.rs b/examples/datafusion-ffi-example/src/physical_optimizer.rs new file mode 100644 index 000000000..0acd1bb4a --- /dev/null +++ b/examples/datafusion-ffi-example/src/physical_optimizer.rs @@ -0,0 +1,98 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; + +use datafusion::common::Result; +use datafusion::common::config::ConfigOptions; +use datafusion::physical_optimizer::PhysicalOptimizerRule; +use datafusion::physical_plan::ExecutionPlan; +use datafusion_ffi::physical_optimizer::FFI_PhysicalOptimizerRule; +use datafusion_python_util::get_tokio_runtime; +use pyo3::prelude::*; +use pyo3::types::PyCapsule; + +/// A physical optimizer rule that leaves every plan unchanged but bumps a +/// shared counter each time it runs. Tests use the counter to prove that a +/// session built with this rule actually routed physical planning through a +/// user-supplied [`PhysicalOptimizerRule`] over FFI. +#[derive(Debug)] +struct CountingPhysicalOptimizerRule { + optimize_calls: Arc, +} + +impl PhysicalOptimizerRule for CountingPhysicalOptimizerRule { + fn optimize( + &self, + plan: Arc, + _config: &ConfigOptions, + ) -> Result> { + self.optimize_calls.fetch_add(1, Ordering::SeqCst); + Ok(plan) + } + + fn name(&self) -> &str { + "counting_physical_optimizer_rule" + } + + fn schema_check(&self) -> bool { + // The plan is returned unchanged, so the schema is preserved. + true + } +} + +/// Python-visible handle that produces an [`FFI_PhysicalOptimizerRule`] and +/// exposes the shared call counter. +#[pyclass( + from_py_object, + name = "MyPhysicalOptimizerRule", + module = "datafusion_ffi_example", + subclass +)] +#[derive(Debug, Default, Clone)] +pub(crate) struct MyPhysicalOptimizerRule { + optimize_calls: Arc, +} + +#[pymethods] +impl MyPhysicalOptimizerRule { + #[new] + fn new() -> Self { + Self::default() + } + + fn optimize_calls(&self) -> usize { + self.optimize_calls.load(Ordering::SeqCst) + } + + fn __datafusion_physical_optimizer_rule__<'py>( + &self, + py: Python<'py>, + ) -> PyResult> { + let rule: Arc = + Arc::new(CountingPhysicalOptimizerRule { + optimize_calls: Arc::clone(&self.optimize_calls), + }); + + let runtime = get_tokio_runtime().handle().clone(); + let ffi = FFI_PhysicalOptimizerRule::new(rule, Some(runtime)); + + let name = cr"datafusion_physical_optimizer_rule".into(); + PyCapsule::new(py, ffi, Some(name)) + } +} diff --git a/python/datafusion/context.py b/python/datafusion/context.py index 5c3501941..0c4ad6b3d 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -130,6 +130,16 @@ class TableProviderExportable(Protocol): def __datafusion_table_provider__(self, session: Any) -> object: ... # noqa: D105 +class PhysicalOptimizerRuleExportable(Protocol): + """Type hint for object that has __datafusion_physical_optimizer_rule__ PyCapsule. + + The method returns a PyCapsule wrapping an ``FFI_PhysicalOptimizerRule``, + typically produced by a separate compiled extension. + """ + + def __datafusion_physical_optimizer_rule__(self) -> object: ... # noqa: D105 + + class SessionConfig: """Session configuration options.""" @@ -1378,6 +1388,30 @@ def remove_optimizer_rule(self, name: str) -> bool: """ return self.ctx.remove_optimizer_rule(name) + def add_physical_optimizer_rule( + self, rule: PhysicalOptimizerRuleExportable + ) -> None: + """Append a user-defined physical optimizer rule to the session. + + The rule is imported via its ``__datafusion_physical_optimizer_rule__`` + PyCapsule, typically produced by a separate compiled extension. The + underlying :class:`SessionState` is rebuilt from its current state + with the new rule appended, so previously registered tables, UDFs, + and catalogs are preserved. + + Args: + rule: Object exposing ``__datafusion_physical_optimizer_rule__``, + a :class:`PhysicalOptimizerRuleExportable`. + + Examples: + >>> from datafusion import SessionContext + >>> ctx = SessionContext() + >>> from my_extension import MyPhysicalOptimizerRule # doctest: +SKIP + >>> rule = MyPhysicalOptimizerRule() # doctest: +SKIP + >>> ctx.add_physical_optimizer_rule(rule) # doctest: +SKIP + """ + self.ctx.add_physical_optimizer_rule(rule) + def table_provider(self, name: str) -> Table: """Return the :py:class:`~datafusion.catalog.Table` for the given table name.