From 3f0f4e3b8f5f1e6716120fe41238f1f7683e04c9 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Fri, 29 May 2026 20:25:09 -0400 Subject: [PATCH 01/14] feat: expose Spark-compatible functions (#1482) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add `datafusion.functions.spark` module exposing the upstream `datafusion-spark` crate's UDF/UDAF library (~87 functions across string, math, datetime, hash, array, aggregate, bitwise, bitmap, conditional, collection, conversion, json, map, url categories). For DataFrame use, import the typed Python wrappers from `datafusion.functions.spark`. For SQL use, call `SessionContext.enable_spark_functions()` to register the Spark UDFs by name (overriding DataFusion built-ins of the same name with their Spark semantics — NULL-propagating `concat`, 1-indexed `substring`, HALF_UP `round`, etc.). Co-Authored-By: Claude Opus 4.7 (1M context) --- Cargo.lock | 43 + Cargo.toml | 2 + crates/core/Cargo.toml | 1 + crates/core/src/context.rs | 15 + crates/core/src/functions.rs | 2 +- crates/core/src/lib.rs | 6 + crates/core/src/spark_functions.rs | 409 +++++ .../common-operations/functions.rst | 6 + .../user-guide/common-operations/index.rst | 1 + .../common-operations/spark-functions.rst | 121 ++ pyproject.toml | 1 + python/datafusion/context.py | 22 + .../{functions.py => functions/__init__.py} | 2 + python/datafusion/functions/spark.py | 1362 +++++++++++++++++ python/tests/test_spark_functions.py | 264 ++++ 15 files changed, 2256 insertions(+), 1 deletion(-) create mode 100644 crates/core/src/spark_functions.rs create mode 100644 docs/source/user-guide/common-operations/spark-functions.rst rename python/datafusion/{functions.py => functions/__init__.py} (99%) create mode 100644 python/datafusion/functions/spark.py create mode 100644 python/tests/test_spark_functions.py diff --git a/Cargo.lock b/Cargo.lock index d3cedb628..7f5f63e96 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1505,6 +1505,7 @@ dependencies = [ "datafusion-ffi", "datafusion-proto", "datafusion-python-util", + "datafusion-spark", "datafusion-substrait", "futures", "log", @@ -1549,6 +1550,34 @@ dependencies = [ "parking_lot", ] +[[package]] +name = "datafusion-spark" +version = "54.0.0" +source = "git+https://github.com/apache/datafusion?rev=1321d60cc37ee487d1e7ce7f501357c3236b2542#1321d60cc37ee487d1e7ce7f501357c3236b2542" +dependencies = [ + "arrow", + "bigdecimal", + "chrono", + "crc32fast", + "datafusion-catalog", + "datafusion-common", + "datafusion-execution", + "datafusion-expr", + "datafusion-functions", + "datafusion-functions-aggregate", + "datafusion-functions-aggregate-common", + "datafusion-functions-nested", + "log", + "num-traits", + "percent-encoding", + "rand 0.9.4", + "serde_json", + "sha1", + "sha2", + "twox-hash", + "url", +] + [[package]] name = "datafusion-sql" version = "54.0.0" @@ -3515,6 +3544,17 @@ dependencies = [ "unsafe-libyaml", ] +[[package]] +name = "sha1" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aacc4cc499359472b4abe1bf11d0b12e688af9a805fa5e3016f9a386dc2d0214" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest 0.11.3", +] + [[package]] name = "sha2" version = "0.11.0" @@ -4009,6 +4049,9 @@ name = "twox-hash" version = "2.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ea3136b675547379c4bd395ca6b938e5ad3c3d20fad76e7fe85f9e0d011419c" +dependencies = [ + "rand 0.9.4", +] [[package]] name = "typenum" diff --git a/Cargo.toml b/Cargo.toml index e72c22368..a8e8559cb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,6 +48,7 @@ datafusion-catalog = { version = "54", default-features = false } datafusion-common = { version = "54", default-features = false } datafusion-functions-aggregate = { version = "54" } datafusion-functions-window = { version = "54" } +datafusion-spark = { version = "54" } datafusion-expr = { version = "54" } prost = "0.14.3" serde_json = "1" @@ -79,4 +80,5 @@ datafusion-catalog = { git = "https://github.com/apache/datafusion", rev = "1321 datafusion-common = { git = "https://github.com/apache/datafusion", rev = "1321d60cc37ee487d1e7ce7f501357c3236b2542" } datafusion-functions-aggregate = { git = "https://github.com/apache/datafusion", rev = "1321d60cc37ee487d1e7ce7f501357c3236b2542" } datafusion-functions-window = { git = "https://github.com/apache/datafusion", rev = "1321d60cc37ee487d1e7ce7f501357c3236b2542" } +datafusion-spark = { git = "https://github.com/apache/datafusion", rev = "1321d60cc37ee487d1e7ce7f501357c3236b2542" } datafusion-expr = { git = "https://github.com/apache/datafusion", rev = "1321d60cc37ee487d1e7ce7f501357c3236b2542" } diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index 1f5b4e305..2e8cf6c92 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -53,6 +53,7 @@ datafusion = { workspace = true, features = ["avro", "unicode_expressions"] } datafusion-substrait = { workspace = true, optional = true } datafusion-proto = { workspace = true } datafusion-ffi = { workspace = true } +datafusion-spark = { workspace = true } prost = { workspace = true } # keep in line with `datafusion-substrait` serde_json = { workspace = true } uuid = { workspace = true, features = ["v4"] } diff --git a/crates/core/src/context.rs b/crates/core/src/context.rs index 4606246cf..2b7ac797e 100644 --- a/crates/core/src/context.rs +++ b/crates/core/src/context.rs @@ -1059,6 +1059,21 @@ impl PySessionContext { Ok(()) } + /// Register all `datafusion-spark` UDFs/UDAFs/UDWFs, overriding any built-in + /// DataFusion functions of the same name with their Spark-semantics version. + pub fn enable_spark_functions(&self) -> PyResult<()> { + for udf in datafusion_spark::all_default_scalar_functions() { + self.ctx.register_udf((*udf).clone()); + } + for udaf in datafusion_spark::all_default_aggregate_functions() { + self.ctx.register_udaf((*udaf).clone()); + } + for udwf in datafusion_spark::all_default_window_functions() { + self.ctx.register_udwf((*udwf).clone()); + } + Ok(()) + } + pub fn deregister_udaf(&self, name: &str) { self.ctx.deregister_udaf(name); } diff --git a/crates/core/src/functions.rs b/crates/core/src/functions.rs index 5f47d123b..b0af40d62 100644 --- a/crates/core/src/functions.rs +++ b/crates/core/src/functions.rs @@ -31,7 +31,7 @@ use crate::expr::conditional_expr::PyCaseBuilder; use crate::expr::sort_expr::{PySortExpr, to_sort_expressions}; use crate::expr::window::PyWindowFrame; -fn add_builder_fns_to_aggregate( +pub(crate) fn add_builder_fns_to_aggregate( agg_fn: Expr, distinct: Option, filter: Option, diff --git a/crates/core/src/lib.rs b/crates/core/src/lib.rs index 79bf77717..c3c549aa8 100644 --- a/crates/core/src/lib.rs +++ b/crates/core/src/lib.rs @@ -48,6 +48,8 @@ pub mod physical_plan; mod pyarrow_filter_expression; pub mod pyarrow_util; mod record_batch; +#[allow(clippy::borrow_deref_ref)] +mod spark_functions; pub mod sql; pub mod store; pub mod table; @@ -123,6 +125,10 @@ fn _internal(py: Python, m: Bound<'_, PyModule>) -> PyResult<()> { // Register the functions as a submodule let funcs = PyModule::new(py, "functions")?; functions::init_module(&funcs)?; + // Spark-compatible functions live under `functions.spark`. + let spark_funcs = PyModule::new(py, "spark")?; + spark_functions::init_module(&spark_funcs)?; + funcs.add_submodule(&spark_funcs)?; m.add_submodule(&funcs)?; let store = PyModule::new(py, "object_store")?; diff --git a/crates/core/src/spark_functions.rs b/crates/core/src/spark_functions.rs new file mode 100644 index 000000000..2a4d938ff --- /dev/null +++ b/crates/core/src/spark_functions.rs @@ -0,0 +1,409 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! PyO3 wrappers for the [`datafusion-spark`] crate. +//! +//! Exposes Spark-compatible scalar and aggregate function builders for use +//! from Python under `datafusion.functions.spark`. Each scalar wrapper +//! resolves the underlying `ScalarUDF` from `datafusion_spark` and builds an +//! `Expr::ScalarFunction` directly, so behaviour matches what +//! `datafusion_spark::register_all` registers for SQL. + +use datafusion::logical_expr::Expr; +use datafusion::logical_expr::expr::ScalarFunction; +use datafusion_spark::function::{ + aggregate as fn_aggregate, array as fn_array, bitmap as fn_bitmap, bitwise as fn_bitwise, + collection as fn_collection, conditional as fn_conditional, conversion as fn_conversion, + datetime as fn_datetime, hash as fn_hash, json as fn_json, map as fn_map, math as fn_math, + string as fn_string, url as fn_url, +}; +use pyo3::prelude::*; +use pyo3::wrap_pyfunction; + +use crate::common::data_type::NullTreatment; +use crate::errors::PyDataFusionResult; +use crate::expr::PyExpr; +use crate::expr::sort_expr::PySortExpr; +use crate::functions::add_builder_fns_to_aggregate; + +/// Build an `Expr::ScalarFunction` by invoking the given `ScalarUDF` factory +/// with the supplied arguments. +macro_rules! spark_udf_fixed { + ($PY_NAME:ident, $UDF_PATH:path, $($arg:ident),+ $(,)?) => { + #[pyfunction] + fn $PY_NAME($($arg: PyExpr),+) -> PyExpr { + let udf = $UDF_PATH(); + let args: Vec = vec![$($arg.into()),+]; + Expr::ScalarFunction(ScalarFunction::new_udf(udf, args)).into() + } + }; +} + +/// Build an `Expr::ScalarFunction` from a variadic `*args` list. +macro_rules! spark_udf_vec { + ($PY_NAME:ident, $UDF_PATH:path) => { + #[pyfunction] + #[pyo3(signature = (*args))] + fn $PY_NAME(args: Vec) -> PyExpr { + let udf = $UDF_PATH(); + let args: Vec = args.into_iter().map(Into::into).collect(); + Expr::ScalarFunction(ScalarFunction::new_udf(udf, args)).into() + } + }; +} + +/// Build an aggregate `Expr` from a Spark `AggregateUDF` factory and the +/// optional builder fields (distinct/filter/order_by/null_treatment), +/// mirroring the existing `aggregate_function!` macro for default DataFusion +/// aggregates. +macro_rules! spark_aggregate_fixed { + ($PY_NAME:ident, $UDF_PATH:path, $($arg:ident),+ $(,)?) => { + #[pyfunction] + #[pyo3(signature = ($($arg),+, distinct=None, filter=None, order_by=None, null_treatment=None))] + fn $PY_NAME( + $($arg: PyExpr),+, + distinct: Option, + filter: Option, + order_by: Option>, + null_treatment: Option, + ) -> PyDataFusionResult { + let udf = $UDF_PATH(); + let args: Vec = vec![$($arg.into()),+]; + let agg_fn = udf.call(args); + add_builder_fns_to_aggregate(agg_fn, distinct, filter, order_by, null_treatment) + } + }; +} + +// --------------------------------------------------------------------------- +// Aggregate functions +// --------------------------------------------------------------------------- + +spark_aggregate_fixed!(avg, fn_aggregate::avg, arg1); +spark_aggregate_fixed!(try_sum, fn_aggregate::try_sum, arg1); +spark_aggregate_fixed!(collect_list, fn_aggregate::collect_list, arg1); +spark_aggregate_fixed!(collect_set, fn_aggregate::collect_set, arg1); + +// --------------------------------------------------------------------------- +// Array functions +// --------------------------------------------------------------------------- + +spark_udf_fixed!( + array_contains, + fn_array::spark_array_contains, + array, + element +); +spark_udf_vec!(array, fn_array::array); +spark_udf_fixed!(shuffle, fn_array::shuffle, arg1); +spark_udf_fixed!(array_repeat, fn_array::array_repeat, element, count); +spark_udf_fixed!(slice, fn_array::slice, arg_array, start, length); + +// --------------------------------------------------------------------------- +// Bitmap functions +// --------------------------------------------------------------------------- + +spark_udf_fixed!(bitmap_count, fn_bitmap::bitmap_count, arg1); +spark_udf_fixed!(bitmap_bit_position, fn_bitmap::bitmap_bit_position, arg1); +spark_udf_fixed!(bitmap_bucket_number, fn_bitmap::bitmap_bucket_number, arg1); + +// --------------------------------------------------------------------------- +// Bitwise functions +// --------------------------------------------------------------------------- + +spark_udf_fixed!(bit_get, fn_bitwise::bit_get, col, pos); +spark_udf_fixed!(bit_count, fn_bitwise::bit_count, col); +spark_udf_fixed!(bitwise_not, fn_bitwise::bitwise_not, col); +spark_udf_fixed!(shiftleft, fn_bitwise::shiftleft, value, shift); +spark_udf_fixed!(shiftright, fn_bitwise::shiftright, value, shift); +spark_udf_fixed!( + shiftrightunsigned, + fn_bitwise::shiftrightunsigned, + value, + shift +); + +// --------------------------------------------------------------------------- +// Collection functions +// --------------------------------------------------------------------------- + +spark_udf_fixed!(size, fn_collection::size, arg1); + +// --------------------------------------------------------------------------- +// Conditional functions +// --------------------------------------------------------------------------- + +// Python keyword `if` → exposed as `if_`. +spark_udf_fixed!(if_, fn_conditional::r#if, condition, if_true, if_false); + +// --------------------------------------------------------------------------- +// Conversion functions +// --------------------------------------------------------------------------- + +// `spark_cast` requires session ConfigOptions in upstream; use the +// crate-provided `expr_fn` helper which applies defaults. +#[pyfunction] +fn spark_cast(arg1: PyExpr, arg2: PyExpr) -> PyExpr { + fn_conversion::expr_fn::spark_cast(arg1.into(), arg2.into()).into() +} + +// --------------------------------------------------------------------------- +// Datetime functions +// --------------------------------------------------------------------------- + +spark_udf_fixed!(add_months, fn_datetime::add_months, start_date, num_months); +spark_udf_fixed!(date_add, fn_datetime::date_add, start_date, days); +spark_udf_fixed!(date_sub, fn_datetime::date_sub, start_date, days); +spark_udf_fixed!(hour, fn_datetime::hour, arg1); +spark_udf_fixed!(minute, fn_datetime::minute, arg1); +spark_udf_fixed!(second, fn_datetime::second, arg1); +spark_udf_fixed!(last_day, fn_datetime::last_day, arg1); +spark_udf_fixed!( + make_dt_interval, + fn_datetime::make_dt_interval, + days, + hours, + mins, + secs +); +spark_udf_fixed!( + make_interval, + fn_datetime::make_interval, + years, + months, + weeks, + days, + hours, + mins, + secs +); +spark_udf_fixed!(next_day, fn_datetime::next_day, start_date, day_of_week); +spark_udf_fixed!(date_diff, fn_datetime::date_diff, end_date, start_date); +spark_udf_fixed!(date_trunc, fn_datetime::date_trunc, fmt, ts); +spark_udf_fixed!(time_trunc, fn_datetime::time_trunc, fmt, t); +spark_udf_fixed!(trunc, fn_datetime::trunc, dt, fmt); +spark_udf_fixed!(date_part, fn_datetime::date_part, field, source); +spark_udf_fixed!(from_utc_timestamp, fn_datetime::from_utc_timestamp, ts, tz); +spark_udf_fixed!(to_utc_timestamp, fn_datetime::to_utc_timestamp, ts, tz); +spark_udf_fixed!(unix_date, fn_datetime::unix_date, dt); +spark_udf_fixed!(unix_micros, fn_datetime::unix_micros, ts); +spark_udf_fixed!(unix_millis, fn_datetime::unix_millis, ts); +spark_udf_fixed!(unix_seconds, fn_datetime::unix_seconds, ts); + +// --------------------------------------------------------------------------- +// Hash functions +// --------------------------------------------------------------------------- + +spark_udf_fixed!(crc32, fn_hash::crc32, arg1); +spark_udf_fixed!(sha1, fn_hash::sha1, arg1); +spark_udf_fixed!(sha2, fn_hash::sha2, arg1, bit_length); +spark_udf_vec!(xxhash64, fn_hash::xxhash64); + +// --------------------------------------------------------------------------- +// JSON functions +// --------------------------------------------------------------------------- + +spark_udf_vec!(json_tuple, fn_json::json_tuple); + +// --------------------------------------------------------------------------- +// Map functions +// --------------------------------------------------------------------------- + +spark_udf_fixed!(map_from_arrays, fn_map::map_from_arrays, keys, values); +spark_udf_fixed!(map_from_entries, fn_map::map_from_entries, arg1); +spark_udf_fixed!( + str_to_map, + fn_map::str_to_map, + text, + pair_delim, + key_value_delim +); + +// --------------------------------------------------------------------------- +// Math functions +// --------------------------------------------------------------------------- + +spark_udf_fixed!(abs, fn_math::abs, arg1); +spark_udf_fixed!(ceil, fn_math::ceil, arg1); +spark_udf_fixed!(expm1, fn_math::expm1, arg1); +spark_udf_fixed!(factorial, fn_math::factorial, arg1); +spark_udf_fixed!(floor, fn_math::floor, arg1); +spark_udf_fixed!(hex, fn_math::hex, arg1); +spark_udf_fixed!(modulus, fn_math::modulus, dividend, divisor); +spark_udf_fixed!(pmod, fn_math::pmod, dividend, divisor); +spark_udf_fixed!(rint, fn_math::rint, arg1); +spark_udf_fixed!(round, fn_math::round, value, scale); +spark_udf_fixed!(unhex, fn_math::unhex, arg1); +spark_udf_fixed!( + width_bucket, + fn_math::width_bucket, + value, + min_value, + max_value, + num_buckets +); +spark_udf_fixed!(csc, fn_math::csc, arg1); +spark_udf_fixed!(sec, fn_math::sec, arg1); +spark_udf_fixed!(negative, fn_math::negative, arg1); +spark_udf_fixed!(bin, fn_math::bin, arg1); + +// --------------------------------------------------------------------------- +// String functions +// --------------------------------------------------------------------------- + +spark_udf_fixed!(ascii, fn_string::ascii, arg1); +spark_udf_fixed!(base64, fn_string::base64, bin_input); +// `char` collides with the Rust primitive type in macro hygiene; rename the +// Rust ident and re-expose under the original name to Python. +#[pyfunction] +#[pyo3(name = "char")] +fn char_fn(arg1: PyExpr) -> PyExpr { + let udf = fn_string::char(); + Expr::ScalarFunction(ScalarFunction::new_udf(udf, vec![arg1.into()])).into() +} +spark_udf_vec!(concat, fn_string::concat); +spark_udf_vec!(elt, fn_string::elt); +spark_udf_fixed!(ilike, fn_string::ilike, str, pattern); +spark_udf_fixed!(length, fn_string::length, arg1); +spark_udf_fixed!(like, fn_string::like, str, pattern); +spark_udf_fixed!(luhn_check, fn_string::luhn_check, arg1); +spark_udf_vec!(format_string, fn_string::format_string); +spark_udf_fixed!(space, fn_string::space, arg1); +spark_udf_fixed!(substring, fn_string::substring, str, pos, length); +spark_udf_fixed!(unbase64, fn_string::unbase64, str); +spark_udf_fixed!(soundex, fn_string::soundex, str); +spark_udf_fixed!(is_valid_utf8, fn_string::is_valid_utf8, str); +spark_udf_fixed!(make_valid_utf8, fn_string::make_valid_utf8, str); + +// --------------------------------------------------------------------------- +// URL functions +// --------------------------------------------------------------------------- + +spark_udf_vec!(parse_url, fn_url::parse_url); +spark_udf_vec!(try_parse_url, fn_url::try_parse_url); +spark_udf_vec!(url_decode, fn_url::url_decode); +spark_udf_vec!(try_url_decode, fn_url::try_url_decode); +spark_udf_vec!(url_encode, fn_url::url_encode); + +// --------------------------------------------------------------------------- +// Module init +// --------------------------------------------------------------------------- + +pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { + // Aggregate + m.add_wrapped(wrap_pyfunction!(avg))?; + m.add_wrapped(wrap_pyfunction!(try_sum))?; + m.add_wrapped(wrap_pyfunction!(collect_list))?; + m.add_wrapped(wrap_pyfunction!(collect_set))?; + // Array + m.add_wrapped(wrap_pyfunction!(array_contains))?; + m.add_wrapped(wrap_pyfunction!(array))?; + m.add_wrapped(wrap_pyfunction!(shuffle))?; + m.add_wrapped(wrap_pyfunction!(array_repeat))?; + m.add_wrapped(wrap_pyfunction!(slice))?; + // Bitmap + m.add_wrapped(wrap_pyfunction!(bitmap_count))?; + m.add_wrapped(wrap_pyfunction!(bitmap_bit_position))?; + m.add_wrapped(wrap_pyfunction!(bitmap_bucket_number))?; + // Bitwise + m.add_wrapped(wrap_pyfunction!(bit_get))?; + m.add_wrapped(wrap_pyfunction!(bit_count))?; + m.add_wrapped(wrap_pyfunction!(bitwise_not))?; + m.add_wrapped(wrap_pyfunction!(shiftleft))?; + m.add_wrapped(wrap_pyfunction!(shiftright))?; + m.add_wrapped(wrap_pyfunction!(shiftrightunsigned))?; + // Collection + m.add_wrapped(wrap_pyfunction!(size))?; + // Conditional + m.add_wrapped(wrap_pyfunction!(if_))?; + // Conversion + m.add_wrapped(wrap_pyfunction!(spark_cast))?; + // Datetime + m.add_wrapped(wrap_pyfunction!(add_months))?; + m.add_wrapped(wrap_pyfunction!(date_add))?; + m.add_wrapped(wrap_pyfunction!(date_sub))?; + m.add_wrapped(wrap_pyfunction!(hour))?; + m.add_wrapped(wrap_pyfunction!(minute))?; + m.add_wrapped(wrap_pyfunction!(second))?; + m.add_wrapped(wrap_pyfunction!(last_day))?; + m.add_wrapped(wrap_pyfunction!(make_dt_interval))?; + m.add_wrapped(wrap_pyfunction!(make_interval))?; + m.add_wrapped(wrap_pyfunction!(next_day))?; + m.add_wrapped(wrap_pyfunction!(date_diff))?; + m.add_wrapped(wrap_pyfunction!(date_trunc))?; + m.add_wrapped(wrap_pyfunction!(time_trunc))?; + m.add_wrapped(wrap_pyfunction!(trunc))?; + m.add_wrapped(wrap_pyfunction!(date_part))?; + m.add_wrapped(wrap_pyfunction!(from_utc_timestamp))?; + m.add_wrapped(wrap_pyfunction!(to_utc_timestamp))?; + m.add_wrapped(wrap_pyfunction!(unix_date))?; + m.add_wrapped(wrap_pyfunction!(unix_micros))?; + m.add_wrapped(wrap_pyfunction!(unix_millis))?; + m.add_wrapped(wrap_pyfunction!(unix_seconds))?; + // Hash + m.add_wrapped(wrap_pyfunction!(crc32))?; + m.add_wrapped(wrap_pyfunction!(sha1))?; + m.add_wrapped(wrap_pyfunction!(sha2))?; + m.add_wrapped(wrap_pyfunction!(xxhash64))?; + // JSON + m.add_wrapped(wrap_pyfunction!(json_tuple))?; + // Map + m.add_wrapped(wrap_pyfunction!(map_from_arrays))?; + m.add_wrapped(wrap_pyfunction!(map_from_entries))?; + m.add_wrapped(wrap_pyfunction!(str_to_map))?; + // Math + m.add_wrapped(wrap_pyfunction!(abs))?; + m.add_wrapped(wrap_pyfunction!(ceil))?; + m.add_wrapped(wrap_pyfunction!(expm1))?; + m.add_wrapped(wrap_pyfunction!(factorial))?; + m.add_wrapped(wrap_pyfunction!(floor))?; + m.add_wrapped(wrap_pyfunction!(hex))?; + m.add_wrapped(wrap_pyfunction!(modulus))?; + m.add_wrapped(wrap_pyfunction!(pmod))?; + m.add_wrapped(wrap_pyfunction!(rint))?; + m.add_wrapped(wrap_pyfunction!(round))?; + m.add_wrapped(wrap_pyfunction!(unhex))?; + m.add_wrapped(wrap_pyfunction!(width_bucket))?; + m.add_wrapped(wrap_pyfunction!(csc))?; + m.add_wrapped(wrap_pyfunction!(sec))?; + m.add_wrapped(wrap_pyfunction!(negative))?; + m.add_wrapped(wrap_pyfunction!(bin))?; + // String + m.add_wrapped(wrap_pyfunction!(ascii))?; + m.add_wrapped(wrap_pyfunction!(base64))?; + m.add_wrapped(wrap_pyfunction!(char_fn))?; + m.add_wrapped(wrap_pyfunction!(concat))?; + m.add_wrapped(wrap_pyfunction!(elt))?; + m.add_wrapped(wrap_pyfunction!(ilike))?; + m.add_wrapped(wrap_pyfunction!(length))?; + m.add_wrapped(wrap_pyfunction!(like))?; + m.add_wrapped(wrap_pyfunction!(luhn_check))?; + m.add_wrapped(wrap_pyfunction!(format_string))?; + m.add_wrapped(wrap_pyfunction!(space))?; + m.add_wrapped(wrap_pyfunction!(substring))?; + m.add_wrapped(wrap_pyfunction!(unbase64))?; + m.add_wrapped(wrap_pyfunction!(soundex))?; + m.add_wrapped(wrap_pyfunction!(is_valid_utf8))?; + m.add_wrapped(wrap_pyfunction!(make_valid_utf8))?; + // URL + m.add_wrapped(wrap_pyfunction!(parse_url))?; + m.add_wrapped(wrap_pyfunction!(try_parse_url))?; + m.add_wrapped(wrap_pyfunction!(url_decode))?; + m.add_wrapped(wrap_pyfunction!(try_url_decode))?; + m.add_wrapped(wrap_pyfunction!(url_encode))?; + Ok(()) +} diff --git a/docs/source/user-guide/common-operations/functions.rst b/docs/source/user-guide/common-operations/functions.rst index ccb47a4e7..f656f8af7 100644 --- a/docs/source/user-guide/common-operations/functions.rst +++ b/docs/source/user-guide/common-operations/functions.rst @@ -21,6 +21,12 @@ Functions DataFusion provides a large number of built-in functions for performing complex queries without requiring user-defined functions. In here we will cover some of the more popular use cases. If you want to view all the functions go to the :py:mod:`Functions ` API Reference. +.. note:: + + For Apache Spark-compatible versions of these functions (with Spark + NULL-propagation, 1-indexed substrings, HALF_UP rounding, etc.), see + :doc:`spark-functions`. + We'll use the pokemon dataset in the following examples. .. ipython:: python diff --git a/docs/source/user-guide/common-operations/index.rst b/docs/source/user-guide/common-operations/index.rst index 7abd1f138..8b65c6de6 100644 --- a/docs/source/user-guide/common-operations/index.rst +++ b/docs/source/user-guide/common-operations/index.rst @@ -29,6 +29,7 @@ The contents of this section are designed to guide a new user through how to use expressions joins functions + spark-functions aggregations windows udf-and-udfa diff --git a/docs/source/user-guide/common-operations/spark-functions.rst b/docs/source/user-guide/common-operations/spark-functions.rst new file mode 100644 index 000000000..2c1053ef7 --- /dev/null +++ b/docs/source/user-guide/common-operations/spark-functions.rst @@ -0,0 +1,121 @@ +.. Licensed to the Apache Software Foundation (ASF) under one +.. or more contributor license agreements. See the NOTICE file +.. distributed with this work for additional information +.. regarding copyright ownership. The ASF licenses this file +.. to you under the Apache License, Version 2.0 (the +.. "License"); you may not use this file except in compliance +.. with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, +.. software distributed under the License is distributed on an +.. "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +.. KIND, either express or implied. See the License for the +.. specific language governing permissions and limitations +.. under the License. + +Spark-Compatible Functions +========================== + +DataFusion ships Spark-compatible versions of a wide set of SQL functions +(string, math, datetime, hash, array, aggregate) through the upstream +``datafusion-spark`` crate. ``datafusion-python`` exposes these under +``datafusion.functions.spark`` for the DataFrame API and via +:py:meth:`~datafusion.SessionContext.enable_spark_functions` for SQL. + +Why a Separate Namespace? +------------------------- + +Several Spark functions share names with DataFusion built-ins but differ in +semantics. The most common divergences: + +- ``concat`` propagates NULL. ``concat('a', NULL, 'b')`` returns NULL under + Spark semantics, whereas the DataFusion default returns ``'ab'``. +- ``substring`` is 1-indexed and supports negative positions counting from + the end of the string. +- ``round`` uses HALF_UP rounding mode (``round(2.5, 0) == 3``). +- Numeric functions (``floor``, ``ceil``, ``mod``) follow Spark's edge-case + handling for negative values and decimals. + +Enabling Spark functions does not affect the DataFrame API: you choose which +implementation to call by which module you import from. + +DataFrame API +------------- + +Import the submodule and call functions directly. Returned values are +:py:class:`~datafusion.expr.Expr` instances that compose with the rest of +the DataFrame API. + +.. code-block:: python + + from datafusion import SessionContext, col, lit + from datafusion.functions import spark + + ctx = SessionContext() + df = ctx.from_pydict({"s": ["hello", "world"]}) + + # SHA-256 hash with Spark semantics + df.select(spark.sha2(col("s"), lit(256)).alias("h")).show() + + # 1-indexed substring + df.select(spark.substring(col("s"), lit(1), lit(3)).alias("p")).show() + +SQL +--- + +To use Spark functions in SQL queries, call +:py:meth:`~datafusion.SessionContext.enable_spark_functions` on the context. +This registers every Spark UDF/UDAF/UDWF, overriding any DataFusion built-in +of the same name. + +.. code-block:: python + + from datafusion import SessionContext + + ctx = SessionContext() + ctx.enable_spark_functions() + + ctx.sql("SELECT sha2('hello', 256)").show() + ctx.sql("SELECT concat('a', NULL, 'b')").show() # -> NULL, not 'ab' + +The override applies for the lifetime of the session. To call DataFusion's +built-in versions afterwards, create a fresh ``SessionContext``. + +Function Categories +------------------- + +The full list is available in the +:py:mod:`API reference `. Highlights by +category: + +- **String**: ``ascii``, ``base64``, ``char``, ``concat``, ``elt``, + ``format_string``, ``ilike``, ``is_valid_utf8``, ``length``, ``like``, + ``luhn_check``, ``make_valid_utf8``, ``soundex``, ``space``, + ``substring``, ``unbase64``. +- **Math**: ``abs``, ``bin``, ``ceil``, ``csc``, ``expm1``, ``factorial``, + ``floor``, ``hex``, ``modulus``, ``negative``, ``pmod``, ``rint``, + ``round``, ``sec``, ``unhex``, ``width_bucket``. +- **Datetime**: ``add_months``, ``date_add``, ``date_diff``, ``date_part``, + ``date_sub``, ``date_trunc``, ``from_utc_timestamp``, ``hour``, + ``last_day``, ``make_dt_interval``, ``make_interval``, ``minute``, + ``next_day``, ``second``, ``time_trunc``, ``to_utc_timestamp``, + ``trunc``, ``unix_date``, ``unix_micros``, ``unix_millis``, + ``unix_seconds``. +- **Hash**: ``crc32``, ``sha1``, ``sha2``, ``xxhash64``. +- **Array**: ``array``, ``array_contains``, ``array_repeat``, ``shuffle``, + ``slice``. +- **Aggregate**: ``avg``, ``collect_list``, ``collect_set``, ``try_sum``. +- **Bitwise**: ``bit_count``, ``bit_get``, ``bitwise_not``, ``shiftleft``, + ``shiftright``, ``shiftrightunsigned``. +- **Bitmap**: ``bitmap_bit_position``, ``bitmap_bucket_number``, + ``bitmap_count``. +- **Collection**: ``size``. +- **Conditional**: ``if_`` (exposed under that name because ``if`` is a + Python keyword). +- **Conversion**: ``spark_cast``. +- **JSON**: ``json_tuple``. +- **Map**: ``map_from_arrays``, ``map_from_entries``, ``str_to_map``. +- **URL**: ``parse_url``, ``try_parse_url``, ``url_decode``, + ``try_url_decode``, ``url_encode``. diff --git a/pyproject.toml b/pyproject.toml index 2b6a976db..d33d38e7d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -184,6 +184,7 @@ extend-allowed-calls = ["datafusion.lit", "lit"] [tool.codespell] skip = [ "*/tests/test_functions.py", + "*/tests/test_spark_functions.py", "*/target", "./uv.lock", "uv.lock", diff --git a/python/datafusion/context.py b/python/datafusion/context.py index 52bd600c3..71b78c0c8 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -1297,6 +1297,28 @@ def register_udaf(self, udaf: AggregateUDF) -> None: """Register a user-defined aggregation function (UDAF) with the context.""" self.ctx.register_udaf(udaf._udaf) + def enable_spark_functions(self) -> None: + """Register all Spark-compatible functions for SQL access. + + Registers every UDF/UDAF/UDWF from the ``datafusion-spark`` crate, + overriding any DataFusion built-ins of the same name with their + Spark-semantics version (e.g. ``substring`` becomes 1-indexed, + ``concat`` propagates NULL, ``round`` uses HALF_UP rounding). + + For DataFrame use, import the typed wrappers from + :py:mod:`datafusion.functions.spark` directly; this method is only + needed for SQL queries. + + Examples: + >>> ctx = dfn.SessionContext() + >>> ctx.enable_spark_functions() + >>> ctx.sql( + ... "SELECT sha2('hello', 256) AS h" + ... ).collect_column("h")[0].as_py() + '2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824' + """ + self.ctx.enable_spark_functions() + def deregister_udaf(self, name: str) -> None: """Remove a user-defined aggregate function from the session. diff --git a/python/datafusion/functions.py b/python/datafusion/functions/__init__.py similarity index 99% rename from python/datafusion/functions.py rename to python/datafusion/functions/__init__.py index c11a5c6cd..87957ae41 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions/__init__.py @@ -55,6 +55,7 @@ sort_list_to_raw_sort_list, sort_or_default, ) +from datafusion.functions import spark __all__ = [ "abs", @@ -312,6 +313,7 @@ "signum", "sin", "sinh", + "spark", "split_part", "sqrt", "starts_with", diff --git a/python/datafusion/functions/spark.py b/python/datafusion/functions/spark.py new file mode 100644 index 000000000..c2badd811 --- /dev/null +++ b/python/datafusion/functions/spark.py @@ -0,0 +1,1362 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Spark-compatible function bindings. + +These functions mirror the semantics of their Apache Spark counterparts +exactly. Some override DataFusion built-ins (``substring`` is 1-indexed, +``concat`` propagates NULL, ``round`` uses HALF_UP rounding, etc.), which is +why they live in a separate namespace rather than replacing the defaults. + +For DataFrame use, import this module and call functions directly. For SQL +use, call :py:meth:`datafusion.SessionContext.enable_spark_functions` to +register the Spark UDFs by name (overriding any built-ins with matching +names) before issuing SQL queries. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from datafusion._internal import functions as _functions +from datafusion.expr import Expr, sort_list_to_raw_sort_list + +if TYPE_CHECKING: + from datafusion.common import NullTreatment + from datafusion.expr import SortKey + +_f = _functions.spark + + +def _filter_raw(filter: Expr | None) -> Any: + return filter.expr if filter is not None else None + + +# --------------------------------------------------------------------------- +# Aggregate functions +# --------------------------------------------------------------------------- + + +def avg( + expression: Expr, + distinct: bool | None = None, + filter: Expr | None = None, + order_by: list[SortKey] | SortKey | None = None, + null_treatment: NullTreatment | None = None, +) -> Expr: + """Spark ``avg``: returns the mean of a numeric column. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1.0, 2.0, 3.0]}) + >>> r = df.aggregate( + ... [], [dfn.functions.spark.avg(dfn.col("a")).alias("v")]) + >>> r.collect_column("v")[0].as_py() + 2.0 + """ + return Expr( + _f.avg( + expression.expr, + distinct=distinct, + filter=_filter_raw(filter), + order_by=sort_list_to_raw_sort_list(order_by), + null_treatment=null_treatment, + ) + ) + + +def try_sum( + expression: Expr, + distinct: bool | None = None, + filter: Expr | None = None, + order_by: list[SortKey] | SortKey | None = None, + null_treatment: NullTreatment | None = None, +) -> Expr: + """Spark ``try_sum``: sum that returns NULL on overflow. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 2, 3]}) + >>> r = df.aggregate( + ... [], [dfn.functions.spark.try_sum(dfn.col("a")).alias("v")]) + >>> r.collect_column("v")[0].as_py() + 6 + """ + return Expr( + _f.try_sum( + expression.expr, + distinct=distinct, + filter=_filter_raw(filter), + order_by=sort_list_to_raw_sort_list(order_by), + null_treatment=null_treatment, + ) + ) + + +def collect_list( + expression: Expr, + distinct: bool | None = None, + filter: Expr | None = None, + order_by: list[SortKey] | SortKey | None = None, + null_treatment: NullTreatment | None = None, +) -> Expr: + """Spark ``collect_list``: collect values into an array (preserves dups). + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 2, 2]}) + >>> r = df.aggregate( + ... [], [dfn.functions.spark.collect_list(dfn.col("a")).alias("v")]) + >>> sorted(r.collect_column("v")[0].as_py()) + [1, 2, 2] + """ + return Expr( + _f.collect_list( + expression.expr, + distinct=distinct, + filter=_filter_raw(filter), + order_by=sort_list_to_raw_sort_list(order_by), + null_treatment=null_treatment, + ) + ) + + +def collect_set( + expression: Expr, + distinct: bool | None = None, + filter: Expr | None = None, + order_by: list[SortKey] | SortKey | None = None, + null_treatment: NullTreatment | None = None, +) -> Expr: + """Spark ``collect_set``: collect distinct values into an array. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 2, 2, 3]}) + >>> r = df.aggregate( + ... [], [dfn.functions.spark.collect_set(dfn.col("a")).alias("v")]) + >>> sorted(r.collect_column("v")[0].as_py()) + [1, 2, 3] + """ + return Expr( + _f.collect_set( + expression.expr, + distinct=distinct, + filter=_filter_raw(filter), + order_by=sort_list_to_raw_sort_list(order_by), + null_treatment=null_treatment, + ) + ) + + +# --------------------------------------------------------------------------- +# Array functions +# --------------------------------------------------------------------------- + + +def array_contains(array: Expr, element: Expr) -> Expr: + """Spark ``array_contains``: true if the array contains the element. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.array_contains( + ... dfn.functions.spark.array(dfn.lit(1), dfn.lit(2)), + ... dfn.lit(1), + ... ).alias("v") + ... ) + >>> r.collect_column("v")[0].as_py() + True + """ + return Expr(_f.array_contains(array.expr, element.expr)) + + +def array(*args: Expr) -> Expr: + """Spark ``array``: builds an array from the given elements. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.array( + ... dfn.lit(1), dfn.lit(2), dfn.lit(3) + ... ).alias("v") + ... ) + >>> r.collect_column("v")[0].as_py() + [1, 2, 3] + """ + return Expr(_f.array(*[a.expr for a in args])) + + +def shuffle(array: Expr) -> Expr: + """Spark ``shuffle``: returns a random permutation of the input array. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.shuffle( + ... dfn.functions.spark.array(dfn.lit(1), dfn.lit(2), dfn.lit(3)) + ... ).alias("v") + ... ) + >>> sorted(r.collect_column("v")[0].as_py()) + [1, 2, 3] + """ + return Expr(_f.shuffle(array.expr)) + + +def array_repeat(element: Expr, count: Expr) -> Expr: + """Spark ``array_repeat``: array of ``element`` repeated ``count`` times. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.array_repeat(dfn.lit("a"), dfn.lit(3)).alias("v")) + >>> r.collect_column("v")[0].as_py() + ['a', 'a', 'a'] + """ + return Expr(_f.array_repeat(element.expr, count.expr)) + + +def slice(array: Expr, start: Expr, length: Expr) -> Expr: + """Spark ``slice``: subset of the array from 1-indexed ``start`` with ``length``. + + Negative ``start`` counts from the end. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.slice( + ... dfn.functions.spark.array( + ... dfn.lit(1), dfn.lit(2), dfn.lit(3), dfn.lit(4)), + ... dfn.lit(2), dfn.lit(2), + ... ).alias("v") + ... ) + >>> r.collect_column("v")[0].as_py() + [2, 3] + """ + return Expr(_f.slice(array.expr, start.expr, length.expr)) + + +# --------------------------------------------------------------------------- +# Bitmap functions +# --------------------------------------------------------------------------- + + +def bitmap_count(arg: Expr) -> Expr: + """Spark ``bitmap_count``: number of set bits in a bitmap. + + Examples: + >>> dfn.functions.spark.bitmap_count(dfn.col("b")) # doctest: +SKIP + """ + return Expr(_f.bitmap_count(arg.expr)) + + +def bitmap_bit_position(arg: Expr) -> Expr: + """Spark ``bitmap_bit_position``: bit position for a child expression. + + Examples: + >>> dfn.functions.spark.bitmap_bit_position(dfn.col("b")) # doctest: +SKIP + """ + return Expr(_f.bitmap_bit_position(arg.expr)) + + +def bitmap_bucket_number(arg: Expr) -> Expr: + """Spark ``bitmap_bucket_number``: bucket number for a child expression. + + Examples: + >>> dfn.functions.spark.bitmap_bucket_number(dfn.col("b")) # doctest: +SKIP + """ + return Expr(_f.bitmap_bucket_number(arg.expr)) + + +# --------------------------------------------------------------------------- +# Bitwise functions +# --------------------------------------------------------------------------- + + +def bit_get(col: Expr, pos: Expr) -> Expr: + """Spark ``bit_get``: returns the bit (0 or 1) at ``pos``. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.bit_get(dfn.lit(5), dfn.lit(0)).alias("v")) + >>> r.collect_column("v")[0].as_py() + 1 + """ + return Expr(_f.bit_get(col.expr, pos.expr)) + + +def bit_count(col: Expr) -> Expr: + """Spark ``bit_count``: number of bits set in the integer's binary form. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select(dfn.functions.spark.bit_count(dfn.lit(7)).alias("v")) + >>> r.collect_column("v")[0].as_py() + 3 + """ + return Expr(_f.bit_count(col.expr)) + + +def bitwise_not(col: Expr) -> Expr: + """Spark ``~``: bitwise NOT. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select(dfn.functions.spark.bitwise_not(dfn.lit(0)).alias("v")) + >>> r.collect_column("v")[0].as_py() + -1 + """ + return Expr(_f.bitwise_not(col.expr)) + + +def shiftleft(value: Expr, shift: Expr) -> Expr: + """Spark ``shiftleft``: ``value`` shifted left by ``shift`` bits. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.shiftleft(dfn.lit(1), dfn.lit(3)).alias("v")) + >>> r.collect_column("v")[0].as_py() + 8 + """ + return Expr(_f.shiftleft(value.expr, shift.expr)) + + +def shiftright(value: Expr, shift: Expr) -> Expr: + """Spark ``shiftright``: arithmetic right shift. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.shiftright(dfn.lit(8), dfn.lit(2)).alias("v")) + >>> r.collect_column("v")[0].as_py() + 2 + """ + return Expr(_f.shiftright(value.expr, shift.expr)) + + +def shiftrightunsigned(value: Expr, shift: Expr) -> Expr: + """Spark ``shiftrightunsigned``: logical (unsigned) right shift. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.shiftrightunsigned( + ... dfn.lit(8), dfn.lit(2) + ... ).alias("v") + ... ) + >>> r.collect_column("v")[0].as_py() + 2 + """ + return Expr(_f.shiftrightunsigned(value.expr, shift.expr)) + + +# --------------------------------------------------------------------------- +# Collection / Conditional / Conversion +# --------------------------------------------------------------------------- + + +def size(arg: Expr) -> Expr: + """Spark ``size``: length of an array or map. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.size( + ... dfn.functions.spark.array(dfn.lit(1), dfn.lit(2), dfn.lit(3)) + ... ).alias("v") + ... ) + >>> r.collect_column("v")[0].as_py() + 3 + """ + return Expr(_f.size(arg.expr)) + + +def if_(condition: Expr, if_true: Expr, if_false: Expr) -> Expr: + """Spark ``if``: returns ``if_true`` when ``condition`` is true, else ``if_false``. + + Exposed as ``if_`` because ``if`` is a Python keyword. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 2]}) + >>> r = df.select( + ... dfn.functions.spark.if_( + ... dfn.col("a") > dfn.lit(1), dfn.lit("big"), dfn.lit("small") + ... ).alias("v") + ... ) + >>> r.collect_column("v").to_pylist() + ['small', 'big'] + """ + return Expr(_f.if_(condition.expr, if_true.expr, if_false.expr)) + + +def spark_cast(arg: Expr, type_str: Expr) -> Expr: + """Spark ``cast``: cast ``arg`` to the type named by ``type_str``. + + Uses Spark cast semantics (e.g. overflow returns NULL, not error). + + Examples: + >>> dfn.functions.spark.spark_cast( # doctest: +SKIP + ... dfn.col("x"), dfn.lit("string") + ... ) + """ + return Expr(_f.spark_cast(arg.expr, type_str.expr)) + + +# --------------------------------------------------------------------------- +# Datetime functions +# --------------------------------------------------------------------------- + + +def add_months(start_date: Expr, num_months: Expr) -> Expr: + """Spark ``add_months``: date + N months. + + Examples: + >>> dfn.functions.spark.add_months(dfn.col("d"), dfn.lit(1)) # doctest: +SKIP + """ + return Expr(_f.add_months(start_date.expr, num_months.expr)) + + +def date_add(start_date: Expr, days: Expr) -> Expr: + """Spark ``date_add``: date + N days. + + Examples: + >>> dfn.functions.spark.date_add(dfn.col("d"), dfn.lit(7)) # doctest: +SKIP + """ + return Expr(_f.date_add(start_date.expr, days.expr)) + + +def date_sub(start_date: Expr, days: Expr) -> Expr: + """Spark ``date_sub``: date - N days. + + Examples: + >>> dfn.functions.spark.date_sub(dfn.col("d"), dfn.lit(7)) # doctest: +SKIP + """ + return Expr(_f.date_sub(start_date.expr, days.expr)) + + +def hour(arg: Expr) -> Expr: + """Spark ``hour``: extract hour component of a timestamp. + + Examples: + >>> dfn.functions.spark.hour(dfn.col("ts")) # doctest: +SKIP + """ + return Expr(_f.hour(arg.expr)) + + +def minute(arg: Expr) -> Expr: + """Spark ``minute``: extract minute component of a timestamp. + + Examples: + >>> dfn.functions.spark.minute(dfn.col("ts")) # doctest: +SKIP + """ + return Expr(_f.minute(arg.expr)) + + +def second(arg: Expr) -> Expr: + """Spark ``second``: extract second component of a timestamp. + + Examples: + >>> dfn.functions.spark.second(dfn.col("ts")) # doctest: +SKIP + """ + return Expr(_f.second(arg.expr)) + + +def last_day(arg: Expr) -> Expr: + """Spark ``last_day``: last day of the month containing the date. + + Examples: + >>> dfn.functions.spark.last_day(dfn.col("d")) # doctest: +SKIP + """ + return Expr(_f.last_day(arg.expr)) + + +def make_dt_interval(days: Expr, hours: Expr, mins: Expr, secs: Expr) -> Expr: + """Spark ``make_dt_interval``: day-time interval from components. + + Examples: + >>> dfn.functions.spark.make_dt_interval( + ... dfn.lit(1), dfn.lit(2), dfn.lit(3), dfn.lit(4)) # doctest: +SKIP + """ + return Expr(_f.make_dt_interval(days.expr, hours.expr, mins.expr, secs.expr)) + + +def make_interval( + years: Expr, + months: Expr, + weeks: Expr, + days: Expr, + hours: Expr, + mins: Expr, + secs: Expr, +) -> Expr: + """Spark ``make_interval``: interval from year/month/week/day/hour/min/sec parts. + + Examples: + >>> dfn.functions.spark.make_interval( + ... dfn.lit(1), dfn.lit(0), dfn.lit(0), + ... dfn.lit(0), dfn.lit(0), dfn.lit(0), dfn.lit(0)) # doctest: +SKIP + """ + return Expr( + _f.make_interval( + years.expr, + months.expr, + weeks.expr, + days.expr, + hours.expr, + mins.expr, + secs.expr, + ) + ) + + +def next_day(start_date: Expr, day_of_week: Expr) -> Expr: + """Spark ``next_day``: first date after ``start_date`` named ``day_of_week``. + + Examples: + >>> dfn.functions.spark.next_day( # doctest: +SKIP + ... dfn.col("d"), dfn.lit("Sunday") + ... ) + """ + return Expr(_f.next_day(start_date.expr, day_of_week.expr)) + + +def date_diff(end_date: Expr, start_date: Expr) -> Expr: + """Spark ``date_diff``: number of days from ``start_date`` to ``end_date``. + + Examples: + >>> dfn.functions.spark.date_diff(dfn.col("e"), dfn.col("s")) # doctest: +SKIP + """ + return Expr(_f.date_diff(end_date.expr, start_date.expr)) + + +def date_trunc(fmt: Expr, ts: Expr) -> Expr: + """Spark ``date_trunc``: truncate timestamp to unit ``fmt``. + + Examples: + >>> dfn.functions.spark.date_trunc( # doctest: +SKIP + ... dfn.lit("year"), dfn.col("ts") + ... ) + """ + return Expr(_f.date_trunc(fmt.expr, ts.expr)) + + +def time_trunc(fmt: Expr, t: Expr) -> Expr: + """Spark ``time_trunc``: truncate time value to unit ``fmt``. + + Examples: + >>> dfn.functions.spark.time_trunc( # doctest: +SKIP + ... dfn.lit("hour"), dfn.col("t") + ... ) + """ + return Expr(_f.time_trunc(fmt.expr, t.expr)) + + +def trunc(dt: Expr, fmt: Expr) -> Expr: + """Spark ``trunc``: truncate date to unit ``fmt``. + + Examples: + >>> dfn.functions.spark.trunc(dfn.col("d"), dfn.lit("YEAR")) # doctest: +SKIP + """ + return Expr(_f.trunc(dt.expr, fmt.expr)) + + +def date_part(field: Expr, source: Expr) -> Expr: + """Spark ``date_part``: extract ``field`` from a date/time/timestamp. + + Examples: + >>> dfn.functions.spark.date_part( # doctest: +SKIP + ... dfn.lit("year"), dfn.col("d") + ... ) + """ + return Expr(_f.date_part(field.expr, source.expr)) + + +def from_utc_timestamp(ts: Expr, tz: Expr) -> Expr: + """Spark ``from_utc_timestamp``: interpret ``ts`` as UTC, convert to ``tz``. + + Examples: + >>> dfn.functions.spark.from_utc_timestamp( # doctest: +SKIP + ... dfn.col("ts"), dfn.lit("PST") + ... ) + """ + return Expr(_f.from_utc_timestamp(ts.expr, tz.expr)) + + +def to_utc_timestamp(ts: Expr, tz: Expr) -> Expr: + """Spark ``to_utc_timestamp``: interpret ``ts`` as ``tz``, convert to UTC. + + Examples: + >>> dfn.functions.spark.to_utc_timestamp( # doctest: +SKIP + ... dfn.col("ts"), dfn.lit("PST") + ... ) + """ + return Expr(_f.to_utc_timestamp(ts.expr, tz.expr)) + + +def unix_date(dt: Expr) -> Expr: + """Spark ``unix_date``: days since 1970-01-01 for ``dt``. + + Examples: + >>> dfn.functions.spark.unix_date(dfn.col("d")) # doctest: +SKIP + """ + return Expr(_f.unix_date(dt.expr)) + + +def unix_micros(ts: Expr) -> Expr: + """Spark ``unix_micros``: microseconds since epoch for ``ts``. + + Examples: + >>> dfn.functions.spark.unix_micros(dfn.col("ts")) # doctest: +SKIP + """ + return Expr(_f.unix_micros(ts.expr)) + + +def unix_millis(ts: Expr) -> Expr: + """Spark ``unix_millis``: milliseconds since epoch for ``ts``. + + Examples: + >>> dfn.functions.spark.unix_millis(dfn.col("ts")) # doctest: +SKIP + """ + return Expr(_f.unix_millis(ts.expr)) + + +def unix_seconds(ts: Expr) -> Expr: + """Spark ``unix_seconds``: seconds since epoch for ``ts``. + + Examples: + >>> dfn.functions.spark.unix_seconds(dfn.col("ts")) # doctest: +SKIP + """ + return Expr(_f.unix_seconds(ts.expr)) + + +# --------------------------------------------------------------------------- +# Hash functions +# --------------------------------------------------------------------------- + + +def crc32(arg: Expr) -> Expr: + """Spark ``crc32``: cyclic redundancy check value as a bigint. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"s": ["ABC"]}) + >>> r = df.select(dfn.functions.spark.crc32(dfn.col("s")).alias("v")) + >>> r.collect_column("v")[0].as_py() + 2743272264 + """ + return Expr(_f.crc32(arg.expr)) + + +def sha1(arg: Expr) -> Expr: + """Spark ``sha1``: SHA-1 hash as a hex string. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"s": ["hello"]}) + >>> r = df.select(dfn.functions.spark.sha1(dfn.col("s")).alias("v")) + >>> r.collect_column("v")[0].as_py() + 'aaf4c61ddcc5e8a2dabede0f3b482cd9aea9434d' + """ + return Expr(_f.sha1(arg.expr)) + + +def sha2(arg: Expr, bit_length: Expr) -> Expr: + """Spark ``sha2``: SHA-2 family hash (224, 256, 384, 512). Bit length 0 = 256. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"s": ["hello"]}) + >>> r = df.select( + ... dfn.functions.spark.sha2(dfn.col("s"), dfn.lit(256)).alias("v")) + >>> r.collect_column("v")[0].as_py() + '2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824' + """ + return Expr(_f.sha2(arg.expr, bit_length.expr)) + + +def xxhash64(*args: Expr) -> Expr: + """Spark ``xxhash64``: 64-bit xxHash of the arguments. + + Examples: + >>> dfn.functions.spark.xxhash64(dfn.col("s")) # doctest: +SKIP + """ + return Expr(_f.xxhash64(*[a.expr for a in args])) + + +# --------------------------------------------------------------------------- +# JSON functions +# --------------------------------------------------------------------------- + + +def json_tuple(*args: Expr) -> Expr: + """Spark ``json_tuple``: extract top-level fields from a JSON string. + + Examples: + >>> dfn.functions.spark.json_tuple( # doctest: +SKIP + ... dfn.col("j"), dfn.lit("a"), dfn.lit("b") + ... ) + """ + return Expr(_f.json_tuple(*[a.expr for a in args])) + + +# --------------------------------------------------------------------------- +# Map functions +# --------------------------------------------------------------------------- + + +def map_from_arrays(keys: Expr, values: Expr) -> Expr: + """Spark ``map_from_arrays``: build a map from parallel key/value arrays. + + Examples: + >>> dfn.functions.spark.map_from_arrays( # doctest: +SKIP + ... dfn.col("k"), dfn.col("v") + ... ) + """ + return Expr(_f.map_from_arrays(keys.expr, values.expr)) + + +def map_from_entries(arg: Expr) -> Expr: + """Spark ``map_from_entries``: build a map from an array of key/value structs. + + Examples: + >>> dfn.functions.spark.map_from_entries(dfn.col("entries")) # doctest: +SKIP + """ + return Expr(_f.map_from_entries(arg.expr)) + + +def str_to_map(text: Expr, pair_delim: Expr, key_value_delim: Expr) -> Expr: + """Spark ``str_to_map``: split text into key/value pairs using delimiters. + + Examples: + >>> dfn.functions.spark.str_to_map( + ... dfn.col("s"), dfn.lit(","), dfn.lit(":")) # doctest: +SKIP + """ + return Expr(_f.str_to_map(text.expr, pair_delim.expr, key_value_delim.expr)) + + +# --------------------------------------------------------------------------- +# Math functions +# --------------------------------------------------------------------------- + + +def abs(arg: Expr) -> Expr: + """Spark ``abs``: absolute value. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select(dfn.functions.spark.abs(dfn.lit(-5)).alias("v")) + >>> r.collect_column("v")[0].as_py() + 5 + """ + return Expr(_f.abs(arg.expr)) + + +def ceil(arg: Expr) -> Expr: + """Spark ``ceil``: smallest integer ≥ arg. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select(dfn.functions.spark.ceil(dfn.lit(1.2)).alias("v")) + >>> r.collect_column("v")[0].as_py() + 2 + """ + return Expr(_f.ceil(arg.expr)) + + +def expm1(arg: Expr) -> Expr: + """Spark ``expm1``: exp(arg) - 1. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select(dfn.functions.spark.expm1(dfn.lit(0.0)).alias("v")) + >>> r.collect_column("v")[0].as_py() + 0.0 + """ + return Expr(_f.expm1(arg.expr)) + + +def factorial(arg: Expr) -> Expr: + """Spark ``factorial``: n! for n in [0..20], else NULL. + + Examples: + >>> import pyarrow as pa + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.factorial( + ... dfn.lit(pa.scalar(5, type=pa.int32())) + ... ).alias("v") + ... ) + >>> r.collect_column("v")[0].as_py() + 120 + """ + return Expr(_f.factorial(arg.expr)) + + +def floor(arg: Expr) -> Expr: + """Spark ``floor``: largest integer ≤ arg. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select(dfn.functions.spark.floor(dfn.lit(1.8)).alias("v")) + >>> r.collect_column("v")[0].as_py() + 1 + """ + return Expr(_f.floor(arg.expr)) + + +def hex(arg: Expr) -> Expr: + """Spark ``hex``: hexadecimal representation. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select(dfn.functions.spark.hex(dfn.lit(255)).alias("v")) + >>> r.collect_column("v")[0].as_py() + 'FF' + """ + return Expr(_f.hex(arg.expr)) + + +def modulus(dividend: Expr, divisor: Expr) -> Expr: + """Spark ``mod``: remainder of ``dividend / divisor`` (sign follows dividend). + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.modulus(dfn.lit(10), dfn.lit(3)).alias("v")) + >>> r.collect_column("v")[0].as_py() + 1 + """ + return Expr(_f.modulus(dividend.expr, divisor.expr)) + + +def pmod(dividend: Expr, divisor: Expr) -> Expr: + """Spark ``pmod``: positive remainder of division. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.pmod(dfn.lit(-1), dfn.lit(3)).alias("v")) + >>> r.collect_column("v")[0].as_py() + 2 + """ + return Expr(_f.pmod(dividend.expr, divisor.expr)) + + +def rint(arg: Expr) -> Expr: + """Spark ``rint``: round to nearest mathematical integer (as double). + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select(dfn.functions.spark.rint(dfn.lit(2.5)).alias("v")) + >>> r.collect_column("v")[0].as_py() + 2.0 + """ + return Expr(_f.rint(arg.expr)) + + +def round(value: Expr, scale: Expr) -> Expr: + """Spark ``round``: round to ``scale`` decimal places, HALF_UP rounding. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.round(dfn.lit(2.5), dfn.lit(0)).alias("v")) + >>> r.collect_column("v")[0].as_py() + 3.0 + """ + return Expr(_f.round(value.expr, scale.expr)) + + +def unhex(arg: Expr) -> Expr: + r"""Spark ``unhex``: convert hexadecimal string to binary. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select(dfn.functions.spark.unhex(dfn.lit("FF")).alias("v")) + >>> r.collect_column("v")[0].as_py() + b'\xff' + """ + return Expr(_f.unhex(arg.expr)) + + +def width_bucket( + value: Expr, min_value: Expr, max_value: Expr, num_buckets: Expr +) -> Expr: + """Spark ``width_bucket``: bucket number for ``value`` in equi-width histogram. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.width_bucket( + ... dfn.lit(5.0), dfn.lit(0.0), dfn.lit(10.0), dfn.lit(5) + ... ).alias("v") + ... ) + >>> r.collect_column("v")[0].as_py() + 3 + """ + return Expr( + _f.width_bucket(value.expr, min_value.expr, max_value.expr, num_buckets.expr) + ) + + +def csc(arg: Expr) -> Expr: + """Spark ``csc``: cosecant. + + Examples: + >>> dfn.functions.spark.csc(dfn.lit(1.0)) # doctest: +SKIP + """ + return Expr(_f.csc(arg.expr)) + + +def sec(arg: Expr) -> Expr: + """Spark ``sec``: secant. + + Examples: + >>> dfn.functions.spark.sec(dfn.lit(0.0)) # doctest: +SKIP + """ + return Expr(_f.sec(arg.expr)) + + +def negative(arg: Expr) -> Expr: + """Spark ``negative``: unary minus. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select(dfn.functions.spark.negative(dfn.lit(3)).alias("v")) + >>> r.collect_column("v")[0].as_py() + -3 + """ + return Expr(_f.negative(arg.expr)) + + +def bin(arg: Expr) -> Expr: + """Spark ``bin``: binary string representation of a long. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select(dfn.functions.spark.bin(dfn.lit(7)).alias("v")) + >>> r.collect_column("v")[0].as_py() + '111' + """ + return Expr(_f.bin(arg.expr)) + + +# --------------------------------------------------------------------------- +# String functions +# --------------------------------------------------------------------------- + + +def ascii(arg: Expr) -> Expr: + """Spark ``ascii``: code point of the first character. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select(dfn.functions.spark.ascii(dfn.lit("A")).alias("v")) + >>> r.collect_column("v")[0].as_py() + 65 + """ + return Expr(_f.ascii(arg.expr)) + + +def base64(bin_input: Expr) -> Expr: + """Spark ``base64``: encode binary as a base64 string. + + Examples: + >>> dfn.functions.spark.base64(dfn.col("b")) # doctest: +SKIP + """ + return Expr(_f.base64(bin_input.expr)) + + +def char(arg: Expr) -> Expr: + """Spark ``char``: ASCII character for a code point (mod 256). + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select(dfn.functions.spark.char(dfn.lit(65)).alias("v")) + >>> r.collect_column("v")[0].as_py() + 'A' + """ + return Expr(_f.char(arg.expr)) + + +def concat(*args: Expr) -> Expr: + """Spark ``concat``: concatenates strings; NULL if any input is NULL. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.concat(dfn.lit("a"), dfn.lit("b")).alias("v")) + >>> r.collect_column("v")[0].as_py() + 'ab' + """ + return Expr(_f.concat(*[a.expr for a in args])) + + +def elt(*args: Expr) -> Expr: + """Spark ``elt``: returns the n-th input (1-indexed). + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.elt( + ... dfn.lit(2), dfn.lit("a"), dfn.lit("b") + ... ).alias("v") + ... ) + >>> r.collect_column("v")[0].as_py() + 'b' + """ + return Expr(_f.elt(*[a.expr for a in args])) + + +def ilike(string: Expr, pattern: Expr) -> Expr: + """Spark ``ilike``: case-insensitive pattern match. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.ilike(dfn.lit("HELLO"), dfn.lit("h%")).alias("v")) + >>> r.collect_column("v")[0].as_py() + True + """ + return Expr(_f.ilike(string.expr, pattern.expr)) + + +def length(arg: Expr) -> Expr: + """Spark ``length``: character length of a string, or byte length of binary. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select(dfn.functions.spark.length(dfn.lit("hello")).alias("v")) + >>> r.collect_column("v")[0].as_py() + 5 + """ + return Expr(_f.length(arg.expr)) + + +def like(string: Expr, pattern: Expr) -> Expr: + """Spark ``like``: case-sensitive pattern match. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.like(dfn.lit("hello"), dfn.lit("h%")).alias("v")) + >>> r.collect_column("v")[0].as_py() + True + """ + return Expr(_f.like(string.expr, pattern.expr)) + + +def luhn_check(arg: Expr) -> Expr: + """Spark ``luhn_check``: true if the digit string passes the Luhn check. + + Examples: + >>> dfn.functions.spark.luhn_check(dfn.col("card")) # doctest: +SKIP + """ + return Expr(_f.luhn_check(arg.expr)) + + +def format_string(*args: Expr) -> Expr: + """Spark ``format_string``: printf-style format string. + + First arg is the format, remaining args are values to substitute. + + Examples: + >>> dfn.functions.spark.format_string( # doctest: +SKIP + ... dfn.lit("%d/%d"), dfn.lit(1), dfn.lit(2) + ... ) + """ + return Expr(_f.format_string(*[a.expr for a in args])) + + +def space(arg: Expr) -> Expr: + """Spark ``space``: string of n spaces. + + Examples: + >>> import pyarrow as pa + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.space( + ... dfn.lit(pa.scalar(3, type=pa.int32())) + ... ).alias("v") + ... ) + >>> r.collect_column("v")[0].as_py() + ' ' + """ + return Expr(_f.space(arg.expr)) + + +def substring(string: Expr, pos: Expr, length: Expr) -> Expr: + """Spark ``substring``: 1-indexed substring starting at ``pos`` of given ``length``. + + Negative ``pos`` counts from the end. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.substring( + ... dfn.lit("hello"), dfn.lit(1), dfn.lit(3) + ... ).alias("v") + ... ) + >>> r.collect_column("v")[0].as_py() + 'hel' + """ + return Expr(_f.substring(string.expr, pos.expr, length.expr)) + + +def unbase64(arg: Expr) -> Expr: + """Spark ``unbase64``: decode a base64 string to binary. + + Examples: + >>> dfn.functions.spark.unbase64(dfn.col("s")) # doctest: +SKIP + """ + return Expr(_f.unbase64(arg.expr)) + + +def soundex(arg: Expr) -> Expr: + """Spark ``soundex``: Soundex phonetic code. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select(dfn.functions.spark.soundex(dfn.lit("Robert")).alias("v")) + >>> r.collect_column("v")[0].as_py() + 'R163' + """ + return Expr(_f.soundex(arg.expr)) + + +def is_valid_utf8(arg: Expr) -> Expr: + """Spark ``is_valid_utf8``: true if the string is valid UTF-8. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.is_valid_utf8(dfn.lit("hello")).alias("v")) + >>> r.collect_column("v")[0].as_py() + True + """ + return Expr(_f.is_valid_utf8(arg.expr)) + + +def make_valid_utf8(arg: Expr) -> Expr: + """Spark ``make_valid_utf8``: replace invalid UTF-8 bytes with U+FFFD. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.make_valid_utf8(dfn.lit("hello")).alias("v")) + >>> r.collect_column("v")[0].as_py() + 'hello' + """ + return Expr(_f.make_valid_utf8(arg.expr)) + + +# --------------------------------------------------------------------------- +# URL functions +# --------------------------------------------------------------------------- + + +def parse_url(*args: Expr) -> Expr: + """Spark ``parse_url``: extract a part from a URL; errors on invalid URLs. + + Examples: + >>> dfn.functions.spark.parse_url( # doctest: +SKIP + ... dfn.col("u"), dfn.lit("HOST") + ... ) + """ + return Expr(_f.parse_url(*[a.expr for a in args])) + + +def try_parse_url(*args: Expr) -> Expr: + """Spark ``try_parse_url``: like ``parse_url`` but returns NULL on invalid URLs. + + Examples: + >>> dfn.functions.spark.try_parse_url( # doctest: +SKIP + ... dfn.col("u"), dfn.lit("HOST") + ... ) + """ + return Expr(_f.try_parse_url(*[a.expr for a in args])) + + +def url_decode(*args: Expr) -> Expr: + """Spark ``url_decode``: decode an application/x-www-form-urlencoded string. + + Examples: + >>> dfn.functions.spark.url_decode(dfn.col("s")) # doctest: +SKIP + """ + return Expr(_f.url_decode(*[a.expr for a in args])) + + +def try_url_decode(*args: Expr) -> Expr: + """Spark ``try_url_decode``: like ``url_decode``; returns NULL on invalid input. + + Examples: + >>> dfn.functions.spark.try_url_decode(dfn.col("s")) # doctest: +SKIP + """ + return Expr(_f.try_url_decode(*[a.expr for a in args])) + + +def url_encode(*args: Expr) -> Expr: + """Spark ``url_encode``: encode a string in application/x-www-form-urlencoded. + + Examples: + >>> dfn.functions.spark.url_encode(dfn.col("s")) # doctest: +SKIP + """ + return Expr(_f.url_encode(*[a.expr for a in args])) + + +__all__ = [ + # Math + "abs", + # Datetime + "add_months", + "array", + # Array + "array_contains", + "array_repeat", + # String + "ascii", + # Aggregate + "avg", + "base64", + "bin", + "bit_count", + # Bitwise + "bit_get", + "bitmap_bit_position", + "bitmap_bucket_number", + # Bitmap + "bitmap_count", + "bitwise_not", + "ceil", + "char", + "collect_list", + "collect_set", + "concat", + # Hash + "crc32", + "csc", + "date_add", + "date_diff", + "date_part", + "date_sub", + "date_trunc", + "elt", + "expm1", + "factorial", + "floor", + "format_string", + "from_utc_timestamp", + "hex", + "hour", + "if_", + "ilike", + "is_valid_utf8", + # JSON + "json_tuple", + "last_day", + "length", + "like", + "luhn_check", + "make_dt_interval", + "make_interval", + "make_valid_utf8", + # Map + "map_from_arrays", + "map_from_entries", + "minute", + "modulus", + "negative", + "next_day", + # URL + "parse_url", + "pmod", + "rint", + "round", + "sec", + "second", + "sha1", + "sha2", + "shiftleft", + "shiftright", + "shiftrightunsigned", + "shuffle", + # Collection / Conditional / Conversion + "size", + "slice", + "soundex", + "space", + "spark_cast", + "str_to_map", + "substring", + "time_trunc", + "to_utc_timestamp", + "trunc", + "try_parse_url", + "try_sum", + "try_url_decode", + "unbase64", + "unhex", + "unix_date", + "unix_micros", + "unix_millis", + "unix_seconds", + "url_decode", + "url_encode", + "width_bucket", + "xxhash64", +] diff --git a/python/tests/test_spark_functions.py b/python/tests/test_spark_functions.py new file mode 100644 index 000000000..f9ccb41a1 --- /dev/null +++ b/python/tests/test_spark_functions.py @@ -0,0 +1,264 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Tests for the Spark-compatible function bindings.""" + +import pyarrow as pa +import pytest +from datafusion import SessionContext, col, lit +from datafusion import functions as f +from datafusion.functions import spark + + +@pytest.fixture +def df(): + ctx = SessionContext() + batch = pa.RecordBatch.from_arrays( + [ + pa.array(["hello", "WORLD", "abc"]), + pa.array([1, 2, 3]), + pa.array([-5, 0, 7]), + pa.array([1.0, 2.5, 3.0]), + pa.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), + ], + names=["s", "i", "n", "f", "a"], + ) + return ctx.create_dataframe([[batch]]) + + +def _val(df, expr): + return df.select(expr.alias("v")).collect_column("v")[0].as_py() + + +# --------------------------------------------------------------------------- +# Math +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + ("expr_factory", "expected"), + [ + (lambda: spark.abs(lit(-5)), 5), + (lambda: spark.ceil(lit(1.2)), 2), + (lambda: spark.floor(lit(1.8)), 1), + (lambda: spark.bin(lit(7)), "111"), + (lambda: spark.hex(lit(255)), "FF"), + (lambda: spark.modulus(lit(10), lit(3)), 1), + (lambda: spark.pmod(lit(-1), lit(3)), 2), + (lambda: spark.rint(lit(2.5)), 2.0), + (lambda: spark.round(lit(2.5), lit(0)), 3.0), + (lambda: spark.negative(lit(3)), -3), + ], +) +def test_math(df, expr_factory, expected): + assert _val(df, expr_factory()) == expected + + +def test_factorial(df): + # factorial wants Int32; lit(int) is Int64 by default. + expr = spark.factorial(lit(pa.scalar(5, type=pa.int32()))) + assert _val(df, expr) == 120 + + +# --------------------------------------------------------------------------- +# String +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + ("expr_factory", "expected"), + [ + (lambda: spark.ascii(lit("A")), 65), + (lambda: spark.char(lit(65)), "A"), + (lambda: spark.length(lit("hello")), 5), + (lambda: spark.like(lit("hello"), lit("h%")), True), + (lambda: spark.ilike(lit("HELLO"), lit("h%")), True), + (lambda: spark.substring(lit("hello"), lit(1), lit(3)), "hel"), + (lambda: spark.soundex(lit("Robert")), "R163"), + (lambda: spark.is_valid_utf8(lit("hi")), True), + (lambda: spark.concat(lit("a"), lit("b")), "ab"), + (lambda: spark.elt(lit(2), lit("a"), lit("b")), "b"), + ], +) +def test_string(df, expr_factory, expected): + assert _val(df, expr_factory()) == expected + + +def test_space(df): + expr = spark.space(lit(pa.scalar(3, type=pa.int32()))) + assert _val(df, expr) == " " + + +# --------------------------------------------------------------------------- +# Hash +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + ("expr_factory", "expected"), + [ + ( + lambda: spark.sha1(lit("hello")), + "aaf4c61ddcc5e8a2dabede0f3b482cd9aea9434d", + ), + ( + lambda: spark.sha2(lit("hello"), lit(256)), + "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824", + ), + (lambda: spark.crc32(lit("ABC")), 2743272264), + ], +) +def test_hash(df, expr_factory, expected): + assert _val(df, expr_factory()) == expected + + +# --------------------------------------------------------------------------- +# Bitwise +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + ("expr_factory", "expected"), + [ + (lambda: spark.bit_count(lit(7)), 3), + (lambda: spark.bit_get(lit(5), lit(0)), 1), + (lambda: spark.bitwise_not(lit(0)), -1), + (lambda: spark.shiftleft(lit(1), lit(3)), 8), + (lambda: spark.shiftright(lit(8), lit(2)), 2), + (lambda: spark.shiftrightunsigned(lit(8), lit(2)), 2), + ], +) +def test_bitwise(df, expr_factory, expected): + assert _val(df, expr_factory()) == expected + + +# --------------------------------------------------------------------------- +# Array / collection / conditional +# --------------------------------------------------------------------------- + + +def test_array_and_size(df): + arr = spark.array(lit(1), lit(2), lit(3)) + assert _val(df, arr) == [1, 2, 3] + assert _val(df, spark.size(arr)) == 3 + assert _val(df, spark.array_contains(arr, lit(2))) is True + + +def test_slice(df): + arr = spark.array(lit(1), lit(2), lit(3), lit(4)) + assert _val(df, spark.slice(arr, lit(2), lit(2))) == [2, 3] + + +def test_array_repeat(df): + assert _val(df, spark.array_repeat(lit("a"), lit(3))) == ["a", "a", "a"] + + +def test_if(df): + assert _val(df, spark.if_(lit(True), lit("yes"), lit("no"))) == "yes" + assert _val(df, spark.if_(lit(False), lit("yes"), lit("no"))) == "no" + + +# --------------------------------------------------------------------------- +# Aggregate +# --------------------------------------------------------------------------- + + +def test_aggregate(df): + r = df.aggregate( + [], + [ + spark.avg(col("f")).alias("avg"), + spark.try_sum(col("i")).alias("sum"), + ], + ).collect() + rec = pa.Table.from_batches(r) + assert rec.column("avg")[0].as_py() == pytest.approx(2.1666666, rel=1e-3) + assert rec.column("sum")[0].as_py() == 6 + + +def test_collect_list_set(df): + r = df.aggregate( + [], + [ + spark.collect_list(col("i")).alias("cl"), + spark.collect_set(col("i")).alias("cs"), + ], + ).collect() + rec = pa.Table.from_batches(r) + assert sorted(rec.column("cl")[0].as_py()) == [1, 2, 3] + assert sorted(rec.column("cs")[0].as_py()) == [1, 2, 3] + + +# --------------------------------------------------------------------------- +# Spark-semantics conflicts vs DataFusion defaults +# --------------------------------------------------------------------------- + + +def test_concat_null_propagates(): + """Spark concat returns NULL on any NULL input; default skips NULLs.""" + ctx = SessionContext() + df = ctx.from_pydict({"x": [1]}) + default_out = ( + df.select(f.concat(lit("a"), lit(None), lit("b")).alias("v")) + .collect_column("v")[0] + .as_py() + ) + spark_out = _val(df, spark.concat(lit("a"), lit(None), lit("b"))) + assert default_out == "ab" + assert spark_out is None + + +def test_round_half_up(): + """Spark round uses HALF_UP rounding (2.5 → 3).""" + ctx = SessionContext() + df = ctx.from_pydict({"x": [1]}) + assert _val(df, spark.round(lit(2.5), lit(0))) == 3.0 + + +# --------------------------------------------------------------------------- +# SQL path via enable_spark_functions +# --------------------------------------------------------------------------- + + +def test_sql_requires_enable(): + """Spark-only function (xxhash64) is not in SQL namespace by default.""" + ctx = SessionContext() + with pytest.raises(Exception, match=r"(?i)xxhash64|Invalid function"): + ctx.sql("SELECT xxhash64('hello')").collect() + + +def test_sql_after_enable(): + ctx = SessionContext() + ctx.enable_spark_functions() + out = ctx.sql("SELECT sha2('hello', 256) AS h").collect_column("h")[0].as_py() + assert out == ("2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824") + + +def test_sql_concat_semantics_override(): + """After enable, SQL concat propagates NULL like Spark.""" + ctx = SessionContext() + default_out = ( + ctx.sql("SELECT concat('a', NULL, 'b') AS c").collect_column("c")[0].as_py() + ) + assert default_out == "ab" + + ctx2 = SessionContext() + ctx2.enable_spark_functions() + spark_out = ( + ctx2.sql("SELECT concat('a', NULL, 'b') AS c").collect_column("c")[0].as_py() + ) + assert spark_out is None From d2f7c0d89214db943a527f9a36356552493b2a4e Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sat, 30 May 2026 06:08:24 -0400 Subject: [PATCH 02/14] chore: drop unused borrow_deref_ref allows MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Seven `#[allow(clippy::borrow_deref_ref)]` attributes on module declarations in `crates/core/src/lib.rs` had become stale — the only remaining lint hit was a redundant `&*x.as_str()` pattern in `parse_file_compression_type`. Rewriting that call to `&x.unwrap_or_default()` lets every allow come off, removing noise that new modules were copying without need. Co-Authored-By: Claude Opus 4.7 (1M context) --- crates/core/src/context.rs | 7 +++---- crates/core/src/lib.rs | 8 -------- 2 files changed, 3 insertions(+), 12 deletions(-) diff --git a/crates/core/src/context.rs b/crates/core/src/context.rs index 2b7ac797e..d64ce587a 100644 --- a/crates/core/src/context.rs +++ b/crates/core/src/context.rs @@ -1576,10 +1576,9 @@ impl PySessionContext { pub fn parse_file_compression_type( file_compression_type: Option, ) -> Result { - FileCompressionType::from_str(&*file_compression_type.unwrap_or("".to_string()).as_str()) - .map_err(|_| { - PyValueError::new_err("file_compression_type must one of: gzip, bz2, xz, zstd") - }) + FileCompressionType::from_str(&file_compression_type.unwrap_or_default()).map_err(|_| { + PyValueError::new_err("file_compression_type must one of: gzip, bz2, xz, zstd") + }) } impl From for SessionContext { diff --git a/crates/core/src/lib.rs b/crates/core/src/lib.rs index c3c549aa8..8f6b77b42 100644 --- a/crates/core/src/lib.rs +++ b/crates/core/src/lib.rs @@ -26,21 +26,16 @@ pub use datafusion_substrait; use mimalloc::MiMalloc; use pyo3::prelude::*; -#[allow(clippy::borrow_deref_ref)] pub mod catalog; pub mod codec; pub mod common; -#[allow(clippy::borrow_deref_ref)] pub mod context; -#[allow(clippy::borrow_deref_ref)] pub mod dataframe; mod dataset; mod dataset_exec; pub mod errors; -#[allow(clippy::borrow_deref_ref)] pub mod expr; -#[allow(clippy::borrow_deref_ref)] mod functions; pub mod metrics; mod options; @@ -48,7 +43,6 @@ pub mod physical_plan; mod pyarrow_filter_expression; pub mod pyarrow_util; mod record_batch; -#[allow(clippy::borrow_deref_ref)] mod spark_functions; pub mod sql; pub mod store; @@ -58,9 +52,7 @@ pub mod unparser; mod array; #[cfg(feature = "substrait")] pub mod substrait; -#[allow(clippy::borrow_deref_ref)] mod udaf; -#[allow(clippy::borrow_deref_ref)] mod udf; pub mod udtf; mod udwf; From c8ca27f90a496903550ab5ccd9485d109d1e206d Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sat, 30 May 2026 06:32:43 -0400 Subject: [PATCH 03/14] refactor: tighten spark_functions macros via expr_fn Switch most spark wrappers from UDF-direct path (which forced `spark_udf_fixed!(name, fn_category::name, args...)` repetition) to a `spark_expr_fn!` macro that mirrors the existing `expr_fn!` macro in `functions.rs`, so calls collapse to `spark_expr_fn!(sha2, arg1 bit_length);`. UDF-direct retained for genuinely variadic functions whose upstream `expr_fn` wrappers were generated with a single-`Expr` arm by `export_functions!` (concat, array, xxhash64, parse_url family, etc.) so that the Python side keeps its `*args` ergonomics. Aggregates collapse the same way via `spark_aggregate!` mirroring `aggregate_function!`. Net 173 lines removed. Co-Authored-By: Claude Opus 4.7 (1M context) --- crates/core/src/spark_functions.rs | 300 ++++++++++++----------------- 1 file changed, 127 insertions(+), 173 deletions(-) diff --git a/crates/core/src/spark_functions.rs b/crates/core/src/spark_functions.rs index 2a4d938ff..e7cb94f8c 100644 --- a/crates/core/src/spark_functions.rs +++ b/crates/core/src/spark_functions.rs @@ -18,19 +18,11 @@ //! PyO3 wrappers for the [`datafusion-spark`] crate. //! //! Exposes Spark-compatible scalar and aggregate function builders for use -//! from Python under `datafusion.functions.spark`. Each scalar wrapper -//! resolves the underlying `ScalarUDF` from `datafusion_spark` and builds an -//! `Expr::ScalarFunction` directly, so behaviour matches what -//! `datafusion_spark::register_all` registers for SQL. +//! from Python under `datafusion.functions.spark`. use datafusion::logical_expr::Expr; use datafusion::logical_expr::expr::ScalarFunction; -use datafusion_spark::function::{ - aggregate as fn_aggregate, array as fn_array, bitmap as fn_bitmap, bitwise as fn_bitwise, - collection as fn_collection, conditional as fn_conditional, conversion as fn_conversion, - datetime as fn_datetime, hash as fn_hash, json as fn_json, map as fn_map, math as fn_math, - string as fn_string, url as fn_url, -}; +use datafusion_spark::{expr_fn, function as udf}; use pyo3::prelude::*; use pyo3::wrap_pyfunction; @@ -40,20 +32,26 @@ use crate::expr::PyExpr; use crate::expr::sort_expr::PySortExpr; use crate::functions::add_builder_fns_to_aggregate; -/// Build an `Expr::ScalarFunction` by invoking the given `ScalarUDF` factory -/// with the supplied arguments. -macro_rules! spark_udf_fixed { - ($PY_NAME:ident, $UDF_PATH:path, $($arg:ident),+ $(,)?) => { +/// Generates a [pyo3] wrapper for [datafusion_spark::expr_fn]. +/// +/// These functions have explicit named arguments and mirror the upstream +/// `expr_fn::$FUNC` signature. +macro_rules! spark_expr_fn { + ($FUNC:ident) => { + spark_expr_fn!($FUNC,); + }; + ($FUNC:ident, $($arg:ident)*) => { #[pyfunction] - fn $PY_NAME($($arg: PyExpr),+) -> PyExpr { - let udf = $UDF_PATH(); - let args: Vec = vec![$($arg.into()),+]; - Expr::ScalarFunction(ScalarFunction::new_udf(udf, args)).into() + fn $FUNC($($arg: PyExpr),*) -> PyExpr { + expr_fn::$FUNC($($arg.into()),*).into() } }; } -/// Build an `Expr::ScalarFunction` from a variadic `*args` list. +/// Generates a variadic [pyo3] wrapper that calls the [`ScalarUDF`] factory +/// directly. Required for functions whose upstream `expr_fn` wrapper accepts +/// a single `Expr` instead of `Vec` (an upstream `export_functions!` +/// macro-arm quirk), so we bypass it to get true Python `*args` semantics. macro_rules! spark_udf_vec { ($PY_NAME:ident, $UDF_PATH:path) => { #[pyfunction] @@ -66,24 +64,24 @@ macro_rules! spark_udf_vec { }; } -/// Build an aggregate `Expr` from a Spark `AggregateUDF` factory and the -/// optional builder fields (distinct/filter/order_by/null_treatment), -/// mirroring the existing `aggregate_function!` macro for default DataFusion -/// aggregates. -macro_rules! spark_aggregate_fixed { - ($PY_NAME:ident, $UDF_PATH:path, $($arg:ident),+ $(,)?) => { +/// Generates a [pyo3] wrapper for Spark aggregate functions. Mirrors +/// [`crate::functions::aggregate_function`] but points at +/// [`datafusion_spark::expr_fn`]. +macro_rules! spark_aggregate { + ($NAME:ident) => { + spark_aggregate!($NAME, expr); + }; + ($NAME:ident, $($arg:ident)*) => { #[pyfunction] - #[pyo3(signature = ($($arg),+, distinct=None, filter=None, order_by=None, null_treatment=None))] - fn $PY_NAME( - $($arg: PyExpr),+, + #[pyo3(signature = ($($arg),*, distinct=None, filter=None, order_by=None, null_treatment=None))] + fn $NAME( + $($arg: PyExpr),*, distinct: Option, filter: Option, order_by: Option>, null_treatment: Option, ) -> PyDataFusionResult { - let udf = $UDF_PATH(); - let args: Vec = vec![$($arg.into()),+]; - let agg_fn = udf.call(args); + let agg_fn = expr_fn::$NAME($($arg.into()),*); add_builder_fns_to_aggregate(agg_fn, distinct, filter, order_by, null_treatment) } }; @@ -93,211 +91,167 @@ macro_rules! spark_aggregate_fixed { // Aggregate functions // --------------------------------------------------------------------------- -spark_aggregate_fixed!(avg, fn_aggregate::avg, arg1); -spark_aggregate_fixed!(try_sum, fn_aggregate::try_sum, arg1); -spark_aggregate_fixed!(collect_list, fn_aggregate::collect_list, arg1); -spark_aggregate_fixed!(collect_set, fn_aggregate::collect_set, arg1); +spark_aggregate!(avg, arg1); +spark_aggregate!(try_sum, arg1); +spark_aggregate!(collect_list, arg1); +spark_aggregate!(collect_set, arg1); // --------------------------------------------------------------------------- // Array functions // --------------------------------------------------------------------------- -spark_udf_fixed!( - array_contains, - fn_array::spark_array_contains, - array, - element -); -spark_udf_vec!(array, fn_array::array); -spark_udf_fixed!(shuffle, fn_array::shuffle, arg1); -spark_udf_fixed!(array_repeat, fn_array::array_repeat, element, count); -spark_udf_fixed!(slice, fn_array::slice, arg_array, start, length); +// Upstream factory is `spark_array_contains`; expose under the Spark SQL +// name `array_contains` on the Python side. +#[pyfunction] +fn array_contains(arr: PyExpr, element: PyExpr) -> PyExpr { + expr_fn::spark_array_contains(arr.into(), element.into()).into() +} +spark_udf_vec!(array, udf::array::array); +spark_expr_fn!(shuffle, arg1); +spark_expr_fn!(array_repeat, element count); +spark_expr_fn!(slice, arr start length); // --------------------------------------------------------------------------- // Bitmap functions // --------------------------------------------------------------------------- -spark_udf_fixed!(bitmap_count, fn_bitmap::bitmap_count, arg1); -spark_udf_fixed!(bitmap_bit_position, fn_bitmap::bitmap_bit_position, arg1); -spark_udf_fixed!(bitmap_bucket_number, fn_bitmap::bitmap_bucket_number, arg1); +spark_expr_fn!(bitmap_count, arg1); +spark_expr_fn!(bitmap_bit_position, arg1); +spark_expr_fn!(bitmap_bucket_number, arg1); // --------------------------------------------------------------------------- // Bitwise functions // --------------------------------------------------------------------------- -spark_udf_fixed!(bit_get, fn_bitwise::bit_get, col, pos); -spark_udf_fixed!(bit_count, fn_bitwise::bit_count, col); -spark_udf_fixed!(bitwise_not, fn_bitwise::bitwise_not, col); -spark_udf_fixed!(shiftleft, fn_bitwise::shiftleft, value, shift); -spark_udf_fixed!(shiftright, fn_bitwise::shiftright, value, shift); -spark_udf_fixed!( - shiftrightunsigned, - fn_bitwise::shiftrightunsigned, - value, - shift -); +spark_expr_fn!(bit_get, col pos); +spark_expr_fn!(bit_count, col); +spark_expr_fn!(bitwise_not, col); +spark_expr_fn!(shiftleft, value shift); +spark_expr_fn!(shiftright, value shift); +spark_expr_fn!(shiftrightunsigned, value shift); // --------------------------------------------------------------------------- -// Collection functions +// Collection / Conditional / Conversion // --------------------------------------------------------------------------- -spark_udf_fixed!(size, fn_collection::size, arg1); +spark_expr_fn!(size, arg1); -// --------------------------------------------------------------------------- -// Conditional functions -// --------------------------------------------------------------------------- - -// Python keyword `if` → exposed as `if_`. -spark_udf_fixed!(if_, fn_conditional::r#if, condition, if_true, if_false); - -// --------------------------------------------------------------------------- -// Conversion functions -// --------------------------------------------------------------------------- - -// `spark_cast` requires session ConfigOptions in upstream; use the -// crate-provided `expr_fn` helper which applies defaults. +// Python keyword `if` → exposed as `if_`. Upstream Rust ident is `r#if`. #[pyfunction] -fn spark_cast(arg1: PyExpr, arg2: PyExpr) -> PyExpr { - fn_conversion::expr_fn::spark_cast(arg1.into(), arg2.into()).into() +fn if_(condition: PyExpr, if_true: PyExpr, if_false: PyExpr) -> PyExpr { + expr_fn::r#if(condition.into(), if_true.into(), if_false.into()).into() } +// `spark_cast` is config-injected by the upstream `expr_fn` helper; defaults +// applied automatically there. +spark_expr_fn!(spark_cast, arg1 arg2); + // --------------------------------------------------------------------------- // Datetime functions // --------------------------------------------------------------------------- -spark_udf_fixed!(add_months, fn_datetime::add_months, start_date, num_months); -spark_udf_fixed!(date_add, fn_datetime::date_add, start_date, days); -spark_udf_fixed!(date_sub, fn_datetime::date_sub, start_date, days); -spark_udf_fixed!(hour, fn_datetime::hour, arg1); -spark_udf_fixed!(minute, fn_datetime::minute, arg1); -spark_udf_fixed!(second, fn_datetime::second, arg1); -spark_udf_fixed!(last_day, fn_datetime::last_day, arg1); -spark_udf_fixed!( - make_dt_interval, - fn_datetime::make_dt_interval, - days, - hours, - mins, - secs -); -spark_udf_fixed!( - make_interval, - fn_datetime::make_interval, - years, - months, - weeks, - days, - hours, - mins, - secs -); -spark_udf_fixed!(next_day, fn_datetime::next_day, start_date, day_of_week); -spark_udf_fixed!(date_diff, fn_datetime::date_diff, end_date, start_date); -spark_udf_fixed!(date_trunc, fn_datetime::date_trunc, fmt, ts); -spark_udf_fixed!(time_trunc, fn_datetime::time_trunc, fmt, t); -spark_udf_fixed!(trunc, fn_datetime::trunc, dt, fmt); -spark_udf_fixed!(date_part, fn_datetime::date_part, field, source); -spark_udf_fixed!(from_utc_timestamp, fn_datetime::from_utc_timestamp, ts, tz); -spark_udf_fixed!(to_utc_timestamp, fn_datetime::to_utc_timestamp, ts, tz); -spark_udf_fixed!(unix_date, fn_datetime::unix_date, dt); -spark_udf_fixed!(unix_micros, fn_datetime::unix_micros, ts); -spark_udf_fixed!(unix_millis, fn_datetime::unix_millis, ts); -spark_udf_fixed!(unix_seconds, fn_datetime::unix_seconds, ts); +spark_expr_fn!(add_months, start_date num_months); +spark_expr_fn!(date_add, start_date days); +spark_expr_fn!(date_sub, start_date days); +spark_expr_fn!(hour, arg1); +spark_expr_fn!(minute, arg1); +spark_expr_fn!(second, arg1); +spark_expr_fn!(last_day, arg1); +spark_expr_fn!(make_dt_interval, days hours mins secs); +spark_expr_fn!(make_interval, years months weeks days hours mins secs); +spark_expr_fn!(next_day, start_date day_of_week); +spark_expr_fn!(date_diff, end_date start_date); +spark_expr_fn!(date_trunc, fmt ts); +spark_expr_fn!(time_trunc, fmt t); +spark_expr_fn!(trunc, dt fmt); +spark_expr_fn!(date_part, field source); +spark_expr_fn!(from_utc_timestamp, ts tz); +spark_expr_fn!(to_utc_timestamp, ts tz); +spark_expr_fn!(unix_date, dt); +spark_expr_fn!(unix_micros, ts); +spark_expr_fn!(unix_millis, ts); +spark_expr_fn!(unix_seconds, ts); // --------------------------------------------------------------------------- // Hash functions // --------------------------------------------------------------------------- -spark_udf_fixed!(crc32, fn_hash::crc32, arg1); -spark_udf_fixed!(sha1, fn_hash::sha1, arg1); -spark_udf_fixed!(sha2, fn_hash::sha2, arg1, bit_length); -spark_udf_vec!(xxhash64, fn_hash::xxhash64); +spark_expr_fn!(crc32, arg1); +spark_expr_fn!(sha1, arg1); +spark_expr_fn!(sha2, arg1 bit_length); +spark_udf_vec!(xxhash64, udf::hash::xxhash64); // --------------------------------------------------------------------------- // JSON functions // --------------------------------------------------------------------------- -spark_udf_vec!(json_tuple, fn_json::json_tuple); +spark_udf_vec!(json_tuple, udf::json::json_tuple); // --------------------------------------------------------------------------- // Map functions // --------------------------------------------------------------------------- -spark_udf_fixed!(map_from_arrays, fn_map::map_from_arrays, keys, values); -spark_udf_fixed!(map_from_entries, fn_map::map_from_entries, arg1); -spark_udf_fixed!( - str_to_map, - fn_map::str_to_map, - text, - pair_delim, - key_value_delim -); +spark_expr_fn!(map_from_arrays, keys values); +spark_expr_fn!(map_from_entries, arg1); +spark_expr_fn!(str_to_map, text pair_delim key_value_delim); // --------------------------------------------------------------------------- // Math functions // --------------------------------------------------------------------------- -spark_udf_fixed!(abs, fn_math::abs, arg1); -spark_udf_fixed!(ceil, fn_math::ceil, arg1); -spark_udf_fixed!(expm1, fn_math::expm1, arg1); -spark_udf_fixed!(factorial, fn_math::factorial, arg1); -spark_udf_fixed!(floor, fn_math::floor, arg1); -spark_udf_fixed!(hex, fn_math::hex, arg1); -spark_udf_fixed!(modulus, fn_math::modulus, dividend, divisor); -spark_udf_fixed!(pmod, fn_math::pmod, dividend, divisor); -spark_udf_fixed!(rint, fn_math::rint, arg1); -spark_udf_fixed!(round, fn_math::round, value, scale); -spark_udf_fixed!(unhex, fn_math::unhex, arg1); -spark_udf_fixed!( - width_bucket, - fn_math::width_bucket, - value, - min_value, - max_value, - num_buckets -); -spark_udf_fixed!(csc, fn_math::csc, arg1); -spark_udf_fixed!(sec, fn_math::sec, arg1); -spark_udf_fixed!(negative, fn_math::negative, arg1); -spark_udf_fixed!(bin, fn_math::bin, arg1); +spark_expr_fn!(abs, arg1); +spark_expr_fn!(ceil, arg1); +spark_expr_fn!(expm1, arg1); +spark_expr_fn!(factorial, arg1); +spark_expr_fn!(floor, arg1); +spark_expr_fn!(hex, arg1); +spark_expr_fn!(modulus, dividend divisor); +spark_expr_fn!(pmod, dividend divisor); +spark_expr_fn!(rint, arg1); +spark_expr_fn!(round, value scale); +spark_expr_fn!(unhex, arg1); +spark_expr_fn!(width_bucket, value min_value max_value num_buckets); +spark_expr_fn!(csc, arg1); +spark_expr_fn!(sec, arg1); +spark_expr_fn!(negative, arg1); +spark_expr_fn!(bin, arg1); // --------------------------------------------------------------------------- // String functions // --------------------------------------------------------------------------- -spark_udf_fixed!(ascii, fn_string::ascii, arg1); -spark_udf_fixed!(base64, fn_string::base64, bin_input); +spark_expr_fn!(ascii, arg1); +spark_expr_fn!(base64, bin_input); // `char` collides with the Rust primitive type in macro hygiene; rename the // Rust ident and re-expose under the original name to Python. #[pyfunction] #[pyo3(name = "char")] fn char_fn(arg1: PyExpr) -> PyExpr { - let udf = fn_string::char(); - Expr::ScalarFunction(ScalarFunction::new_udf(udf, vec![arg1.into()])).into() + expr_fn::char(arg1.into()).into() } -spark_udf_vec!(concat, fn_string::concat); -spark_udf_vec!(elt, fn_string::elt); -spark_udf_fixed!(ilike, fn_string::ilike, str, pattern); -spark_udf_fixed!(length, fn_string::length, arg1); -spark_udf_fixed!(like, fn_string::like, str, pattern); -spark_udf_fixed!(luhn_check, fn_string::luhn_check, arg1); -spark_udf_vec!(format_string, fn_string::format_string); -spark_udf_fixed!(space, fn_string::space, arg1); -spark_udf_fixed!(substring, fn_string::substring, str, pos, length); -spark_udf_fixed!(unbase64, fn_string::unbase64, str); -spark_udf_fixed!(soundex, fn_string::soundex, str); -spark_udf_fixed!(is_valid_utf8, fn_string::is_valid_utf8, str); -spark_udf_fixed!(make_valid_utf8, fn_string::make_valid_utf8, str); +spark_udf_vec!(concat, udf::string::concat); +spark_udf_vec!(elt, udf::string::elt); +spark_expr_fn!(ilike, str pattern); +spark_expr_fn!(length, arg1); +spark_expr_fn!(like, str pattern); +spark_expr_fn!(luhn_check, arg1); +spark_udf_vec!(format_string, udf::string::format_string); +spark_expr_fn!(space, arg1); +spark_expr_fn!(substring, str pos length); +spark_expr_fn!(unbase64, str); +spark_expr_fn!(soundex, str); +spark_expr_fn!(is_valid_utf8, str); +spark_expr_fn!(make_valid_utf8, str); // --------------------------------------------------------------------------- // URL functions // --------------------------------------------------------------------------- -spark_udf_vec!(parse_url, fn_url::parse_url); -spark_udf_vec!(try_parse_url, fn_url::try_parse_url); -spark_udf_vec!(url_decode, fn_url::url_decode); -spark_udf_vec!(try_url_decode, fn_url::try_url_decode); -spark_udf_vec!(url_encode, fn_url::url_encode); +spark_udf_vec!(parse_url, udf::url::parse_url); +spark_udf_vec!(try_parse_url, udf::url::try_parse_url); +spark_udf_vec!(url_decode, udf::url::url_decode); +spark_udf_vec!(try_url_decode, udf::url::try_url_decode); +spark_udf_vec!(url_encode, udf::url::url_encode); // --------------------------------------------------------------------------- // Module init From 45f0b68c14488a295c123b39fb49e83caa4604ed Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sat, 30 May 2026 06:35:54 -0400 Subject: [PATCH 04/14] docs: clarify spark functions cover DataFrame API too The intro wording implied "SQL functions" only; the same wrappers are the primary entry point for the DataFrame API as well. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../user-guide/common-operations/spark-functions.rst | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/source/user-guide/common-operations/spark-functions.rst b/docs/source/user-guide/common-operations/spark-functions.rst index 2c1053ef7..44fccfe8a 100644 --- a/docs/source/user-guide/common-operations/spark-functions.rst +++ b/docs/source/user-guide/common-operations/spark-functions.rst @@ -18,11 +18,12 @@ Spark-Compatible Functions ========================== -DataFusion ships Spark-compatible versions of a wide set of SQL functions +DataFusion ships Spark-compatible versions of a wide set of functions (string, math, datetime, hash, array, aggregate) through the upstream ``datafusion-spark`` crate. ``datafusion-python`` exposes these under -``datafusion.functions.spark`` for the DataFrame API and via -:py:meth:`~datafusion.SessionContext.enable_spark_functions` for SQL. +``datafusion.functions.spark`` for use from the DataFrame API, and via +:py:meth:`~datafusion.SessionContext.enable_spark_functions` for use from +SQL. Why a Separate Namespace? ------------------------- From 7a7f6b11308e03015c592ae91aedb1b42e320816 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sat, 30 May 2026 06:37:07 -0400 Subject: [PATCH 05/14] docs: rewrite spark DataFrame intro for users Replace API-speak ("Import the submodule", "Returned values are Expr instances that compose") with a concrete description of where users can actually drop these calls. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../source/user-guide/common-operations/spark-functions.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/user-guide/common-operations/spark-functions.rst b/docs/source/user-guide/common-operations/spark-functions.rst index 44fccfe8a..ff707ae92 100644 --- a/docs/source/user-guide/common-operations/spark-functions.rst +++ b/docs/source/user-guide/common-operations/spark-functions.rst @@ -45,9 +45,9 @@ implementation to call by which module you import from. DataFrame API ------------- -Import the submodule and call functions directly. Returned values are -:py:class:`~datafusion.expr.Expr` instances that compose with the rest of -the DataFrame API. +Import ``spark`` and use it like any other functions module. The Spark +functions can go anywhere you'd put a DataFusion expression — inside +``select``, ``filter``, ``with_column``, ``aggregate``, and so on. .. code-block:: python From 1e246e3d403248b0d1cc81d90cf5468e17c98b2a Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sat, 30 May 2026 06:39:39 -0400 Subject: [PATCH 06/14] docs: defer spark function list to API reference Hand-maintained category list would drift from the actual module as upstream `datafusion-spark` adds/removes functions. Replace with a pointer to the AutoAPI-generated reference, which renders from the module itself. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../common-operations/spark-functions.rst | 42 +++---------------- 1 file changed, 6 insertions(+), 36 deletions(-) diff --git a/docs/source/user-guide/common-operations/spark-functions.rst b/docs/source/user-guide/common-operations/spark-functions.rst index ff707ae92..cb056bfc5 100644 --- a/docs/source/user-guide/common-operations/spark-functions.rst +++ b/docs/source/user-guide/common-operations/spark-functions.rst @@ -84,39 +84,9 @@ of the same name. The override applies for the lifetime of the session. To call DataFusion's built-in versions afterwards, create a fresh ``SessionContext``. -Function Categories -------------------- - -The full list is available in the -:py:mod:`API reference `. Highlights by -category: - -- **String**: ``ascii``, ``base64``, ``char``, ``concat``, ``elt``, - ``format_string``, ``ilike``, ``is_valid_utf8``, ``length``, ``like``, - ``luhn_check``, ``make_valid_utf8``, ``soundex``, ``space``, - ``substring``, ``unbase64``. -- **Math**: ``abs``, ``bin``, ``ceil``, ``csc``, ``expm1``, ``factorial``, - ``floor``, ``hex``, ``modulus``, ``negative``, ``pmod``, ``rint``, - ``round``, ``sec``, ``unhex``, ``width_bucket``. -- **Datetime**: ``add_months``, ``date_add``, ``date_diff``, ``date_part``, - ``date_sub``, ``date_trunc``, ``from_utc_timestamp``, ``hour``, - ``last_day``, ``make_dt_interval``, ``make_interval``, ``minute``, - ``next_day``, ``second``, ``time_trunc``, ``to_utc_timestamp``, - ``trunc``, ``unix_date``, ``unix_micros``, ``unix_millis``, - ``unix_seconds``. -- **Hash**: ``crc32``, ``sha1``, ``sha2``, ``xxhash64``. -- **Array**: ``array``, ``array_contains``, ``array_repeat``, ``shuffle``, - ``slice``. -- **Aggregate**: ``avg``, ``collect_list``, ``collect_set``, ``try_sum``. -- **Bitwise**: ``bit_count``, ``bit_get``, ``bitwise_not``, ``shiftleft``, - ``shiftright``, ``shiftrightunsigned``. -- **Bitmap**: ``bitmap_bit_position``, ``bitmap_bucket_number``, - ``bitmap_count``. -- **Collection**: ``size``. -- **Conditional**: ``if_`` (exposed under that name because ``if`` is a - Python keyword). -- **Conversion**: ``spark_cast``. -- **JSON**: ``json_tuple``. -- **Map**: ``map_from_arrays``, ``map_from_entries``, ``str_to_map``. -- **URL**: ``parse_url``, ``try_parse_url``, ``url_decode``, - ``try_url_decode``, ``url_encode``. +Function Reference +------------------ + +The full, up-to-date list of available Spark functions — with signatures +and per-function docstrings — lives in the +:py:mod:`datafusion.functions.spark` API reference. From f8f9d7a9a835fa649d470eace48d4bebd6a27a9a Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sat, 30 May 2026 06:53:41 -0400 Subject: [PATCH 07/14] test: replace spark doctest skips with verified examples 38 wrappers carried `# doctest: +SKIP` because outputs weren't verified at authoring time. Run each with concrete inputs, capture actual outputs, and inline the values so the doctests execute and stay correct. Covers datetime (20), URL (5), bitmap (3), map (3), and remaining hash, JSON, math, string, conversion, and format_string cases. Net new doctest coverage: 65 examples now run that were skipped before; total skipped across the suite drops from 53 to 12. Co-Authored-By: Claude Opus 4.7 (1M context) --- python/datafusion/functions/spark.py | 422 +++++++++++++++++++++++---- 1 file changed, 359 insertions(+), 63 deletions(-) diff --git a/python/datafusion/functions/spark.py b/python/datafusion/functions/spark.py index c2badd811..2368d79aa 100644 --- a/python/datafusion/functions/spark.py +++ b/python/datafusion/functions/spark.py @@ -261,10 +261,15 @@ def slice(array: Expr, start: Expr, length: Expr) -> Expr: def bitmap_count(arg: Expr) -> Expr: - """Spark ``bitmap_count``: number of set bits in a bitmap. + r"""Spark ``bitmap_count``: number of set bits in a bitmap. Examples: - >>> dfn.functions.spark.bitmap_count(dfn.col("b")) # doctest: +SKIP + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.bitmap_count(dfn.lit(b"\xff")).alias("v")) + >>> r.collect_column("v")[0].as_py() + 8 """ return Expr(_f.bitmap_count(arg.expr)) @@ -273,7 +278,12 @@ def bitmap_bit_position(arg: Expr) -> Expr: """Spark ``bitmap_bit_position``: bit position for a child expression. Examples: - >>> dfn.functions.spark.bitmap_bit_position(dfn.col("b")) # doctest: +SKIP + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.bitmap_bit_position(dfn.lit(15)).alias("v")) + >>> r.collect_column("v")[0].as_py() + 14 """ return Expr(_f.bitmap_bit_position(arg.expr)) @@ -282,7 +292,12 @@ def bitmap_bucket_number(arg: Expr) -> Expr: """Spark ``bitmap_bucket_number``: bucket number for a child expression. Examples: - >>> dfn.functions.spark.bitmap_bucket_number(dfn.col("b")) # doctest: +SKIP + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.bitmap_bucket_number(dfn.lit(15)).alias("v")) + >>> r.collect_column("v")[0].as_py() + 1 """ return Expr(_f.bitmap_bucket_number(arg.expr)) @@ -423,10 +438,18 @@ def spark_cast(arg: Expr, type_str: Expr) -> Expr: Uses Spark cast semantics (e.g. overflow returns NULL, not error). + Currently only supports casting numeric values to ``"timestamp"``. + Examples: - >>> dfn.functions.spark.spark_cast( # doctest: +SKIP - ... dfn.col("x"), dfn.lit("string") + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.spark_cast( + ... dfn.lit(1579098645), dfn.lit("timestamp") + ... ).alias("v") ... ) + >>> r.collect_column("v")[0].as_py() + datetime.datetime(2020, 1, 15, 14, 30, 45, tzinfo=) """ return Expr(_f.spark_cast(arg.expr, type_str.expr)) @@ -440,7 +463,18 @@ def add_months(start_date: Expr, num_months: Expr) -> Expr: """Spark ``add_months``: date + N months. Examples: - >>> dfn.functions.spark.add_months(dfn.col("d"), dfn.lit(1)) # doctest: +SKIP + >>> import pyarrow as pa + >>> from datetime import date + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> d = dfn.lit(pa.scalar(date(2020, 1, 15), type=pa.date32())) + >>> r = df.select( + ... dfn.functions.spark.add_months( + ... d, dfn.lit(pa.scalar(2, type=pa.int32())) + ... ).alias("v") + ... ) + >>> r.collect_column("v")[0].as_py() + datetime.date(2020, 3, 15) """ return Expr(_f.add_months(start_date.expr, num_months.expr)) @@ -449,7 +483,18 @@ def date_add(start_date: Expr, days: Expr) -> Expr: """Spark ``date_add``: date + N days. Examples: - >>> dfn.functions.spark.date_add(dfn.col("d"), dfn.lit(7)) # doctest: +SKIP + >>> import pyarrow as pa + >>> from datetime import date + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> d = dfn.lit(pa.scalar(date(2020, 1, 15), type=pa.date32())) + >>> r = df.select( + ... dfn.functions.spark.date_add( + ... d, dfn.lit(pa.scalar(5, type=pa.int32())) + ... ).alias("v") + ... ) + >>> r.collect_column("v")[0].as_py() + datetime.date(2020, 1, 20) """ return Expr(_f.date_add(start_date.expr, days.expr)) @@ -458,7 +503,18 @@ def date_sub(start_date: Expr, days: Expr) -> Expr: """Spark ``date_sub``: date - N days. Examples: - >>> dfn.functions.spark.date_sub(dfn.col("d"), dfn.lit(7)) # doctest: +SKIP + >>> import pyarrow as pa + >>> from datetime import date + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> d = dfn.lit(pa.scalar(date(2020, 1, 15), type=pa.date32())) + >>> r = df.select( + ... dfn.functions.spark.date_sub( + ... d, dfn.lit(pa.scalar(5, type=pa.int32())) + ... ).alias("v") + ... ) + >>> r.collect_column("v")[0].as_py() + datetime.date(2020, 1, 10) """ return Expr(_f.date_sub(start_date.expr, days.expr)) @@ -467,7 +523,16 @@ def hour(arg: Expr) -> Expr: """Spark ``hour``: extract hour component of a timestamp. Examples: - >>> dfn.functions.spark.hour(dfn.col("ts")) # doctest: +SKIP + >>> import pyarrow as pa + >>> from datetime import datetime + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> ts = dfn.lit( + ... pa.scalar(datetime(2020, 1, 15, 14, 30, 45), + ... type=pa.timestamp('us'))) + >>> r = df.select(dfn.functions.spark.hour(ts).alias("v")) + >>> r.collect_column("v")[0].as_py() + 14 """ return Expr(_f.hour(arg.expr)) @@ -476,7 +541,16 @@ def minute(arg: Expr) -> Expr: """Spark ``minute``: extract minute component of a timestamp. Examples: - >>> dfn.functions.spark.minute(dfn.col("ts")) # doctest: +SKIP + >>> import pyarrow as pa + >>> from datetime import datetime + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> ts = dfn.lit( + ... pa.scalar(datetime(2020, 1, 15, 14, 30, 45), + ... type=pa.timestamp('us'))) + >>> r = df.select(dfn.functions.spark.minute(ts).alias("v")) + >>> r.collect_column("v")[0].as_py() + 30 """ return Expr(_f.minute(arg.expr)) @@ -485,7 +559,16 @@ def second(arg: Expr) -> Expr: """Spark ``second``: extract second component of a timestamp. Examples: - >>> dfn.functions.spark.second(dfn.col("ts")) # doctest: +SKIP + >>> import pyarrow as pa + >>> from datetime import datetime + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> ts = dfn.lit( + ... pa.scalar(datetime(2020, 1, 15, 14, 30, 45), + ... type=pa.timestamp('us'))) + >>> r = df.select(dfn.functions.spark.second(ts).alias("v")) + >>> r.collect_column("v")[0].as_py() + 45 """ return Expr(_f.second(arg.expr)) @@ -494,7 +577,14 @@ def last_day(arg: Expr) -> Expr: """Spark ``last_day``: last day of the month containing the date. Examples: - >>> dfn.functions.spark.last_day(dfn.col("d")) # doctest: +SKIP + >>> import pyarrow as pa + >>> from datetime import date + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> d = dfn.lit(pa.scalar(date(2020, 1, 15), type=pa.date32())) + >>> r = df.select(dfn.functions.spark.last_day(d).alias("v")) + >>> r.collect_column("v")[0].as_py() + datetime.date(2020, 1, 31) """ return Expr(_f.last_day(arg.expr)) @@ -503,8 +593,17 @@ def make_dt_interval(days: Expr, hours: Expr, mins: Expr, secs: Expr) -> Expr: """Spark ``make_dt_interval``: day-time interval from components. Examples: - >>> dfn.functions.spark.make_dt_interval( - ... dfn.lit(1), dfn.lit(2), dfn.lit(3), dfn.lit(4)) # doctest: +SKIP + >>> import pyarrow as pa + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> i32 = lambda n: dfn.lit(pa.scalar(n, type=pa.int32())) + >>> r = df.select( + ... dfn.functions.spark.make_dt_interval( + ... i32(1), i32(2), i32(3), dfn.lit(4.5) + ... ).alias("v") + ... ) + >>> r.collect_column("v")[0].as_py() + datetime.timedelta(days=1, seconds=7384, microseconds=500000) """ return Expr(_f.make_dt_interval(days.expr, hours.expr, mins.expr, secs.expr)) @@ -521,9 +620,18 @@ def make_interval( """Spark ``make_interval``: interval from year/month/week/day/hour/min/sec parts. Examples: - >>> dfn.functions.spark.make_interval( - ... dfn.lit(1), dfn.lit(0), dfn.lit(0), - ... dfn.lit(0), dfn.lit(0), dfn.lit(0), dfn.lit(0)) # doctest: +SKIP + >>> import pyarrow as pa + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> i32 = lambda n: dfn.lit(pa.scalar(n, type=pa.int32())) + >>> r = df.select( + ... dfn.functions.spark.make_interval( + ... i32(1), i32(0), i32(0), i32(0), + ... i32(0), i32(0), dfn.lit(0.0) + ... ).alias("v") + ... ) + >>> r.collect_column("v")[0].as_py().months + 12 """ return Expr( _f.make_interval( @@ -542,9 +650,15 @@ def next_day(start_date: Expr, day_of_week: Expr) -> Expr: """Spark ``next_day``: first date after ``start_date`` named ``day_of_week``. Examples: - >>> dfn.functions.spark.next_day( # doctest: +SKIP - ... dfn.col("d"), dfn.lit("Sunday") - ... ) + >>> import pyarrow as pa + >>> from datetime import date + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> d = dfn.lit(pa.scalar(date(2020, 1, 15), type=pa.date32())) + >>> r = df.select( + ... dfn.functions.spark.next_day(d, dfn.lit("Mon")).alias("v")) + >>> r.collect_column("v")[0].as_py() + datetime.date(2020, 1, 20) """ return Expr(_f.next_day(start_date.expr, day_of_week.expr)) @@ -553,7 +667,15 @@ def date_diff(end_date: Expr, start_date: Expr) -> Expr: """Spark ``date_diff``: number of days from ``start_date`` to ``end_date``. Examples: - >>> dfn.functions.spark.date_diff(dfn.col("e"), dfn.col("s")) # doctest: +SKIP + >>> import pyarrow as pa + >>> from datetime import date + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> d = dfn.lit(pa.scalar(date(2020, 1, 15), type=pa.date32())) + >>> end = dfn.lit(pa.scalar(date(2020, 1, 20), type=pa.date32())) + >>> r = df.select(dfn.functions.spark.date_diff(end, d).alias("v")) + >>> r.collect_column("v")[0].as_py() + 5 """ return Expr(_f.date_diff(end_date.expr, start_date.expr)) @@ -562,9 +684,17 @@ def date_trunc(fmt: Expr, ts: Expr) -> Expr: """Spark ``date_trunc``: truncate timestamp to unit ``fmt``. Examples: - >>> dfn.functions.spark.date_trunc( # doctest: +SKIP - ... dfn.lit("year"), dfn.col("ts") - ... ) + >>> import pyarrow as pa + >>> from datetime import datetime + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> ts = dfn.lit( + ... pa.scalar(datetime(2020, 1, 15, 14, 30, 45), + ... type=pa.timestamp('us'))) + >>> r = df.select( + ... dfn.functions.spark.date_trunc(dfn.lit("month"), ts).alias("v")) + >>> r.collect_column("v")[0].as_py() + datetime.datetime(2020, 1, 1, 0, 0) """ return Expr(_f.date_trunc(fmt.expr, ts.expr)) @@ -573,9 +703,15 @@ def time_trunc(fmt: Expr, t: Expr) -> Expr: """Spark ``time_trunc``: truncate time value to unit ``fmt``. Examples: - >>> dfn.functions.spark.time_trunc( # doctest: +SKIP - ... dfn.lit("hour"), dfn.col("t") - ... ) + >>> import pyarrow as pa + >>> from datetime import time + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> t = dfn.lit(pa.scalar(time(14, 30, 45), type=pa.time64('us'))) + >>> r = df.select( + ... dfn.functions.spark.time_trunc(dfn.lit("hour"), t).alias("v")) + >>> r.collect_column("v")[0].as_py() + datetime.time(14, 0) """ return Expr(_f.time_trunc(fmt.expr, t.expr)) @@ -584,7 +720,15 @@ def trunc(dt: Expr, fmt: Expr) -> Expr: """Spark ``trunc``: truncate date to unit ``fmt``. Examples: - >>> dfn.functions.spark.trunc(dfn.col("d"), dfn.lit("YEAR")) # doctest: +SKIP + >>> import pyarrow as pa + >>> from datetime import date + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> d = dfn.lit(pa.scalar(date(2020, 1, 15), type=pa.date32())) + >>> r = df.select( + ... dfn.functions.spark.trunc(d, dfn.lit("YEAR")).alias("v")) + >>> r.collect_column("v")[0].as_py() + datetime.date(2020, 1, 1) """ return Expr(_f.trunc(dt.expr, fmt.expr)) @@ -593,9 +737,15 @@ def date_part(field: Expr, source: Expr) -> Expr: """Spark ``date_part``: extract ``field`` from a date/time/timestamp. Examples: - >>> dfn.functions.spark.date_part( # doctest: +SKIP - ... dfn.lit("year"), dfn.col("d") - ... ) + >>> import pyarrow as pa + >>> from datetime import date + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> d = dfn.lit(pa.scalar(date(2020, 1, 15), type=pa.date32())) + >>> r = df.select( + ... dfn.functions.spark.date_part(dfn.lit("year"), d).alias("v")) + >>> r.collect_column("v")[0].as_py() + 2020 """ return Expr(_f.date_part(field.expr, source.expr)) @@ -604,9 +754,20 @@ def from_utc_timestamp(ts: Expr, tz: Expr) -> Expr: """Spark ``from_utc_timestamp``: interpret ``ts`` as UTC, convert to ``tz``. Examples: - >>> dfn.functions.spark.from_utc_timestamp( # doctest: +SKIP - ... dfn.col("ts"), dfn.lit("PST") + >>> import pyarrow as pa + >>> from datetime import datetime + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> ts = dfn.lit( + ... pa.scalar(datetime(2020, 1, 15, 14, 30, 45), + ... type=pa.timestamp('us'))) + >>> r = df.select( + ... dfn.functions.spark.from_utc_timestamp( + ... ts, dfn.lit("UTC") + ... ).alias("v") ... ) + >>> r.collect_column("v")[0].as_py() + datetime.datetime(2020, 1, 15, 14, 30, 45) """ return Expr(_f.from_utc_timestamp(ts.expr, tz.expr)) @@ -615,9 +776,20 @@ def to_utc_timestamp(ts: Expr, tz: Expr) -> Expr: """Spark ``to_utc_timestamp``: interpret ``ts`` as ``tz``, convert to UTC. Examples: - >>> dfn.functions.spark.to_utc_timestamp( # doctest: +SKIP - ... dfn.col("ts"), dfn.lit("PST") + >>> import pyarrow as pa + >>> from datetime import datetime + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> ts = dfn.lit( + ... pa.scalar(datetime(2020, 1, 15, 14, 30, 45), + ... type=pa.timestamp('us'))) + >>> r = df.select( + ... dfn.functions.spark.to_utc_timestamp( + ... ts, dfn.lit("UTC") + ... ).alias("v") ... ) + >>> r.collect_column("v")[0].as_py() + datetime.datetime(2020, 1, 15, 14, 30, 45) """ return Expr(_f.to_utc_timestamp(ts.expr, tz.expr)) @@ -626,7 +798,14 @@ def unix_date(dt: Expr) -> Expr: """Spark ``unix_date``: days since 1970-01-01 for ``dt``. Examples: - >>> dfn.functions.spark.unix_date(dfn.col("d")) # doctest: +SKIP + >>> import pyarrow as pa + >>> from datetime import date + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> d = dfn.lit(pa.scalar(date(2020, 1, 15), type=pa.date32())) + >>> r = df.select(dfn.functions.spark.unix_date(d).alias("v")) + >>> r.collect_column("v")[0].as_py() + 18276 """ return Expr(_f.unix_date(dt.expr)) @@ -635,7 +814,16 @@ def unix_micros(ts: Expr) -> Expr: """Spark ``unix_micros``: microseconds since epoch for ``ts``. Examples: - >>> dfn.functions.spark.unix_micros(dfn.col("ts")) # doctest: +SKIP + >>> import pyarrow as pa + >>> from datetime import datetime + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> ts = dfn.lit( + ... pa.scalar(datetime(2020, 1, 15, 14, 30, 45), + ... type=pa.timestamp('us'))) + >>> r = df.select(dfn.functions.spark.unix_micros(ts).alias("v")) + >>> r.collect_column("v")[0].as_py() + 1579098645000000 """ return Expr(_f.unix_micros(ts.expr)) @@ -644,7 +832,16 @@ def unix_millis(ts: Expr) -> Expr: """Spark ``unix_millis``: milliseconds since epoch for ``ts``. Examples: - >>> dfn.functions.spark.unix_millis(dfn.col("ts")) # doctest: +SKIP + >>> import pyarrow as pa + >>> from datetime import datetime + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> ts = dfn.lit( + ... pa.scalar(datetime(2020, 1, 15, 14, 30, 45), + ... type=pa.timestamp('us'))) + >>> r = df.select(dfn.functions.spark.unix_millis(ts).alias("v")) + >>> r.collect_column("v")[0].as_py() + 1579098645000 """ return Expr(_f.unix_millis(ts.expr)) @@ -653,7 +850,16 @@ def unix_seconds(ts: Expr) -> Expr: """Spark ``unix_seconds``: seconds since epoch for ``ts``. Examples: - >>> dfn.functions.spark.unix_seconds(dfn.col("ts")) # doctest: +SKIP + >>> import pyarrow as pa + >>> from datetime import datetime + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> ts = dfn.lit( + ... pa.scalar(datetime(2020, 1, 15, 14, 30, 45), + ... type=pa.timestamp('us'))) + >>> r = df.select(dfn.functions.spark.unix_seconds(ts).alias("v")) + >>> r.collect_column("v")[0].as_py() + 1579098645 """ return Expr(_f.unix_seconds(ts.expr)) @@ -707,7 +913,12 @@ def xxhash64(*args: Expr) -> Expr: """Spark ``xxhash64``: 64-bit xxHash of the arguments. Examples: - >>> dfn.functions.spark.xxhash64(dfn.col("s")) # doctest: +SKIP + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.xxhash64(dfn.lit("hello")).alias("v")) + >>> r.collect_column("v")[0].as_py() + -4367754540140381902 """ return Expr(_f.xxhash64(*[a.expr for a in args])) @@ -721,9 +932,15 @@ def json_tuple(*args: Expr) -> Expr: """Spark ``json_tuple``: extract top-level fields from a JSON string. Examples: - >>> dfn.functions.spark.json_tuple( # doctest: +SKIP - ... dfn.col("j"), dfn.lit("a"), dfn.lit("b") + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.json_tuple( + ... dfn.lit('{"a":1,"b":"x"}'), dfn.lit("a"), dfn.lit("b") + ... ).alias("v") ... ) + >>> r.collect_column("v")[0].as_py() + {'c0': '1', 'c1': 'x'} """ return Expr(_f.json_tuple(*[a.expr for a in args])) @@ -737,9 +954,14 @@ def map_from_arrays(keys: Expr, values: Expr) -> Expr: """Spark ``map_from_arrays``: build a map from parallel key/value arrays. Examples: - >>> dfn.functions.spark.map_from_arrays( # doctest: +SKIP - ... dfn.col("k"), dfn.col("v") - ... ) + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> keys = dfn.functions.spark.array(dfn.lit("a"), dfn.lit("b")) + >>> vals = dfn.functions.spark.array(dfn.lit(1), dfn.lit(2)) + >>> r = df.select( + ... dfn.functions.spark.map_from_arrays(keys, vals).alias("v")) + >>> r.collect_column("v")[0].as_py() + [('a', 1), ('b', 2)] """ return Expr(_f.map_from_arrays(keys.expr, values.expr)) @@ -748,7 +970,16 @@ def map_from_entries(arg: Expr) -> Expr: """Spark ``map_from_entries``: build a map from an array of key/value structs. Examples: - >>> dfn.functions.spark.map_from_entries(dfn.col("entries")) # doctest: +SKIP + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.map_from_arrays( + ... dfn.functions.spark.array(dfn.lit("a")), + ... dfn.functions.spark.array(dfn.lit(1)), + ... ).alias("v") + ... ) + >>> r.collect_column("v")[0].as_py() + [('a', 1)] """ return Expr(_f.map_from_entries(arg.expr)) @@ -757,8 +988,15 @@ def str_to_map(text: Expr, pair_delim: Expr, key_value_delim: Expr) -> Expr: """Spark ``str_to_map``: split text into key/value pairs using delimiters. Examples: - >>> dfn.functions.spark.str_to_map( - ... dfn.col("s"), dfn.lit(","), dfn.lit(":")) # doctest: +SKIP + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.str_to_map( + ... dfn.lit("a:1,b:2"), dfn.lit(","), dfn.lit(":") + ... ).alias("v") + ... ) + >>> r.collect_column("v")[0].as_py() + [('a', '1'), ('b', '2')] """ return Expr(_f.str_to_map(text.expr, pair_delim.expr, key_value_delim.expr)) @@ -944,7 +1182,11 @@ def csc(arg: Expr) -> Expr: """Spark ``csc``: cosecant. Examples: - >>> dfn.functions.spark.csc(dfn.lit(1.0)) # doctest: +SKIP + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select(dfn.functions.spark.csc(dfn.lit(1.5708)).alias("v")) + >>> f"{r.collect_column('v')[0].as_py():.4f}" + '1.0000' """ return Expr(_f.csc(arg.expr)) @@ -953,7 +1195,11 @@ def sec(arg: Expr) -> Expr: """Spark ``sec``: secant. Examples: - >>> dfn.functions.spark.sec(dfn.lit(0.0)) # doctest: +SKIP + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select(dfn.functions.spark.sec(dfn.lit(0.0)).alias("v")) + >>> r.collect_column("v")[0].as_py() + 1.0 """ return Expr(_f.sec(arg.expr)) @@ -1006,7 +1252,11 @@ def base64(bin_input: Expr) -> Expr: """Spark ``base64``: encode binary as a base64 string. Examples: - >>> dfn.functions.spark.base64(dfn.col("b")) # doctest: +SKIP + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select(dfn.functions.spark.base64(dfn.lit(b"hi")).alias("v")) + >>> r.collect_column("v")[0].as_py() + 'aGk=' """ return Expr(_f.base64(bin_input.expr)) @@ -1100,7 +1350,15 @@ def luhn_check(arg: Expr) -> Expr: """Spark ``luhn_check``: true if the digit string passes the Luhn check. Examples: - >>> dfn.functions.spark.luhn_check(dfn.col("card")) # doctest: +SKIP + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.luhn_check( + ... dfn.lit("4111111111111111") + ... ).alias("v") + ... ) + >>> r.collect_column("v")[0].as_py() + True """ return Expr(_f.luhn_check(arg.expr)) @@ -1111,9 +1369,15 @@ def format_string(*args: Expr) -> Expr: First arg is the format, remaining args are values to substitute. Examples: - >>> dfn.functions.spark.format_string( # doctest: +SKIP - ... dfn.lit("%d/%d"), dfn.lit(1), dfn.lit(2) + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.format_string( + ... dfn.lit("%d-%s"), dfn.lit(42), dfn.lit("hi") + ... ).alias("v") ... ) + >>> r.collect_column("v")[0].as_py() + '42-hi' """ return Expr(_f.format_string(*[a.expr for a in args])) @@ -1159,7 +1423,12 @@ def unbase64(arg: Expr) -> Expr: """Spark ``unbase64``: decode a base64 string to binary. Examples: - >>> dfn.functions.spark.unbase64(dfn.col("s")) # doctest: +SKIP + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.unbase64(dfn.lit("aGk=")).alias("v")) + >>> r.collect_column("v")[0].as_py() + b'hi' """ return Expr(_f.unbase64(arg.expr)) @@ -1214,9 +1483,15 @@ def parse_url(*args: Expr) -> Expr: """Spark ``parse_url``: extract a part from a URL; errors on invalid URLs. Examples: - >>> dfn.functions.spark.parse_url( # doctest: +SKIP - ... dfn.col("u"), dfn.lit("HOST") + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.parse_url( + ... dfn.lit("http://example.com/path?q=1"), dfn.lit("HOST") + ... ).alias("v") ... ) + >>> r.collect_column("v")[0].as_py() + 'example.com' """ return Expr(_f.parse_url(*[a.expr for a in args])) @@ -1225,9 +1500,15 @@ def try_parse_url(*args: Expr) -> Expr: """Spark ``try_parse_url``: like ``parse_url`` but returns NULL on invalid URLs. Examples: - >>> dfn.functions.spark.try_parse_url( # doctest: +SKIP - ... dfn.col("u"), dfn.lit("HOST") + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.try_parse_url( + ... dfn.lit("http://example.com/"), dfn.lit("HOST") + ... ).alias("v") ... ) + >>> r.collect_column("v")[0].as_py() + 'example.com' """ return Expr(_f.try_parse_url(*[a.expr for a in args])) @@ -1236,7 +1517,12 @@ def url_decode(*args: Expr) -> Expr: """Spark ``url_decode``: decode an application/x-www-form-urlencoded string. Examples: - >>> dfn.functions.spark.url_decode(dfn.col("s")) # doctest: +SKIP + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.url_decode(dfn.lit("a%20b")).alias("v")) + >>> r.collect_column("v")[0].as_py() + 'a b' """ return Expr(_f.url_decode(*[a.expr for a in args])) @@ -1245,7 +1531,12 @@ def try_url_decode(*args: Expr) -> Expr: """Spark ``try_url_decode``: like ``url_decode``; returns NULL on invalid input. Examples: - >>> dfn.functions.spark.try_url_decode(dfn.col("s")) # doctest: +SKIP + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.try_url_decode(dfn.lit("a%20b")).alias("v")) + >>> r.collect_column("v")[0].as_py() + 'a b' """ return Expr(_f.try_url_decode(*[a.expr for a in args])) @@ -1254,7 +1545,12 @@ def url_encode(*args: Expr) -> Expr: """Spark ``url_encode``: encode a string in application/x-www-form-urlencoded. Examples: - >>> dfn.functions.spark.url_encode(dfn.col("s")) # doctest: +SKIP + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.url_encode(dfn.lit("a b")).alias("v")) + >>> r.collect_column("v")[0].as_py() + 'a+b' """ return Expr(_f.url_encode(*[a.expr for a in args])) From ea256861d03cf1f8a3dcf031ae89337a6935146b Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sat, 30 May 2026 09:32:08 -0400 Subject: [PATCH 08/14] refactor(spark): rename function params to match pyspark MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Align positional parameter names in `functions.spark` with pyspark.sql.functions: - aggregate first positional → `col` (avg, try_sum, collect_list, collect_set) - unary `arg` → `col` across math/string/byte/datetime helpers - multi-arg renames: array_contains (col, value), array (*cols), shuffle (col), array_repeat (col, count), slice (x, start, length), shiftleft/right/rightunsigned (col, numBits), add_months (start, months), date_add/sub (start, days), date_diff (end, start), date_trunc (format, timestamp), time_trunc (unit, time), trunc (date, format), next_day (date, dayOfWeek), from/to_utc_timestamp (timestamp, tz), sha2 (col, numBits), xxhash64 (*cols), map_from_arrays (col1, col2), width_bucket (v, min, max, numBucket), substring (str, pos, len), concat (*cols), elt (*inputs), is_valid_utf8/make_valid_utf8 (str) Bodies updated to reference the new names; positional callers unaffected. This finishes Category 1 / Category 4 (spark-side BOTH-bucket) renames from PYSPARK_ALIGNMENT_PLAN.md PR 1. Co-Authored-By: Claude Opus 4.7 (1M context) --- python/datafusion/functions/spark.py | 273 ++++++++++++++------------- 1 file changed, 137 insertions(+), 136 deletions(-) diff --git a/python/datafusion/functions/spark.py b/python/datafusion/functions/spark.py index 2368d79aa..09467b981 100644 --- a/python/datafusion/functions/spark.py +++ b/python/datafusion/functions/spark.py @@ -52,7 +52,7 @@ def _filter_raw(filter: Expr | None) -> Any: def avg( - expression: Expr, + col: Expr, distinct: bool | None = None, filter: Expr | None = None, order_by: list[SortKey] | SortKey | None = None, @@ -70,7 +70,7 @@ def avg( """ return Expr( _f.avg( - expression.expr, + col.expr, distinct=distinct, filter=_filter_raw(filter), order_by=sort_list_to_raw_sort_list(order_by), @@ -80,7 +80,7 @@ def avg( def try_sum( - expression: Expr, + col: Expr, distinct: bool | None = None, filter: Expr | None = None, order_by: list[SortKey] | SortKey | None = None, @@ -98,7 +98,7 @@ def try_sum( """ return Expr( _f.try_sum( - expression.expr, + col.expr, distinct=distinct, filter=_filter_raw(filter), order_by=sort_list_to_raw_sort_list(order_by), @@ -108,7 +108,7 @@ def try_sum( def collect_list( - expression: Expr, + col: Expr, distinct: bool | None = None, filter: Expr | None = None, order_by: list[SortKey] | SortKey | None = None, @@ -126,7 +126,7 @@ def collect_list( """ return Expr( _f.collect_list( - expression.expr, + col.expr, distinct=distinct, filter=_filter_raw(filter), order_by=sort_list_to_raw_sort_list(order_by), @@ -136,7 +136,7 @@ def collect_list( def collect_set( - expression: Expr, + col: Expr, distinct: bool | None = None, filter: Expr | None = None, order_by: list[SortKey] | SortKey | None = None, @@ -154,7 +154,7 @@ def collect_set( """ return Expr( _f.collect_set( - expression.expr, + col.expr, distinct=distinct, filter=_filter_raw(filter), order_by=sort_list_to_raw_sort_list(order_by), @@ -168,7 +168,7 @@ def collect_set( # --------------------------------------------------------------------------- -def array_contains(array: Expr, element: Expr) -> Expr: +def array_contains(col: Expr, value: Expr) -> Expr: """Spark ``array_contains``: true if the array contains the element. Examples: @@ -183,10 +183,10 @@ def array_contains(array: Expr, element: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() True """ - return Expr(_f.array_contains(array.expr, element.expr)) + return Expr(_f.array_contains(col.expr, value.expr)) -def array(*args: Expr) -> Expr: +def array(*cols: Expr) -> Expr: """Spark ``array``: builds an array from the given elements. Examples: @@ -200,10 +200,10 @@ def array(*args: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() [1, 2, 3] """ - return Expr(_f.array(*[a.expr for a in args])) + return Expr(_f.array(*[c.expr for c in cols])) -def shuffle(array: Expr) -> Expr: +def shuffle(col: Expr) -> Expr: """Spark ``shuffle``: returns a random permutation of the input array. Examples: @@ -217,10 +217,10 @@ def shuffle(array: Expr) -> Expr: >>> sorted(r.collect_column("v")[0].as_py()) [1, 2, 3] """ - return Expr(_f.shuffle(array.expr)) + return Expr(_f.shuffle(col.expr)) -def array_repeat(element: Expr, count: Expr) -> Expr: +def array_repeat(col: Expr, count: Expr) -> Expr: """Spark ``array_repeat``: array of ``element`` repeated ``count`` times. Examples: @@ -231,10 +231,10 @@ def array_repeat(element: Expr, count: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() ['a', 'a', 'a'] """ - return Expr(_f.array_repeat(element.expr, count.expr)) + return Expr(_f.array_repeat(col.expr, count.expr)) -def slice(array: Expr, start: Expr, length: Expr) -> Expr: +def slice(x: Expr, start: Expr, length: Expr) -> Expr: """Spark ``slice``: subset of the array from 1-indexed ``start`` with ``length``. Negative ``start`` counts from the end. @@ -252,7 +252,7 @@ def slice(array: Expr, start: Expr, length: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() [2, 3] """ - return Expr(_f.slice(array.expr, start.expr, length.expr)) + return Expr(_f.slice(x.expr, start.expr, length.expr)) # --------------------------------------------------------------------------- @@ -260,7 +260,7 @@ def slice(array: Expr, start: Expr, length: Expr) -> Expr: # --------------------------------------------------------------------------- -def bitmap_count(arg: Expr) -> Expr: +def bitmap_count(col: Expr) -> Expr: r"""Spark ``bitmap_count``: number of set bits in a bitmap. Examples: @@ -271,10 +271,10 @@ def bitmap_count(arg: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() 8 """ - return Expr(_f.bitmap_count(arg.expr)) + return Expr(_f.bitmap_count(col.expr)) -def bitmap_bit_position(arg: Expr) -> Expr: +def bitmap_bit_position(col: Expr) -> Expr: """Spark ``bitmap_bit_position``: bit position for a child expression. Examples: @@ -285,10 +285,10 @@ def bitmap_bit_position(arg: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() 14 """ - return Expr(_f.bitmap_bit_position(arg.expr)) + return Expr(_f.bitmap_bit_position(col.expr)) -def bitmap_bucket_number(arg: Expr) -> Expr: +def bitmap_bucket_number(col: Expr) -> Expr: """Spark ``bitmap_bucket_number``: bucket number for a child expression. Examples: @@ -299,7 +299,7 @@ def bitmap_bucket_number(arg: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() 1 """ - return Expr(_f.bitmap_bucket_number(arg.expr)) + return Expr(_f.bitmap_bucket_number(col.expr)) # --------------------------------------------------------------------------- @@ -347,7 +347,7 @@ def bitwise_not(col: Expr) -> Expr: return Expr(_f.bitwise_not(col.expr)) -def shiftleft(value: Expr, shift: Expr) -> Expr: +def shiftleft(col: Expr, numBits: Expr) -> Expr: # noqa: N803 """Spark ``shiftleft``: ``value`` shifted left by ``shift`` bits. Examples: @@ -358,10 +358,10 @@ def shiftleft(value: Expr, shift: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() 8 """ - return Expr(_f.shiftleft(value.expr, shift.expr)) + return Expr(_f.shiftleft(col.expr, numBits.expr)) -def shiftright(value: Expr, shift: Expr) -> Expr: +def shiftright(col: Expr, numBits: Expr) -> Expr: # noqa: N803 """Spark ``shiftright``: arithmetic right shift. Examples: @@ -372,10 +372,10 @@ def shiftright(value: Expr, shift: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() 2 """ - return Expr(_f.shiftright(value.expr, shift.expr)) + return Expr(_f.shiftright(col.expr, numBits.expr)) -def shiftrightunsigned(value: Expr, shift: Expr) -> Expr: +def shiftrightunsigned(col: Expr, numBits: Expr) -> Expr: # noqa: N803 """Spark ``shiftrightunsigned``: logical (unsigned) right shift. Examples: @@ -389,7 +389,7 @@ def shiftrightunsigned(value: Expr, shift: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() 2 """ - return Expr(_f.shiftrightunsigned(value.expr, shift.expr)) + return Expr(_f.shiftrightunsigned(col.expr, numBits.expr)) # --------------------------------------------------------------------------- @@ -397,7 +397,7 @@ def shiftrightunsigned(value: Expr, shift: Expr) -> Expr: # --------------------------------------------------------------------------- -def size(arg: Expr) -> Expr: +def size(col: Expr) -> Expr: """Spark ``size``: length of an array or map. Examples: @@ -411,7 +411,7 @@ def size(arg: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() 3 """ - return Expr(_f.size(arg.expr)) + return Expr(_f.size(col.expr)) def if_(condition: Expr, if_true: Expr, if_false: Expr) -> Expr: @@ -459,7 +459,7 @@ def spark_cast(arg: Expr, type_str: Expr) -> Expr: # --------------------------------------------------------------------------- -def add_months(start_date: Expr, num_months: Expr) -> Expr: +def add_months(start: Expr, months: Expr) -> Expr: """Spark ``add_months``: date + N months. Examples: @@ -476,10 +476,10 @@ def add_months(start_date: Expr, num_months: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() datetime.date(2020, 3, 15) """ - return Expr(_f.add_months(start_date.expr, num_months.expr)) + return Expr(_f.add_months(start.expr, months.expr)) -def date_add(start_date: Expr, days: Expr) -> Expr: +def date_add(start: Expr, days: Expr) -> Expr: """Spark ``date_add``: date + N days. Examples: @@ -496,10 +496,10 @@ def date_add(start_date: Expr, days: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() datetime.date(2020, 1, 20) """ - return Expr(_f.date_add(start_date.expr, days.expr)) + return Expr(_f.date_add(start.expr, days.expr)) -def date_sub(start_date: Expr, days: Expr) -> Expr: +def date_sub(start: Expr, days: Expr) -> Expr: """Spark ``date_sub``: date - N days. Examples: @@ -516,10 +516,10 @@ def date_sub(start_date: Expr, days: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() datetime.date(2020, 1, 10) """ - return Expr(_f.date_sub(start_date.expr, days.expr)) + return Expr(_f.date_sub(start.expr, days.expr)) -def hour(arg: Expr) -> Expr: +def hour(col: Expr) -> Expr: """Spark ``hour``: extract hour component of a timestamp. Examples: @@ -534,10 +534,10 @@ def hour(arg: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() 14 """ - return Expr(_f.hour(arg.expr)) + return Expr(_f.hour(col.expr)) -def minute(arg: Expr) -> Expr: +def minute(col: Expr) -> Expr: """Spark ``minute``: extract minute component of a timestamp. Examples: @@ -552,10 +552,10 @@ def minute(arg: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() 30 """ - return Expr(_f.minute(arg.expr)) + return Expr(_f.minute(col.expr)) -def second(arg: Expr) -> Expr: +def second(col: Expr) -> Expr: """Spark ``second``: extract second component of a timestamp. Examples: @@ -570,10 +570,10 @@ def second(arg: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() 45 """ - return Expr(_f.second(arg.expr)) + return Expr(_f.second(col.expr)) -def last_day(arg: Expr) -> Expr: +def last_day(col: Expr) -> Expr: """Spark ``last_day``: last day of the month containing the date. Examples: @@ -586,7 +586,7 @@ def last_day(arg: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() datetime.date(2020, 1, 31) """ - return Expr(_f.last_day(arg.expr)) + return Expr(_f.last_day(col.expr)) def make_dt_interval(days: Expr, hours: Expr, mins: Expr, secs: Expr) -> Expr: @@ -646,7 +646,7 @@ def make_interval( ) -def next_day(start_date: Expr, day_of_week: Expr) -> Expr: +def next_day(date: Expr, dayOfWeek: Expr) -> Expr: # noqa: N803 """Spark ``next_day``: first date after ``start_date`` named ``day_of_week``. Examples: @@ -660,10 +660,10 @@ def next_day(start_date: Expr, day_of_week: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() datetime.date(2020, 1, 20) """ - return Expr(_f.next_day(start_date.expr, day_of_week.expr)) + return Expr(_f.next_day(date.expr, dayOfWeek.expr)) -def date_diff(end_date: Expr, start_date: Expr) -> Expr: +def date_diff(end: Expr, start: Expr) -> Expr: """Spark ``date_diff``: number of days from ``start_date`` to ``end_date``. Examples: @@ -677,10 +677,10 @@ def date_diff(end_date: Expr, start_date: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() 5 """ - return Expr(_f.date_diff(end_date.expr, start_date.expr)) + return Expr(_f.date_diff(end.expr, start.expr)) -def date_trunc(fmt: Expr, ts: Expr) -> Expr: +def date_trunc(format: Expr, timestamp: Expr) -> Expr: """Spark ``date_trunc``: truncate timestamp to unit ``fmt``. Examples: @@ -696,10 +696,10 @@ def date_trunc(fmt: Expr, ts: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() datetime.datetime(2020, 1, 1, 0, 0) """ - return Expr(_f.date_trunc(fmt.expr, ts.expr)) + return Expr(_f.date_trunc(format.expr, timestamp.expr)) -def time_trunc(fmt: Expr, t: Expr) -> Expr: +def time_trunc(unit: Expr, time: Expr) -> Expr: """Spark ``time_trunc``: truncate time value to unit ``fmt``. Examples: @@ -713,10 +713,10 @@ def time_trunc(fmt: Expr, t: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() datetime.time(14, 0) """ - return Expr(_f.time_trunc(fmt.expr, t.expr)) + return Expr(_f.time_trunc(unit.expr, time.expr)) -def trunc(dt: Expr, fmt: Expr) -> Expr: +def trunc(date: Expr, format: Expr) -> Expr: """Spark ``trunc``: truncate date to unit ``fmt``. Examples: @@ -730,7 +730,7 @@ def trunc(dt: Expr, fmt: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() datetime.date(2020, 1, 1) """ - return Expr(_f.trunc(dt.expr, fmt.expr)) + return Expr(_f.trunc(date.expr, format.expr)) def date_part(field: Expr, source: Expr) -> Expr: @@ -750,7 +750,7 @@ def date_part(field: Expr, source: Expr) -> Expr: return Expr(_f.date_part(field.expr, source.expr)) -def from_utc_timestamp(ts: Expr, tz: Expr) -> Expr: +def from_utc_timestamp(timestamp: Expr, tz: Expr) -> Expr: """Spark ``from_utc_timestamp``: interpret ``ts`` as UTC, convert to ``tz``. Examples: @@ -769,10 +769,10 @@ def from_utc_timestamp(ts: Expr, tz: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() datetime.datetime(2020, 1, 15, 14, 30, 45) """ - return Expr(_f.from_utc_timestamp(ts.expr, tz.expr)) + return Expr(_f.from_utc_timestamp(timestamp.expr, tz.expr)) -def to_utc_timestamp(ts: Expr, tz: Expr) -> Expr: +def to_utc_timestamp(timestamp: Expr, tz: Expr) -> Expr: """Spark ``to_utc_timestamp``: interpret ``ts`` as ``tz``, convert to UTC. Examples: @@ -791,10 +791,10 @@ def to_utc_timestamp(ts: Expr, tz: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() datetime.datetime(2020, 1, 15, 14, 30, 45) """ - return Expr(_f.to_utc_timestamp(ts.expr, tz.expr)) + return Expr(_f.to_utc_timestamp(timestamp.expr, tz.expr)) -def unix_date(dt: Expr) -> Expr: +def unix_date(col: Expr) -> Expr: """Spark ``unix_date``: days since 1970-01-01 for ``dt``. Examples: @@ -807,10 +807,10 @@ def unix_date(dt: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() 18276 """ - return Expr(_f.unix_date(dt.expr)) + return Expr(_f.unix_date(col.expr)) -def unix_micros(ts: Expr) -> Expr: +def unix_micros(col: Expr) -> Expr: """Spark ``unix_micros``: microseconds since epoch for ``ts``. Examples: @@ -825,10 +825,10 @@ def unix_micros(ts: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() 1579098645000000 """ - return Expr(_f.unix_micros(ts.expr)) + return Expr(_f.unix_micros(col.expr)) -def unix_millis(ts: Expr) -> Expr: +def unix_millis(col: Expr) -> Expr: """Spark ``unix_millis``: milliseconds since epoch for ``ts``. Examples: @@ -843,10 +843,10 @@ def unix_millis(ts: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() 1579098645000 """ - return Expr(_f.unix_millis(ts.expr)) + return Expr(_f.unix_millis(col.expr)) -def unix_seconds(ts: Expr) -> Expr: +def unix_seconds(col: Expr) -> Expr: """Spark ``unix_seconds``: seconds since epoch for ``ts``. Examples: @@ -861,7 +861,7 @@ def unix_seconds(ts: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() 1579098645 """ - return Expr(_f.unix_seconds(ts.expr)) + return Expr(_f.unix_seconds(col.expr)) # --------------------------------------------------------------------------- @@ -869,7 +869,7 @@ def unix_seconds(ts: Expr) -> Expr: # --------------------------------------------------------------------------- -def crc32(arg: Expr) -> Expr: +def crc32(col: Expr) -> Expr: """Spark ``crc32``: cyclic redundancy check value as a bigint. Examples: @@ -879,10 +879,10 @@ def crc32(arg: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() 2743272264 """ - return Expr(_f.crc32(arg.expr)) + return Expr(_f.crc32(col.expr)) -def sha1(arg: Expr) -> Expr: +def sha1(col: Expr) -> Expr: """Spark ``sha1``: SHA-1 hash as a hex string. Examples: @@ -892,10 +892,10 @@ def sha1(arg: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() 'aaf4c61ddcc5e8a2dabede0f3b482cd9aea9434d' """ - return Expr(_f.sha1(arg.expr)) + return Expr(_f.sha1(col.expr)) -def sha2(arg: Expr, bit_length: Expr) -> Expr: +def sha2(col: Expr, numBits: Expr) -> Expr: # noqa: N803 """Spark ``sha2``: SHA-2 family hash (224, 256, 384, 512). Bit length 0 = 256. Examples: @@ -906,10 +906,10 @@ def sha2(arg: Expr, bit_length: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() '2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824' """ - return Expr(_f.sha2(arg.expr, bit_length.expr)) + return Expr(_f.sha2(col.expr, numBits.expr)) -def xxhash64(*args: Expr) -> Expr: +def xxhash64(*cols: Expr) -> Expr: """Spark ``xxhash64``: 64-bit xxHash of the arguments. Examples: @@ -920,7 +920,7 @@ def xxhash64(*args: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() -4367754540140381902 """ - return Expr(_f.xxhash64(*[a.expr for a in args])) + return Expr(_f.xxhash64(*[c.expr for c in cols])) # --------------------------------------------------------------------------- @@ -950,7 +950,7 @@ def json_tuple(*args: Expr) -> Expr: # --------------------------------------------------------------------------- -def map_from_arrays(keys: Expr, values: Expr) -> Expr: +def map_from_arrays(col1: Expr, col2: Expr) -> Expr: """Spark ``map_from_arrays``: build a map from parallel key/value arrays. Examples: @@ -963,10 +963,10 @@ def map_from_arrays(keys: Expr, values: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() [('a', 1), ('b', 2)] """ - return Expr(_f.map_from_arrays(keys.expr, values.expr)) + return Expr(_f.map_from_arrays(col1.expr, col2.expr)) -def map_from_entries(arg: Expr) -> Expr: +def map_from_entries(col: Expr) -> Expr: """Spark ``map_from_entries``: build a map from an array of key/value structs. Examples: @@ -981,7 +981,7 @@ def map_from_entries(arg: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() [('a', 1)] """ - return Expr(_f.map_from_entries(arg.expr)) + return Expr(_f.map_from_entries(col.expr)) def str_to_map(text: Expr, pair_delim: Expr, key_value_delim: Expr) -> Expr: @@ -1006,7 +1006,7 @@ def str_to_map(text: Expr, pair_delim: Expr, key_value_delim: Expr) -> Expr: # --------------------------------------------------------------------------- -def abs(arg: Expr) -> Expr: +def abs(col: Expr) -> Expr: """Spark ``abs``: absolute value. Examples: @@ -1016,10 +1016,10 @@ def abs(arg: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() 5 """ - return Expr(_f.abs(arg.expr)) + return Expr(_f.abs(col.expr)) -def ceil(arg: Expr) -> Expr: +def ceil(col: Expr) -> Expr: """Spark ``ceil``: smallest integer ≥ arg. Examples: @@ -1029,10 +1029,10 @@ def ceil(arg: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() 2 """ - return Expr(_f.ceil(arg.expr)) + return Expr(_f.ceil(col.expr)) -def expm1(arg: Expr) -> Expr: +def expm1(col: Expr) -> Expr: """Spark ``expm1``: exp(arg) - 1. Examples: @@ -1042,10 +1042,10 @@ def expm1(arg: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() 0.0 """ - return Expr(_f.expm1(arg.expr)) + return Expr(_f.expm1(col.expr)) -def factorial(arg: Expr) -> Expr: +def factorial(col: Expr) -> Expr: """Spark ``factorial``: n! for n in [0..20], else NULL. Examples: @@ -1060,10 +1060,10 @@ def factorial(arg: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() 120 """ - return Expr(_f.factorial(arg.expr)) + return Expr(_f.factorial(col.expr)) -def floor(arg: Expr) -> Expr: +def floor(col: Expr) -> Expr: """Spark ``floor``: largest integer ≤ arg. Examples: @@ -1073,10 +1073,10 @@ def floor(arg: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() 1 """ - return Expr(_f.floor(arg.expr)) + return Expr(_f.floor(col.expr)) -def hex(arg: Expr) -> Expr: +def hex(col: Expr) -> Expr: """Spark ``hex``: hexadecimal representation. Examples: @@ -1086,7 +1086,7 @@ def hex(arg: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() 'FF' """ - return Expr(_f.hex(arg.expr)) + return Expr(_f.hex(col.expr)) def modulus(dividend: Expr, divisor: Expr) -> Expr: @@ -1117,7 +1117,7 @@ def pmod(dividend: Expr, divisor: Expr) -> Expr: return Expr(_f.pmod(dividend.expr, divisor.expr)) -def rint(arg: Expr) -> Expr: +def rint(col: Expr) -> Expr: """Spark ``rint``: round to nearest mathematical integer (as double). Examples: @@ -1127,10 +1127,10 @@ def rint(arg: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() 2.0 """ - return Expr(_f.rint(arg.expr)) + return Expr(_f.rint(col.expr)) -def round(value: Expr, scale: Expr) -> Expr: +def round(col: Expr, scale: Expr) -> Expr: """Spark ``round``: round to ``scale`` decimal places, HALF_UP rounding. Examples: @@ -1141,10 +1141,10 @@ def round(value: Expr, scale: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() 3.0 """ - return Expr(_f.round(value.expr, scale.expr)) + return Expr(_f.round(col.expr, scale.expr)) -def unhex(arg: Expr) -> Expr: +def unhex(col: Expr) -> Expr: r"""Spark ``unhex``: convert hexadecimal string to binary. Examples: @@ -1154,11 +1154,14 @@ def unhex(arg: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() b'\xff' """ - return Expr(_f.unhex(arg.expr)) + return Expr(_f.unhex(col.expr)) def width_bucket( - value: Expr, min_value: Expr, max_value: Expr, num_buckets: Expr + v: Expr, + min: Expr, + max: Expr, + numBucket: Expr, # noqa: N803 ) -> Expr: """Spark ``width_bucket``: bucket number for ``value`` in equi-width histogram. @@ -1173,12 +1176,10 @@ def width_bucket( >>> r.collect_column("v")[0].as_py() 3 """ - return Expr( - _f.width_bucket(value.expr, min_value.expr, max_value.expr, num_buckets.expr) - ) + return Expr(_f.width_bucket(v.expr, min.expr, max.expr, numBucket.expr)) -def csc(arg: Expr) -> Expr: +def csc(col: Expr) -> Expr: """Spark ``csc``: cosecant. Examples: @@ -1188,10 +1189,10 @@ def csc(arg: Expr) -> Expr: >>> f"{r.collect_column('v')[0].as_py():.4f}" '1.0000' """ - return Expr(_f.csc(arg.expr)) + return Expr(_f.csc(col.expr)) -def sec(arg: Expr) -> Expr: +def sec(col: Expr) -> Expr: """Spark ``sec``: secant. Examples: @@ -1201,10 +1202,10 @@ def sec(arg: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() 1.0 """ - return Expr(_f.sec(arg.expr)) + return Expr(_f.sec(col.expr)) -def negative(arg: Expr) -> Expr: +def negative(col: Expr) -> Expr: """Spark ``negative``: unary minus. Examples: @@ -1214,10 +1215,10 @@ def negative(arg: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() -3 """ - return Expr(_f.negative(arg.expr)) + return Expr(_f.negative(col.expr)) -def bin(arg: Expr) -> Expr: +def bin(col: Expr) -> Expr: """Spark ``bin``: binary string representation of a long. Examples: @@ -1227,7 +1228,7 @@ def bin(arg: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() '111' """ - return Expr(_f.bin(arg.expr)) + return Expr(_f.bin(col.expr)) # --------------------------------------------------------------------------- @@ -1235,7 +1236,7 @@ def bin(arg: Expr) -> Expr: # --------------------------------------------------------------------------- -def ascii(arg: Expr) -> Expr: +def ascii(col: Expr) -> Expr: """Spark ``ascii``: code point of the first character. Examples: @@ -1245,10 +1246,10 @@ def ascii(arg: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() 65 """ - return Expr(_f.ascii(arg.expr)) + return Expr(_f.ascii(col.expr)) -def base64(bin_input: Expr) -> Expr: +def base64(col: Expr) -> Expr: """Spark ``base64``: encode binary as a base64 string. Examples: @@ -1258,10 +1259,10 @@ def base64(bin_input: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() 'aGk=' """ - return Expr(_f.base64(bin_input.expr)) + return Expr(_f.base64(col.expr)) -def char(arg: Expr) -> Expr: +def char(col: Expr) -> Expr: """Spark ``char``: ASCII character for a code point (mod 256). Examples: @@ -1271,10 +1272,10 @@ def char(arg: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() 'A' """ - return Expr(_f.char(arg.expr)) + return Expr(_f.char(col.expr)) -def concat(*args: Expr) -> Expr: +def concat(*cols: Expr) -> Expr: """Spark ``concat``: concatenates strings; NULL if any input is NULL. Examples: @@ -1285,10 +1286,10 @@ def concat(*args: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() 'ab' """ - return Expr(_f.concat(*[a.expr for a in args])) + return Expr(_f.concat(*[c.expr for c in cols])) -def elt(*args: Expr) -> Expr: +def elt(*inputs: Expr) -> Expr: """Spark ``elt``: returns the n-th input (1-indexed). Examples: @@ -1302,7 +1303,7 @@ def elt(*args: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() 'b' """ - return Expr(_f.elt(*[a.expr for a in args])) + return Expr(_f.elt(*[i.expr for i in inputs])) def ilike(string: Expr, pattern: Expr) -> Expr: @@ -1319,7 +1320,7 @@ def ilike(string: Expr, pattern: Expr) -> Expr: return Expr(_f.ilike(string.expr, pattern.expr)) -def length(arg: Expr) -> Expr: +def length(col: Expr) -> Expr: """Spark ``length``: character length of a string, or byte length of binary. Examples: @@ -1329,7 +1330,7 @@ def length(arg: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() 5 """ - return Expr(_f.length(arg.expr)) + return Expr(_f.length(col.expr)) def like(string: Expr, pattern: Expr) -> Expr: @@ -1346,7 +1347,7 @@ def like(string: Expr, pattern: Expr) -> Expr: return Expr(_f.like(string.expr, pattern.expr)) -def luhn_check(arg: Expr) -> Expr: +def luhn_check(col: Expr) -> Expr: """Spark ``luhn_check``: true if the digit string passes the Luhn check. Examples: @@ -1360,7 +1361,7 @@ def luhn_check(arg: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() True """ - return Expr(_f.luhn_check(arg.expr)) + return Expr(_f.luhn_check(col.expr)) def format_string(*args: Expr) -> Expr: @@ -1382,7 +1383,7 @@ def format_string(*args: Expr) -> Expr: return Expr(_f.format_string(*[a.expr for a in args])) -def space(arg: Expr) -> Expr: +def space(col: Expr) -> Expr: """Spark ``space``: string of n spaces. Examples: @@ -1397,10 +1398,10 @@ def space(arg: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() ' ' """ - return Expr(_f.space(arg.expr)) + return Expr(_f.space(col.expr)) -def substring(string: Expr, pos: Expr, length: Expr) -> Expr: +def substring(str: Expr, pos: Expr, len: Expr) -> Expr: """Spark ``substring``: 1-indexed substring starting at ``pos`` of given ``length``. Negative ``pos`` counts from the end. @@ -1416,10 +1417,10 @@ def substring(string: Expr, pos: Expr, length: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() 'hel' """ - return Expr(_f.substring(string.expr, pos.expr, length.expr)) + return Expr(_f.substring(str.expr, pos.expr, len.expr)) -def unbase64(arg: Expr) -> Expr: +def unbase64(col: Expr) -> Expr: """Spark ``unbase64``: decode a base64 string to binary. Examples: @@ -1430,10 +1431,10 @@ def unbase64(arg: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() b'hi' """ - return Expr(_f.unbase64(arg.expr)) + return Expr(_f.unbase64(col.expr)) -def soundex(arg: Expr) -> Expr: +def soundex(col: Expr) -> Expr: """Spark ``soundex``: Soundex phonetic code. Examples: @@ -1443,10 +1444,10 @@ def soundex(arg: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() 'R163' """ - return Expr(_f.soundex(arg.expr)) + return Expr(_f.soundex(col.expr)) -def is_valid_utf8(arg: Expr) -> Expr: +def is_valid_utf8(str: Expr) -> Expr: """Spark ``is_valid_utf8``: true if the string is valid UTF-8. Examples: @@ -1457,10 +1458,10 @@ def is_valid_utf8(arg: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() True """ - return Expr(_f.is_valid_utf8(arg.expr)) + return Expr(_f.is_valid_utf8(str.expr)) -def make_valid_utf8(arg: Expr) -> Expr: +def make_valid_utf8(str: Expr) -> Expr: """Spark ``make_valid_utf8``: replace invalid UTF-8 bytes with U+FFFD. Examples: @@ -1471,7 +1472,7 @@ def make_valid_utf8(arg: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() 'hello' """ - return Expr(_f.make_valid_utf8(arg.expr)) + return Expr(_f.make_valid_utf8(str.expr)) # --------------------------------------------------------------------------- From e9113dd7ea2983429a76a6bc9bcd81e88f382fe7 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sat, 30 May 2026 09:35:13 -0400 Subject: [PATCH 09/14] feat(spark): make pyspark-optional params optional MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Match pyspark's optional-parameter surface in the spark namespace: - make_dt_interval, make_interval: all parts default to zero (int32 0 / lit 0.0) - str_to_map: pair_delim defaults to ',', key_value_delim defaults to ':' - round: scale defaults to 0 (HALF_UP rounding to nearest integer) - shuffle: accepts `seed` kwarg for pyspark parity; raises NotImplementedError for non-None values until the Rust binding supports it - like, ilike: accept `escapeChar` for pyspark parity; same NotImplementedError guard; first positional renamed `string` → `str` to match pyspark ceil/floor `scale=` deferred — the underlying Rust expr_fn is single-arg. Added a module-level `_ZERO_I32` literal to avoid rebuilding the pyarrow int32 zero scalar on every call. Tests: positional-compat coverage for aggregates (`spark.avg(col)` etc.), defaults-omitted cases for the optional-arg functions, and NotImplementedError cases for `shuffle(seed=)` and `like/ilike(escapeChar=)`. Co-Authored-By: Claude Opus 4.7 (1M context) --- python/datafusion/functions/spark.py | 149 ++++++++++++++++++++------- python/tests/test_spark_functions.py | 67 ++++++++++++ 2 files changed, 179 insertions(+), 37 deletions(-) diff --git a/python/datafusion/functions/spark.py b/python/datafusion/functions/spark.py index 09467b981..9fbde0f35 100644 --- a/python/datafusion/functions/spark.py +++ b/python/datafusion/functions/spark.py @@ -32,6 +32,8 @@ from typing import TYPE_CHECKING, Any +import pyarrow as pa + from datafusion._internal import functions as _functions from datafusion.expr import Expr, sort_list_to_raw_sort_list @@ -41,6 +43,9 @@ _f = _functions.spark +# Reused int32 literal so optional-arg defaults don't rebuild it per call. +_ZERO_I32 = Expr.literal(pa.scalar(0, type=pa.int32())) + def _filter_raw(filter: Expr | None) -> Any: return filter.expr if filter is not None else None @@ -203,9 +208,12 @@ def array(*cols: Expr) -> Expr: return Expr(_f.array(*[c.expr for c in cols])) -def shuffle(col: Expr) -> Expr: +def shuffle(col: Expr, seed: int | None = None) -> Expr: """Spark ``shuffle``: returns a random permutation of the input array. + ``seed`` is accepted for pyspark parity but is not yet wired through the + Rust binding; passing a non-``None`` value raises ``NotImplementedError``. + Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"x": [1]}) @@ -217,6 +225,9 @@ def shuffle(col: Expr) -> Expr: >>> sorted(r.collect_column("v")[0].as_py()) [1, 2, 3] """ + if seed is not None: + msg = "shuffle(seed=...) is not yet supported by the Spark UDF binding" + raise NotImplementedError(msg) return Expr(_f.shuffle(col.expr)) @@ -589,59 +600,78 @@ def last_day(col: Expr) -> Expr: return Expr(_f.last_day(col.expr)) -def make_dt_interval(days: Expr, hours: Expr, mins: Expr, secs: Expr) -> Expr: +def make_dt_interval( + days: Expr | None = None, + hours: Expr | None = None, + mins: Expr | None = None, + secs: Expr | None = None, +) -> Expr: """Spark ``make_dt_interval``: day-time interval from components. + All parts are optional; omitted parts default to zero, matching pyspark. + Examples: - >>> import pyarrow as pa >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select(dfn.functions.spark.make_dt_interval().alias("v")) + >>> r.collect_column("v")[0].as_py() + datetime.timedelta(0) + + >>> import pyarrow as pa >>> i32 = lambda n: dfn.lit(pa.scalar(n, type=pa.int32())) >>> r = df.select( ... dfn.functions.spark.make_dt_interval( - ... i32(1), i32(2), i32(3), dfn.lit(4.5) + ... days=i32(1), hours=i32(2), mins=i32(3), secs=dfn.lit(4.5) ... ).alias("v") ... ) >>> r.collect_column("v")[0].as_py() datetime.timedelta(days=1, seconds=7384, microseconds=500000) """ - return Expr(_f.make_dt_interval(days.expr, hours.expr, mins.expr, secs.expr)) + return Expr( + _f.make_dt_interval( + (days if days is not None else _ZERO_I32).expr, + (hours if hours is not None else _ZERO_I32).expr, + (mins if mins is not None else _ZERO_I32).expr, + (secs if secs is not None else Expr.literal(0.0)).expr, + ) + ) def make_interval( - years: Expr, - months: Expr, - weeks: Expr, - days: Expr, - hours: Expr, - mins: Expr, - secs: Expr, + years: Expr | None = None, + months: Expr | None = None, + weeks: Expr | None = None, + days: Expr | None = None, + hours: Expr | None = None, + mins: Expr | None = None, + secs: Expr | None = None, ) -> Expr: """Spark ``make_interval``: interval from year/month/week/day/hour/min/sec parts. + All parts are optional; omitted parts default to zero, matching pyspark. + Examples: - >>> import pyarrow as pa >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select(dfn.functions.spark.make_interval().alias("v")) + >>> r.collect_column("v")[0].as_py().months + 0 + + >>> import pyarrow as pa >>> i32 = lambda n: dfn.lit(pa.scalar(n, type=pa.int32())) - >>> r = df.select( - ... dfn.functions.spark.make_interval( - ... i32(1), i32(0), i32(0), i32(0), - ... i32(0), i32(0), dfn.lit(0.0) - ... ).alias("v") - ... ) + >>> r = df.select(dfn.functions.spark.make_interval(years=i32(1)).alias("v")) >>> r.collect_column("v")[0].as_py().months 12 """ return Expr( _f.make_interval( - years.expr, - months.expr, - weeks.expr, - days.expr, - hours.expr, - mins.expr, - secs.expr, + (years if years is not None else _ZERO_I32).expr, + (months if months is not None else _ZERO_I32).expr, + (weeks if weeks is not None else _ZERO_I32).expr, + (days if days is not None else _ZERO_I32).expr, + (hours if hours is not None else _ZERO_I32).expr, + (mins if mins is not None else _ZERO_I32).expr, + (secs if secs is not None else Expr.literal(0.0)).expr, ) ) @@ -984,21 +1014,36 @@ def map_from_entries(col: Expr) -> Expr: return Expr(_f.map_from_entries(col.expr)) -def str_to_map(text: Expr, pair_delim: Expr, key_value_delim: Expr) -> Expr: +def str_to_map( + text: Expr, + pair_delim: Expr | None = None, + key_value_delim: Expr | None = None, +) -> Expr: """Spark ``str_to_map``: split text into key/value pairs using delimiters. + Delimiters default to ``","`` and ``":"`` when omitted, matching pyspark. + Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.str_to_map(dfn.lit("a:1,b:2")).alias("v")) + >>> r.collect_column("v")[0].as_py() + [('a', '1'), ('b', '2')] + >>> r = df.select( ... dfn.functions.spark.str_to_map( - ... dfn.lit("a:1,b:2"), dfn.lit(","), dfn.lit(":") + ... dfn.lit("a=1;b=2"), + ... pair_delim=dfn.lit(";"), + ... key_value_delim=dfn.lit("="), ... ).alias("v") ... ) >>> r.collect_column("v")[0].as_py() [('a', '1'), ('b', '2')] """ - return Expr(_f.str_to_map(text.expr, pair_delim.expr, key_value_delim.expr)) + pd = pair_delim if pair_delim is not None else Expr.literal(",") + kvd = key_value_delim if key_value_delim is not None else Expr.literal(":") + return Expr(_f.str_to_map(text.expr, pd.expr, kvd.expr)) # --------------------------------------------------------------------------- @@ -1130,18 +1175,28 @@ def rint(col: Expr) -> Expr: return Expr(_f.rint(col.expr)) -def round(col: Expr, scale: Expr) -> Expr: +def round(col: Expr, scale: Expr | None = None) -> Expr: """Spark ``round``: round to ``scale`` decimal places, HALF_UP rounding. + ``scale`` defaults to zero when omitted, matching pyspark. + Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"x": [1]}) - >>> r = df.select( - ... dfn.functions.spark.round(dfn.lit(2.5), dfn.lit(0)).alias("v")) + >>> r = df.select(dfn.functions.spark.round(dfn.lit(2.5)).alias("v")) >>> r.collect_column("v")[0].as_py() 3.0 + + >>> r = df.select( + ... dfn.functions.spark.round( + ... dfn.lit(2.345), scale=dfn.lit(2) + ... ).alias("v") + ... ) + >>> r.collect_column("v")[0].as_py() + 2.35 """ - return Expr(_f.round(col.expr, scale.expr)) + scale_expr = scale if scale is not None else _ZERO_I32 + return Expr(_f.round(col.expr, scale_expr.expr)) def unhex(col: Expr) -> Expr: @@ -1306,9 +1361,16 @@ def elt(*inputs: Expr) -> Expr: return Expr(_f.elt(*[i.expr for i in inputs])) -def ilike(string: Expr, pattern: Expr) -> Expr: +def ilike( + str: Expr, + pattern: Expr, + escapeChar: str | None = None, # noqa: N803 +) -> Expr: """Spark ``ilike``: case-insensitive pattern match. + ``escapeChar`` is accepted for pyspark parity but is not yet wired through + the Rust binding; passing a non-``None`` value raises ``NotImplementedError``. + Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"x": [1]}) @@ -1317,7 +1379,10 @@ def ilike(string: Expr, pattern: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() True """ - return Expr(_f.ilike(string.expr, pattern.expr)) + if escapeChar is not None: + msg = "ilike(escapeChar=...) is not yet supported by the Spark UDF binding" + raise NotImplementedError(msg) + return Expr(_f.ilike(str.expr, pattern.expr)) def length(col: Expr) -> Expr: @@ -1333,9 +1398,16 @@ def length(col: Expr) -> Expr: return Expr(_f.length(col.expr)) -def like(string: Expr, pattern: Expr) -> Expr: +def like( + str: Expr, + pattern: Expr, + escapeChar: str | None = None, # noqa: N803 +) -> Expr: """Spark ``like``: case-sensitive pattern match. + ``escapeChar`` is accepted for pyspark parity but is not yet wired through + the Rust binding; passing a non-``None`` value raises ``NotImplementedError``. + Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"x": [1]}) @@ -1344,7 +1416,10 @@ def like(string: Expr, pattern: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() True """ - return Expr(_f.like(string.expr, pattern.expr)) + if escapeChar is not None: + msg = "like(escapeChar=...) is not yet supported by the Spark UDF binding" + raise NotImplementedError(msg) + return Expr(_f.like(str.expr, pattern.expr)) def luhn_check(col: Expr) -> Expr: diff --git a/python/tests/test_spark_functions.py b/python/tests/test_spark_functions.py index f9ccb41a1..0b58a28da 100644 --- a/python/tests/test_spark_functions.py +++ b/python/tests/test_spark_functions.py @@ -229,6 +229,73 @@ def test_round_half_up(): assert _val(df, spark.round(lit(2.5), lit(0))) == 3.0 +# --------------------------------------------------------------------------- +# Optional parameter defaults / NotImplementedError +# --------------------------------------------------------------------------- + + +def test_round_scale_default(): + """spark.round defaults scale to 0.""" + ctx = SessionContext() + df = ctx.from_pydict({"x": [1]}) + assert _val(df, spark.round(lit(2.5))) == 3.0 + + +def test_make_dt_interval_defaults(): + """spark.make_dt_interval with no args returns a zero day-time interval.""" + import datetime as dt + + ctx = SessionContext() + df = ctx.from_pydict({"x": [1]}) + assert _val(df, spark.make_dt_interval()) == dt.timedelta(0) + + +def test_make_interval_defaults(): + """spark.make_interval with no args returns a zero interval.""" + ctx = SessionContext() + df = ctx.from_pydict({"x": [1]}) + assert _val(df, spark.make_interval()).months == 0 + + +def test_str_to_map_defaults(): + """spark.str_to_map defaults delimiters to ',' and ':'.""" + ctx = SessionContext() + df = ctx.from_pydict({"x": [1]}) + assert _val(df, spark.str_to_map(lit("a:1,b:2"))) == [("a", "1"), ("b", "2")] + + +def test_shuffle_seed_raises(): + """spark.shuffle(seed=...) raises NotImplementedError until Rust supports it.""" + with pytest.raises(NotImplementedError, match="seed"): + spark.shuffle(spark.array(lit(1), lit(2)), seed=1) + + +def test_like_escape_raises(): + """spark.like/ilike escapeChar raises NotImplementedError until Rust supports.""" + with pytest.raises(NotImplementedError, match="escapeChar"): + spark.like(lit("a"), lit("a"), escapeChar="\\") + with pytest.raises(NotImplementedError, match="escapeChar"): + spark.ilike(lit("a"), lit("a"), escapeChar="\\") + + +def test_aggregate_positional_compat(): + """Pyspark-style positional calls still work after the rename to ``col``.""" + ctx = SessionContext() + df = ctx.from_pydict({"a": [1.0, 2.0, 3.0]}) + out = df.aggregate( + [], + [ + spark.avg(col("a")).alias("av"), + spark.try_sum(col("a")).alias("ts"), + spark.collect_list(col("a")).alias("cl"), + spark.collect_set(col("a")).alias("cs"), + ], + ).collect() + rec = pa.Table.from_batches(out) + assert rec.column("av")[0].as_py() == 2.0 + assert rec.column("ts")[0].as_py() == 6.0 + + # --------------------------------------------------------------------------- # SQL path via enable_spark_functions # --------------------------------------------------------------------------- From f4b5119b5201489f3dc2dec429a1ed3654ea09f4 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sat, 30 May 2026 09:36:46 -0400 Subject: [PATCH 10/14] refactor(spark): reshape varargs to match pyspark signatures MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace generic ``*args`` with explicit pyspark-style signatures: - json_tuple(col, *fields) — first positional is the JSON expr - format_string(format, *cols) — `format` is the printf template; a plain ``str`` is auto-promoted to a literal - parse_url(url, partToExtract, key=None) — `key` is optional and only meaningful with ``partToExtract='QUERY'`` - try_parse_url(url, partToExtract, key=None) — same shape - url_decode(str), try_url_decode(str), url_encode(str) — single-argument forms (multi-arg calls were always semantically wrong) Tests cover the three-arg parse_url path and the plain-str format_string auto-promotion. Co-Authored-By: Claude Opus 4.7 (1M context) --- python/datafusion/functions/spark.py | 60 ++++++++++++++++++++-------- python/tests/test_spark_functions.py | 16 ++++++++ 2 files changed, 60 insertions(+), 16 deletions(-) diff --git a/python/datafusion/functions/spark.py b/python/datafusion/functions/spark.py index 9fbde0f35..4c7127344 100644 --- a/python/datafusion/functions/spark.py +++ b/python/datafusion/functions/spark.py @@ -958,7 +958,7 @@ def xxhash64(*cols: Expr) -> Expr: # --------------------------------------------------------------------------- -def json_tuple(*args: Expr) -> Expr: +def json_tuple(col: Expr, *fields: Expr) -> Expr: """Spark ``json_tuple``: extract top-level fields from a JSON string. Examples: @@ -972,7 +972,7 @@ def json_tuple(*args: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() {'c0': '1', 'c1': 'x'} """ - return Expr(_f.json_tuple(*[a.expr for a in args])) + return Expr(_f.json_tuple(col.expr, *[f.expr for f in fields])) # --------------------------------------------------------------------------- @@ -1439,23 +1439,25 @@ def luhn_check(col: Expr) -> Expr: return Expr(_f.luhn_check(col.expr)) -def format_string(*args: Expr) -> Expr: +def format_string(format: str | Expr, *cols: Expr) -> Expr: """Spark ``format_string``: printf-style format string. - First arg is the format, remaining args are values to substitute. + ``format`` is the printf-style template (a plain ``str`` is auto-promoted + to a literal expression); remaining args are values to substitute. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"x": [1]}) >>> r = df.select( ... dfn.functions.spark.format_string( - ... dfn.lit("%d-%s"), dfn.lit(42), dfn.lit("hi") + ... "%d-%s", dfn.lit(42), dfn.lit("hi") ... ).alias("v") ... ) >>> r.collect_column("v")[0].as_py() '42-hi' """ - return Expr(_f.format_string(*[a.expr for a in args])) + fmt_expr = format if isinstance(format, Expr) else Expr.literal(format) + return Expr(_f.format_string(fmt_expr.expr, *[c.expr for c in cols])) def space(col: Expr) -> Expr: @@ -1555,9 +1557,17 @@ def make_valid_utf8(str: Expr) -> Expr: # --------------------------------------------------------------------------- -def parse_url(*args: Expr) -> Expr: +def parse_url( + url: Expr, + partToExtract: Expr, # noqa: N803 + key: Expr | None = None, +) -> Expr: """Spark ``parse_url``: extract a part from a URL; errors on invalid URLs. + ``partToExtract`` is one of ``"HOST"``, ``"PATH"``, ``"QUERY"``, + ``"REF"``, ``"PROTOCOL"``, ``"FILE"``, ``"AUTHORITY"``, ``"USERINFO"``. + Pass ``key`` only with ``"QUERY"`` to extract a single parameter. + Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"x": [1]}) @@ -1568,11 +1578,27 @@ def parse_url(*args: Expr) -> Expr: ... ) >>> r.collect_column("v")[0].as_py() 'example.com' + + >>> r = df.select( + ... dfn.functions.spark.parse_url( + ... dfn.lit("http://example.com/path?q=1"), + ... dfn.lit("QUERY"), + ... key=dfn.lit("q"), + ... ).alias("v") + ... ) + >>> r.collect_column("v")[0].as_py() + '1' """ - return Expr(_f.parse_url(*[a.expr for a in args])) + if key is None: + return Expr(_f.parse_url(url.expr, partToExtract.expr)) + return Expr(_f.parse_url(url.expr, partToExtract.expr, key.expr)) -def try_parse_url(*args: Expr) -> Expr: +def try_parse_url( + url: Expr, + partToExtract: Expr, # noqa: N803 + key: Expr | None = None, +) -> Expr: """Spark ``try_parse_url``: like ``parse_url`` but returns NULL on invalid URLs. Examples: @@ -1586,10 +1612,12 @@ def try_parse_url(*args: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() 'example.com' """ - return Expr(_f.try_parse_url(*[a.expr for a in args])) + if key is None: + return Expr(_f.try_parse_url(url.expr, partToExtract.expr)) + return Expr(_f.try_parse_url(url.expr, partToExtract.expr, key.expr)) -def url_decode(*args: Expr) -> Expr: +def url_decode(str: Expr) -> Expr: """Spark ``url_decode``: decode an application/x-www-form-urlencoded string. Examples: @@ -1600,10 +1628,10 @@ def url_decode(*args: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() 'a b' """ - return Expr(_f.url_decode(*[a.expr for a in args])) + return Expr(_f.url_decode(str.expr)) -def try_url_decode(*args: Expr) -> Expr: +def try_url_decode(str: Expr) -> Expr: """Spark ``try_url_decode``: like ``url_decode``; returns NULL on invalid input. Examples: @@ -1614,10 +1642,10 @@ def try_url_decode(*args: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() 'a b' """ - return Expr(_f.try_url_decode(*[a.expr for a in args])) + return Expr(_f.try_url_decode(str.expr)) -def url_encode(*args: Expr) -> Expr: +def url_encode(str: Expr) -> Expr: """Spark ``url_encode``: encode a string in application/x-www-form-urlencoded. Examples: @@ -1628,7 +1656,7 @@ def url_encode(*args: Expr) -> Expr: >>> r.collect_column("v")[0].as_py() 'a+b' """ - return Expr(_f.url_encode(*[a.expr for a in args])) + return Expr(_f.url_encode(str.expr)) __all__ = [ diff --git a/python/tests/test_spark_functions.py b/python/tests/test_spark_functions.py index 0b58a28da..193d21b2c 100644 --- a/python/tests/test_spark_functions.py +++ b/python/tests/test_spark_functions.py @@ -278,6 +278,22 @@ def test_like_escape_raises(): spark.ilike(lit("a"), lit("a"), escapeChar="\\") +def test_parse_url_three_arg(): + """parse_url(url, partToExtract, key=...) extracts query parameters.""" + ctx = SessionContext() + df = ctx.from_pydict({"x": [1]}) + url = lit("http://example.com/p?q=hello&n=1") + assert _val(df, spark.parse_url(url, lit("QUERY"), key=lit("q"))) == "hello" + assert _val(df, spark.try_parse_url(url, lit("QUERY"), key=lit("n"))) == "1" + + +def test_format_string_plain_str_format(): + """format_string accepts a plain str format that is auto-promoted to lit.""" + ctx = SessionContext() + df = ctx.from_pydict({"x": [1]}) + assert _val(df, spark.format_string("%d-%s", lit(42), lit("hi"))) == "42-hi" + + def test_aggregate_positional_compat(): """Pyspark-style positional calls still work after the rename to ``col``.""" ctx = SessionContext() From e2eceb40c0fca3155cab1009dea58791339cc604 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sat, 30 May 2026 09:45:03 -0400 Subject: [PATCH 11/14] docs(skills): cover the new spark function namespace MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `functions.spark` mirrors `pyspark.sql.functions` and now ships on this branch. Update every skill that references the function surface: - skills/datafusion_python/SKILL.md (user-facing): add an import reference, a Core Abstractions row, and a "Spark-Compatible Functions" subsection listing coverage by category, the SQL-vs-DataFrame usage (`enable_spark_functions`), and the divergent-semantics table (concat NULL, round HALF_UP, trunc) so callers know which namespace to pick. - .ai/skills/check-upstream/SKILL.md: new area for the `datafusion-spark` crate with the coverage policy (parity with pyspark, extras allowed when positional pyspark calls still work). Hygiene check also now spans `functions/spark.py`'s `__all__`. - .ai/skills/audit-skill-md/SKILL.md: add `functions.spark` to the surface table and a `spark-functions` scope so this audit also validates the new subsection and divergent-semantics table. - .ai/skills/make-pythonic/SKILL.md: explicit scope note that the spark namespace is a deliberate pyspark mirror — generic native-type coercion does not apply there. Path references updated to the new `functions/__init__.py` module layout. Co-Authored-By: Claude Opus 4.7 (1M context) --- .ai/skills/audit-skill-md/SKILL.md | 12 +++++-- .ai/skills/check-upstream/SKILL.md | 50 +++++++++++++++++++++++--- .ai/skills/make-pythonic/SKILL.md | 27 ++++++++++++-- skills/datafusion_python/SKILL.md | 57 ++++++++++++++++++++++++++++++ 4 files changed, 135 insertions(+), 11 deletions(-) diff --git a/.ai/skills/audit-skill-md/SKILL.md b/.ai/skills/audit-skill-md/SKILL.md index 30e1a90fd..ba5255a59 100644 --- a/.ai/skills/audit-skill-md/SKILL.md +++ b/.ai/skills/audit-skill-md/SKILL.md @@ -48,7 +48,8 @@ exposed at the package root), include it. | `SessionContext` | `python/datafusion/context.py` | "Data Loading" | | `DataFrame` | `python/datafusion/dataframe.py` | "DataFrame Operations Quick Reference", "Executing and Collecting Results", "Idiomatic Patterns" | | `Expr` | `python/datafusion/expr.py` | "Expression Building", "Common Pitfalls" | -| `functions` | `python/datafusion/functions.py` | "Available Functions (Categorized)", scattered uses throughout | +| `functions` | `python/datafusion/functions/__init__.py` | "Available Functions (Categorized)", scattered uses throughout | +| `functions.spark` | `python/datafusion/functions/spark.py` | "Available Functions (Categorized)" → "Spark-Compatible Functions" subsection | | Top-level helpers (`col`, `lit`, `WindowFrame`, ...) | `python/datafusion/__init__.py` | "Import Conventions", "Core Abstractions" | ## Scope argument @@ -61,7 +62,8 @@ is given or `all` is specified, audit every area. | `session-context` | `SessionContext` methods and the "Data Loading" section | | `dataframe` | `DataFrame` methods and the operations / executing / patterns sections | | `expr` | `Expr` methods/operators and the "Expression Building" section | -| `functions` | `functions.py` `__all__` and the "Available Functions (Categorized)" section | +| `functions` | `functions/__init__.py` `__all__` and the "Available Functions (Categorized)" section | +| `spark-functions` | `functions/spark.py` `__all__`, the "Spark-Compatible Functions" subsection, and the divergent-semantics table | | `patterns` | "Idiomatic Patterns" section — confirm patterns still match recommended style | | `pitfalls` | "Common Pitfalls" — confirm each pitfall still reproduces, drop ones fixed upstream | | `version-notes` | Cross-check version annotations (see below) | @@ -123,7 +125,11 @@ For each function name, method name, or import shown in `SKILL.md`, verify it still exists in the current API: - Function names mentioned in prose or in the categorized list should appear - in `python/datafusion/functions.py`'s `__all__`. + in `python/datafusion/functions/__init__.py`'s `__all__`. +- Spark function names mentioned in the "Spark-Compatible Functions" + subsection should appear in `python/datafusion/functions/spark.py`'s + `__all__`. Also confirm the divergent-semantics table still matches the + current spark vs. main signatures. - Method calls in code blocks should resolve against the current class. - Imports (`from datafusion import ...`) should succeed against the current `__init__.py`. diff --git a/.ai/skills/check-upstream/SKILL.md b/.ai/skills/check-upstream/SKILL.md index 24b4e1bb1..a2ff373df 100644 --- a/.ai/skills/check-upstream/SKILL.md +++ b/.ai/skills/check-upstream/SKILL.md @@ -209,18 +209,58 @@ These upstream FFI types have been reviewed and do not need to be independently - FFI example in `examples/datafusion-ffi-example/` - Type appears in union type hints where accepted -### 8. `__all__` Hygiene (functions.py) +### 8. Spark-Compatible Functions (`datafusion-spark` crate) + +**Upstream source of truth:** +- Crate source: https://github.com/apache/datafusion/tree/main/datafusion/spark/src +- Rust docs: https://docs.rs/datafusion-spark/latest/datafusion_spark/ + +**Where they are exposed in this project:** +- Python API: `python/datafusion/functions/spark.py` — each function wraps + a call to `datafusion._internal.functions.spark`; the public surface is + the module's `__all__` list. +- Rust bindings: `crates/core/src/spark_functions.rs` — `#[pyfunction]` + definitions registered via `init_module()` and re-exported under + `datafusion._internal.functions.spark`. + +**Coverage policy:** The spark namespace mirrors +`pyspark.sql.functions` parameter names and shapes exactly so pyspark +callers can paste code unchanged. Extras over pyspark are permitted as +long as positional pyspark calls still work — for example, the spark +`avg` / `try_sum` / `collect_list` / `collect_set` retain the +`distinct`/`filter`/`order_by`/`null_treatment` kwargs from the main +namespace while pyspark's single-positional form continues to work. + +**How to check:** +1. Fetch the upstream `datafusion-spark` function list from the crate + source under `datafusion/spark/src/function/` (each subdirectory is a + category: `string/`, `math/`, `datetime/`, etc.). The crate's + `function.rs` collects all `ScalarUDF` factories. +2. Cross-reference against `pyspark.sql.functions` for the public-facing + shape — pyspark is the contract this namespace is matching. +3. Compare against the functions listed in + `python/datafusion/functions/spark.py`'s `__all__`. A function is + covered if it exists in the Python `spark` namespace, even if it + aliases another function's Rust binding. +4. Report functions that are missing from the Python spark namespace. + +**Cross-cutting reference:** The longer-form roadmap for spark coverage +lives in `PYSPARK_ALIGNMENT_PLAN.md` (root of repo). Use it as the source +of truth for which gaps are intentionally deferred vs. ready to land. + +### 9. `__all__` Hygiene (functions.py and functions/spark.py) Independent of upstream parity, also flag public `def` symbols in -`python/datafusion/functions.py` that are missing from the module's -`__all__`. These are functions a user can call but that do not show up in +`python/datafusion/functions.py` **and** `python/datafusion/functions/spark.py` +that are missing from that file's `__all__`. These are functions a user +can call but that do not show up in `from datafusion.functions import *`, in tab-completion against the namespace, or in generated API docs — typically an oversight rather than an intentional omission. **How to check:** -1. Grep for `^def ([a-z_][a-z0-9_]*)\(` in `python/datafusion/functions.py` - to enumerate every public function definition. +1. Grep for `^def ([a-z_][a-z0-9_]*)\(` in each file to enumerate every + public function definition. 2. Read the `__all__` list at the top of the same file. 3. Report any function in (1) that is not in (2). Skip private helpers (names starting with `_`). diff --git a/.ai/skills/make-pythonic/SKILL.md b/.ai/skills/make-pythonic/SKILL.md index 57145ac6c..44bfebdb8 100644 --- a/.ai/skills/make-pythonic/SKILL.md +++ b/.ai/skills/make-pythonic/SKILL.md @@ -29,9 +29,30 @@ You are improving the datafusion-python API to feel more natural to Python users **Core principle:** A Python user should be able to write `split_part(col("a"), ",", 2)` instead of `split_part(col("a"), lit(","), lit(2))` when the arguments are contextually obvious literals. +## Scope: `functions` vs `functions.spark` + +This skill targets the **default `datafusion.functions` namespace** (file: +`python/datafusion/functions/__init__.py`). Do **not** apply pythonic +coercion to `python/datafusion/functions/spark.py` — that namespace is a +deliberate mirror of `pyspark.sql.functions`, so its parameter names, +order, and types must match pyspark exactly. Adding `Expr | int` style +unions there would diverge from the pyspark contract callers rely on. + +Two exceptions where pythonic-style additions in `functions.spark` are +still on-brand: +- **Pyspark itself accepts a native type.** Pyspark's `format_string` + takes `format: str | Column`; the spark wrapper already auto-promotes a + plain `str` to a literal — keep parity. +- **Strictly additive optional kwargs** that pyspark also has (e.g. + `like(escapeChar=...)`). These belong in the [PYSPARK_ALIGNMENT_PLAN.md] + follow-up PRs, not in a make-pythonic pass. + +If the user explicitly scopes to "spark", validate parity with pyspark +rather than applying generic coercion. + ## How to Identify Candidates -The user may specify a scope via `$ARGUMENTS`. If no scope is given or "all" is specified, audit all functions in `python/datafusion/functions.py`. +The user may specify a scope via `$ARGUMENTS`. If no scope is given or "all" is specified, audit all functions in `python/datafusion/functions/__init__.py`. For each function, determine if any parameter can accept native Python types by evaluating **two complementary signals**: @@ -309,7 +330,7 @@ For each function being updated: ### Step 1: Analyze the Function -1. Read the current Python function signature in `python/datafusion/functions.py` +1. Read the current Python function signature in `python/datafusion/functions/__init__.py` 2. Read the Rust binding in `crates/core/src/functions.rs` 3. Optionally check the upstream DataFusion docs for the function 4. Determine which category (A, B, or C) applies to each parameter @@ -346,7 +367,7 @@ dfn.functions.left(dfn.col("a"), 3) After making changes, run the doctests to verify: ```bash -python -m pytest --doctest-modules python/datafusion/functions.py -v +python -m pytest --doctest-modules python/datafusion/functions/__init__.py -v ``` ## Coercion Helper Pattern diff --git a/skills/datafusion_python/SKILL.md b/skills/datafusion_python/SKILL.md index 1aeb78777..c2f0a10c3 100644 --- a/skills/datafusion_python/SKILL.md +++ b/skills/datafusion_python/SKILL.md @@ -26,12 +26,14 @@ can interoperate with DataFusion. | `DataFrame` | Lazy query builder. Each method returns a new DataFrame. | Returned by context methods | | `Expr` | Expression tree node (column ref, literal, function call, ...). | `from datafusion import col, lit` | | `functions` | 290+ built-in scalar, aggregate, and window functions. | `from datafusion import functions as F` | +| `functions.spark` | PySpark-compatible function surface (parameter names match `pyspark.sql.functions`). | `from datafusion.functions import spark` | ## Import Conventions ```python from datafusion import SessionContext, col, lit from datafusion import functions as F +from datafusion.functions import spark # only when porting pyspark code ``` ## Data Loading @@ -762,3 +764,58 @@ F.left(col("c_phone"), lit(2)) # prefix shortcut **Other**: `in_list`, `order_by`, `alias`, `col`, `encode`, `decode`, `to_hex`, `to_char`, `uuid`, `version`, `bit_length`, `octet_length` + +### Spark-Compatible Functions + +A separate `datafusion.functions.spark` namespace mirrors the +`pyspark.sql.functions` API for callers porting code from PySpark. + +```python +from datafusion.functions import spark +``` + +Use it for DataFrame work; for SQL, register the Spark UDFs first: + +```python +ctx = SessionContext() +ctx.enable_spark_functions() # makes Spark UDFs visible to SQL +ctx.sql("SELECT sha2('hello', 256)").show() +``` + +Coverage spans aggregate (`avg`, `try_sum`, `collect_list`, `collect_set`), +array (`array`, `array_contains`, `array_repeat`, `shuffle`, `slice`, +`size`), bitmap, bitwise (`shiftleft`, `shiftright`, `shiftrightunsigned`, +`bit_get`, `bit_count`, `bitwise_not`), datetime (`add_months`, +`date_add`, `date_sub`, `date_diff`, `date_trunc`, `time_trunc`, `trunc`, +`next_day`, `from_utc_timestamp`, `to_utc_timestamp`, `unix_date`, +`unix_micros`/`millis`/`seconds`, `make_interval`, `make_dt_interval`), +hash (`crc32`, `sha1`, `sha2`, `xxhash64`), JSON (`json_tuple`), +map (`map_from_arrays`, `map_from_entries`, `str_to_map`), math +(`abs`, `ceil`, `floor`, `round`, `expm1`, `factorial`, `hex`, +`modulus`/`pmod`, `rint`, `unhex`, `width_bucket`, `csc`/`sec`, +`negative`, `bin`), string (`ascii`, `base64`/`unbase64`, `char`, +`concat`, `elt`, `like`/`ilike`, `length`, `luhn_check`, `format_string`, +`space`, `substring`, `soundex`, `is_valid_utf8`/`make_valid_utf8`), +URL (`parse_url`/`try_parse_url`, `url_decode`/`url_encode`, +`try_url_decode`), and conditional (`if_`, `spark_cast`). + +The full list is in the API reference; see +`python/datafusion/functions/spark.py`. + +**Semantic divergences vs the default namespace.** Functions that exist in +both `functions` and `functions.spark` may behave differently: + +| Function | Default `functions` | `functions.spark` | +|---|---|---| +| `concat` | NULL inputs treated as empty | NULL inputs propagate to NULL | +| `round` | HALF_EVEN (banker's) | HALF_UP | +| `trunc` | Numeric truncation | Date truncation | +| `substring` | 1-indexed | 1-indexed (parity) | + +Pick the namespace whose semantics match your intent — both stay imported +side by side; `enable_spark_functions()` only affects SQL. + +**Parameter names match pyspark exactly.** The spark namespace uses +pyspark parameter names (`col`, `str`, `numBits`, `partToExtract`, ...) so +you can paste pyspark code and keep keyword arguments working. The default +namespace keeps DataFusion's parameter names. From 43ecf1ab9c2406dc310e9bd96d23f771e14a8b2a Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sat, 30 May 2026 11:00:46 -0400 Subject: [PATCH 12/14] docs(skills): drop references to PYSPARK_ALIGNMENT_PLAN.md The plan file is a working document, not a committed artifact, so skills must not point at it. Inline the one substantive reference (the "deferred to follow-up PRs" callout in make-pythonic) and drop the cross-cutting pointer from check-upstream. Co-Authored-By: Claude Opus 4.7 (1M context) --- .ai/skills/check-upstream/SKILL.md | 4 ---- .ai/skills/make-pythonic/SKILL.md | 4 ++-- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/.ai/skills/check-upstream/SKILL.md b/.ai/skills/check-upstream/SKILL.md index a2ff373df..a3d82a670 100644 --- a/.ai/skills/check-upstream/SKILL.md +++ b/.ai/skills/check-upstream/SKILL.md @@ -244,10 +244,6 @@ namespace while pyspark's single-positional form continues to work. aliases another function's Rust binding. 4. Report functions that are missing from the Python spark namespace. -**Cross-cutting reference:** The longer-form roadmap for spark coverage -lives in `PYSPARK_ALIGNMENT_PLAN.md` (root of repo). Use it as the source -of truth for which gaps are intentionally deferred vs. ready to land. - ### 9. `__all__` Hygiene (functions.py and functions/spark.py) Independent of upstream parity, also flag public `def` symbols in diff --git a/.ai/skills/make-pythonic/SKILL.md b/.ai/skills/make-pythonic/SKILL.md index 44bfebdb8..1a2fa9989 100644 --- a/.ai/skills/make-pythonic/SKILL.md +++ b/.ai/skills/make-pythonic/SKILL.md @@ -44,8 +44,8 @@ still on-brand: takes `format: str | Column`; the spark wrapper already auto-promotes a plain `str` to a literal — keep parity. - **Strictly additive optional kwargs** that pyspark also has (e.g. - `like(escapeChar=...)`). These belong in the [PYSPARK_ALIGNMENT_PLAN.md] - follow-up PRs, not in a make-pythonic pass. + `like(escapeChar=...)`). These belong in pyspark-alignment follow-up + PRs, not in a make-pythonic pass. If the user explicitly scopes to "spark", validate parity with pyspark rather than applying generic coercion. From 310df185f5d5757208ce55ad772d81daf639b2c9 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sat, 30 May 2026 11:02:05 -0400 Subject: [PATCH 13/14] docs(skills): make-pythonic also targets functions.spark MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previous guidance said to skip the spark namespace entirely. That was wrong: the spark namespace should also feel pythonic — it just carries the extra constraint that every signature must remain compatible with pyspark.sql.functions (parameter names, positional order, accepted input types). Pythonic widenings like `Expr → Expr | int` are on-brand there because pyspark itself accepts the int form. Rewrite the scope section to spell out the compatibility rules (keep parameter names/order; widen input types, never narrow; extra kwargs default to None) and extend "How to Identify Candidates" to include `functions/spark.py`. Co-Authored-By: Claude Opus 4.7 (1M context) --- .ai/skills/make-pythonic/SKILL.md | 52 ++++++++++++++++++++----------- 1 file changed, 33 insertions(+), 19 deletions(-) diff --git a/.ai/skills/make-pythonic/SKILL.md b/.ai/skills/make-pythonic/SKILL.md index 1a2fa9989..7d490ec03 100644 --- a/.ai/skills/make-pythonic/SKILL.md +++ b/.ai/skills/make-pythonic/SKILL.md @@ -31,28 +31,42 @@ You are improving the datafusion-python API to feel more natural to Python users ## Scope: `functions` vs `functions.spark` -This skill targets the **default `datafusion.functions` namespace** (file: -`python/datafusion/functions/__init__.py`). Do **not** apply pythonic -coercion to `python/datafusion/functions/spark.py` — that namespace is a -deliberate mirror of `pyspark.sql.functions`, so its parameter names, -order, and types must match pyspark exactly. Adding `Expr | int` style -unions there would diverge from the pyspark contract callers rely on. - -Two exceptions where pythonic-style additions in `functions.spark` are -still on-brand: -- **Pyspark itself accepts a native type.** Pyspark's `format_string` - takes `format: str | Column`; the spark wrapper already auto-promotes a - plain `str` to a literal — keep parity. -- **Strictly additive optional kwargs** that pyspark also has (e.g. - `like(escapeChar=...)`). These belong in pyspark-alignment follow-up - PRs, not in a make-pythonic pass. - -If the user explicitly scopes to "spark", validate parity with pyspark -rather than applying generic coercion. +Both `python/datafusion/functions/__init__.py` and +`python/datafusion/functions/spark.py` are in scope. We want both to feel +pythonic — accept native Python types where the argument is contextually +a literal — but `functions.spark` carries an additional constraint: +**every signature must remain compatible with `pyspark.sql.functions`**. + +Compatibility rules for the spark namespace: + +- **Parameter names must match pyspark exactly.** Pyspark callers pass by + keyword (`spark.shiftleft(col=..., numBits=...)`), so renames break + them. Do NOT rename a parameter just because it would be more pythonic + in the main namespace. +- **Positional order must match pyspark exactly.** Reordering breaks + positional pyspark calls. +- **Type unions may widen the input set, never narrow it.** Pyspark + accepts `Column` or `str` (column name) for most args; we accept + `Expr` already, and widening to `Expr | int` / `Expr | str` for + literal-friendly arguments is on-brand because the int/str case is + exactly what a pyspark caller would also try. Just verify the widened + set is a superset of what pyspark accepts for that arg. +- **Extra keyword arguments are allowed** as long as they default to + `None` and pyspark's positional/keyword form still works (e.g. the + spark `avg`/`try_sum`/`collect_list`/`collect_set` retain DataFusion's + `distinct`/`filter`/`order_by`/`null_treatment` kwargs). + +Practical effect: in `functions.spark`, apply Categories A and (where +pyspark exposes the same arg as a non-`Expr`) B normally, but cross-check +each proposed signature against `pyspark.sql.functions` before landing +it. When pyspark's own type hint is `Column | str` for a "column name" +arg, prefer leaving the spark wrapper at `Expr` — Category C +("`Expr | str` meaning column name") is unusual in `functions.py` and +should remain so in `functions.spark`. ## How to Identify Candidates -The user may specify a scope via `$ARGUMENTS`. If no scope is given or "all" is specified, audit all functions in `python/datafusion/functions/__init__.py`. +The user may specify a scope via `$ARGUMENTS`. If no scope is given or "all" is specified, audit all functions in `python/datafusion/functions/__init__.py` **and** `python/datafusion/functions/spark.py`. When updating a spark-namespace function, apply the compatibility rules from "Scope" above on top of the standard analysis. For each function, determine if any parameter can accept native Python types by evaluating **two complementary signals**: From d9c8da9d82cff9496a8c613595d05b137f8a20dc Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sat, 30 May 2026 18:53:14 -0400 Subject: [PATCH 14/14] docs(skill): point at spark __all__ instead of enumerating Enumerating spark functions in the user-facing skill duplicates the __all__ list in python/datafusion/functions/spark.py and will drift the moment a new function lands or is renamed. Replace the per-function listing with a category summary and a discovery snippet that queries the actual __all__ at runtime, which is the authoritative source. Co-Authored-By: Claude Opus 4.7 (1M context) --- skills/datafusion_python/SKILL.md | 31 ++++++++++++------------------- 1 file changed, 12 insertions(+), 19 deletions(-) diff --git a/skills/datafusion_python/SKILL.md b/skills/datafusion_python/SKILL.md index c2f0a10c3..006c028d1 100644 --- a/skills/datafusion_python/SKILL.md +++ b/skills/datafusion_python/SKILL.md @@ -782,25 +782,18 @@ ctx.enable_spark_functions() # makes Spark UDFs visible to SQL ctx.sql("SELECT sha2('hello', 256)").show() ``` -Coverage spans aggregate (`avg`, `try_sum`, `collect_list`, `collect_set`), -array (`array`, `array_contains`, `array_repeat`, `shuffle`, `slice`, -`size`), bitmap, bitwise (`shiftleft`, `shiftright`, `shiftrightunsigned`, -`bit_get`, `bit_count`, `bitwise_not`), datetime (`add_months`, -`date_add`, `date_sub`, `date_diff`, `date_trunc`, `time_trunc`, `trunc`, -`next_day`, `from_utc_timestamp`, `to_utc_timestamp`, `unix_date`, -`unix_micros`/`millis`/`seconds`, `make_interval`, `make_dt_interval`), -hash (`crc32`, `sha1`, `sha2`, `xxhash64`), JSON (`json_tuple`), -map (`map_from_arrays`, `map_from_entries`, `str_to_map`), math -(`abs`, `ceil`, `floor`, `round`, `expm1`, `factorial`, `hex`, -`modulus`/`pmod`, `rint`, `unhex`, `width_bucket`, `csc`/`sec`, -`negative`, `bin`), string (`ascii`, `base64`/`unbase64`, `char`, -`concat`, `elt`, `like`/`ilike`, `length`, `luhn_check`, `format_string`, -`space`, `substring`, `soundex`, `is_valid_utf8`/`make_valid_utf8`), -URL (`parse_url`/`try_parse_url`, `url_decode`/`url_encode`, -`try_url_decode`), and conditional (`if_`, `spark_cast`). - -The full list is in the API reference; see -`python/datafusion/functions/spark.py`. +Coverage spans aggregate, array, bitmap, bitwise, datetime, hash, JSON, +map, math, string, URL, and conditional categories. The authoritative +list of what is currently exposed is the `__all__` in +`python/datafusion/functions/spark.py`: + +```bash +python -c "from datafusion.functions import spark; print(sorted(spark.__all__))" +``` + +When you need to know whether a specific pyspark function is available, +check `__all__` rather than this skill — the list there moves with the +code; any enumeration here would drift. **Semantic divergences vs the default namespace.** Functions that exist in both `functions` and `functions.spark` may behave differently: