diff --git a/.ai/skills/audit-skill-md/SKILL.md b/.ai/skills/audit-skill-md/SKILL.md index 30e1a90fd..ba5255a59 100644 --- a/.ai/skills/audit-skill-md/SKILL.md +++ b/.ai/skills/audit-skill-md/SKILL.md @@ -48,7 +48,8 @@ exposed at the package root), include it. | `SessionContext` | `python/datafusion/context.py` | "Data Loading" | | `DataFrame` | `python/datafusion/dataframe.py` | "DataFrame Operations Quick Reference", "Executing and Collecting Results", "Idiomatic Patterns" | | `Expr` | `python/datafusion/expr.py` | "Expression Building", "Common Pitfalls" | -| `functions` | `python/datafusion/functions.py` | "Available Functions (Categorized)", scattered uses throughout | +| `functions` | `python/datafusion/functions/__init__.py` | "Available Functions (Categorized)", scattered uses throughout | +| `functions.spark` | `python/datafusion/functions/spark.py` | "Available Functions (Categorized)" → "Spark-Compatible Functions" subsection | | Top-level helpers (`col`, `lit`, `WindowFrame`, ...) | `python/datafusion/__init__.py` | "Import Conventions", "Core Abstractions" | ## Scope argument @@ -61,7 +62,8 @@ is given or `all` is specified, audit every area. | `session-context` | `SessionContext` methods and the "Data Loading" section | | `dataframe` | `DataFrame` methods and the operations / executing / patterns sections | | `expr` | `Expr` methods/operators and the "Expression Building" section | -| `functions` | `functions.py` `__all__` and the "Available Functions (Categorized)" section | +| `functions` | `functions/__init__.py` `__all__` and the "Available Functions (Categorized)" section | +| `spark-functions` | `functions/spark.py` `__all__`, the "Spark-Compatible Functions" subsection, and the divergent-semantics table | | `patterns` | "Idiomatic Patterns" section — confirm patterns still match recommended style | | `pitfalls` | "Common Pitfalls" — confirm each pitfall still reproduces, drop ones fixed upstream | | `version-notes` | Cross-check version annotations (see below) | @@ -123,7 +125,11 @@ For each function name, method name, or import shown in `SKILL.md`, verify it still exists in the current API: - Function names mentioned in prose or in the categorized list should appear - in `python/datafusion/functions.py`'s `__all__`. + in `python/datafusion/functions/__init__.py`'s `__all__`. +- Spark function names mentioned in the "Spark-Compatible Functions" + subsection should appear in `python/datafusion/functions/spark.py`'s + `__all__`. Also confirm the divergent-semantics table still matches the + current spark vs. main signatures. - Method calls in code blocks should resolve against the current class. - Imports (`from datafusion import ...`) should succeed against the current `__init__.py`. diff --git a/.ai/skills/check-upstream/SKILL.md b/.ai/skills/check-upstream/SKILL.md index 24b4e1bb1..a3d82a670 100644 --- a/.ai/skills/check-upstream/SKILL.md +++ b/.ai/skills/check-upstream/SKILL.md @@ -209,18 +209,54 @@ These upstream FFI types have been reviewed and do not need to be independently - FFI example in `examples/datafusion-ffi-example/` - Type appears in union type hints where accepted -### 8. `__all__` Hygiene (functions.py) +### 8. Spark-Compatible Functions (`datafusion-spark` crate) + +**Upstream source of truth:** +- Crate source: https://github.com/apache/datafusion/tree/main/datafusion/spark/src +- Rust docs: https://docs.rs/datafusion-spark/latest/datafusion_spark/ + +**Where they are exposed in this project:** +- Python API: `python/datafusion/functions/spark.py` — each function wraps + a call to `datafusion._internal.functions.spark`; the public surface is + the module's `__all__` list. +- Rust bindings: `crates/core/src/spark_functions.rs` — `#[pyfunction]` + definitions registered via `init_module()` and re-exported under + `datafusion._internal.functions.spark`. + +**Coverage policy:** The spark namespace mirrors +`pyspark.sql.functions` parameter names and shapes exactly so pyspark +callers can paste code unchanged. Extras over pyspark are permitted as +long as positional pyspark calls still work — for example, the spark +`avg` / `try_sum` / `collect_list` / `collect_set` retain the +`distinct`/`filter`/`order_by`/`null_treatment` kwargs from the main +namespace while pyspark's single-positional form continues to work. + +**How to check:** +1. Fetch the upstream `datafusion-spark` function list from the crate + source under `datafusion/spark/src/function/` (each subdirectory is a + category: `string/`, `math/`, `datetime/`, etc.). The crate's + `function.rs` collects all `ScalarUDF` factories. +2. Cross-reference against `pyspark.sql.functions` for the public-facing + shape — pyspark is the contract this namespace is matching. +3. Compare against the functions listed in + `python/datafusion/functions/spark.py`'s `__all__`. A function is + covered if it exists in the Python `spark` namespace, even if it + aliases another function's Rust binding. +4. Report functions that are missing from the Python spark namespace. + +### 9. `__all__` Hygiene (functions.py and functions/spark.py) Independent of upstream parity, also flag public `def` symbols in -`python/datafusion/functions.py` that are missing from the module's -`__all__`. These are functions a user can call but that do not show up in +`python/datafusion/functions.py` **and** `python/datafusion/functions/spark.py` +that are missing from that file's `__all__`. These are functions a user +can call but that do not show up in `from datafusion.functions import *`, in tab-completion against the namespace, or in generated API docs — typically an oversight rather than an intentional omission. **How to check:** -1. Grep for `^def ([a-z_][a-z0-9_]*)\(` in `python/datafusion/functions.py` - to enumerate every public function definition. +1. Grep for `^def ([a-z_][a-z0-9_]*)\(` in each file to enumerate every + public function definition. 2. Read the `__all__` list at the top of the same file. 3. Report any function in (1) that is not in (2). Skip private helpers (names starting with `_`). diff --git a/.ai/skills/make-pythonic/SKILL.md b/.ai/skills/make-pythonic/SKILL.md index 57145ac6c..7d490ec03 100644 --- a/.ai/skills/make-pythonic/SKILL.md +++ b/.ai/skills/make-pythonic/SKILL.md @@ -29,9 +29,44 @@ You are improving the datafusion-python API to feel more natural to Python users **Core principle:** A Python user should be able to write `split_part(col("a"), ",", 2)` instead of `split_part(col("a"), lit(","), lit(2))` when the arguments are contextually obvious literals. +## Scope: `functions` vs `functions.spark` + +Both `python/datafusion/functions/__init__.py` and +`python/datafusion/functions/spark.py` are in scope. We want both to feel +pythonic — accept native Python types where the argument is contextually +a literal — but `functions.spark` carries an additional constraint: +**every signature must remain compatible with `pyspark.sql.functions`**. + +Compatibility rules for the spark namespace: + +- **Parameter names must match pyspark exactly.** Pyspark callers pass by + keyword (`spark.shiftleft(col=..., numBits=...)`), so renames break + them. Do NOT rename a parameter just because it would be more pythonic + in the main namespace. +- **Positional order must match pyspark exactly.** Reordering breaks + positional pyspark calls. +- **Type unions may widen the input set, never narrow it.** Pyspark + accepts `Column` or `str` (column name) for most args; we accept + `Expr` already, and widening to `Expr | int` / `Expr | str` for + literal-friendly arguments is on-brand because the int/str case is + exactly what a pyspark caller would also try. Just verify the widened + set is a superset of what pyspark accepts for that arg. +- **Extra keyword arguments are allowed** as long as they default to + `None` and pyspark's positional/keyword form still works (e.g. the + spark `avg`/`try_sum`/`collect_list`/`collect_set` retain DataFusion's + `distinct`/`filter`/`order_by`/`null_treatment` kwargs). + +Practical effect: in `functions.spark`, apply Categories A and (where +pyspark exposes the same arg as a non-`Expr`) B normally, but cross-check +each proposed signature against `pyspark.sql.functions` before landing +it. When pyspark's own type hint is `Column | str` for a "column name" +arg, prefer leaving the spark wrapper at `Expr` — Category C +("`Expr | str` meaning column name") is unusual in `functions.py` and +should remain so in `functions.spark`. + ## How to Identify Candidates -The user may specify a scope via `$ARGUMENTS`. If no scope is given or "all" is specified, audit all functions in `python/datafusion/functions.py`. +The user may specify a scope via `$ARGUMENTS`. If no scope is given or "all" is specified, audit all functions in `python/datafusion/functions/__init__.py` **and** `python/datafusion/functions/spark.py`. When updating a spark-namespace function, apply the compatibility rules from "Scope" above on top of the standard analysis. For each function, determine if any parameter can accept native Python types by evaluating **two complementary signals**: @@ -309,7 +344,7 @@ For each function being updated: ### Step 1: Analyze the Function -1. Read the current Python function signature in `python/datafusion/functions.py` +1. Read the current Python function signature in `python/datafusion/functions/__init__.py` 2. Read the Rust binding in `crates/core/src/functions.rs` 3. Optionally check the upstream DataFusion docs for the function 4. Determine which category (A, B, or C) applies to each parameter @@ -346,7 +381,7 @@ dfn.functions.left(dfn.col("a"), 3) After making changes, run the doctests to verify: ```bash -python -m pytest --doctest-modules python/datafusion/functions.py -v +python -m pytest --doctest-modules python/datafusion/functions/__init__.py -v ``` ## Coercion Helper Pattern diff --git a/Cargo.lock b/Cargo.lock index d3cedb628..7f5f63e96 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1505,6 +1505,7 @@ dependencies = [ "datafusion-ffi", "datafusion-proto", "datafusion-python-util", + "datafusion-spark", "datafusion-substrait", "futures", "log", @@ -1549,6 +1550,34 @@ dependencies = [ "parking_lot", ] +[[package]] +name = "datafusion-spark" +version = "54.0.0" +source = "git+https://github.com/apache/datafusion?rev=1321d60cc37ee487d1e7ce7f501357c3236b2542#1321d60cc37ee487d1e7ce7f501357c3236b2542" +dependencies = [ + "arrow", + "bigdecimal", + "chrono", + "crc32fast", + "datafusion-catalog", + "datafusion-common", + "datafusion-execution", + "datafusion-expr", + "datafusion-functions", + "datafusion-functions-aggregate", + "datafusion-functions-aggregate-common", + "datafusion-functions-nested", + "log", + "num-traits", + "percent-encoding", + "rand 0.9.4", + "serde_json", + "sha1", + "sha2", + "twox-hash", + "url", +] + [[package]] name = "datafusion-sql" version = "54.0.0" @@ -3515,6 +3544,17 @@ dependencies = [ "unsafe-libyaml", ] +[[package]] +name = "sha1" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aacc4cc499359472b4abe1bf11d0b12e688af9a805fa5e3016f9a386dc2d0214" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest 0.11.3", +] + [[package]] name = "sha2" version = "0.11.0" @@ -4009,6 +4049,9 @@ name = "twox-hash" version = "2.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ea3136b675547379c4bd395ca6b938e5ad3c3d20fad76e7fe85f9e0d011419c" +dependencies = [ + "rand 0.9.4", +] [[package]] name = "typenum" diff --git a/Cargo.toml b/Cargo.toml index e72c22368..a8e8559cb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,6 +48,7 @@ datafusion-catalog = { version = "54", default-features = false } datafusion-common = { version = "54", default-features = false } datafusion-functions-aggregate = { version = "54" } datafusion-functions-window = { version = "54" } +datafusion-spark = { version = "54" } datafusion-expr = { version = "54" } prost = "0.14.3" serde_json = "1" @@ -79,4 +80,5 @@ datafusion-catalog = { git = "https://github.com/apache/datafusion", rev = "1321 datafusion-common = { git = "https://github.com/apache/datafusion", rev = "1321d60cc37ee487d1e7ce7f501357c3236b2542" } datafusion-functions-aggregate = { git = "https://github.com/apache/datafusion", rev = "1321d60cc37ee487d1e7ce7f501357c3236b2542" } datafusion-functions-window = { git = "https://github.com/apache/datafusion", rev = "1321d60cc37ee487d1e7ce7f501357c3236b2542" } +datafusion-spark = { git = "https://github.com/apache/datafusion", rev = "1321d60cc37ee487d1e7ce7f501357c3236b2542" } datafusion-expr = { git = "https://github.com/apache/datafusion", rev = "1321d60cc37ee487d1e7ce7f501357c3236b2542" } diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index 1f5b4e305..2e8cf6c92 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -53,6 +53,7 @@ datafusion = { workspace = true, features = ["avro", "unicode_expressions"] } datafusion-substrait = { workspace = true, optional = true } datafusion-proto = { workspace = true } datafusion-ffi = { workspace = true } +datafusion-spark = { workspace = true } prost = { workspace = true } # keep in line with `datafusion-substrait` serde_json = { workspace = true } uuid = { workspace = true, features = ["v4"] } diff --git a/crates/core/src/context.rs b/crates/core/src/context.rs index d714861a6..16a12b6bc 100644 --- a/crates/core/src/context.rs +++ b/crates/core/src/context.rs @@ -1060,6 +1060,21 @@ impl PySessionContext { Ok(()) } + /// Register all `datafusion-spark` UDFs/UDAFs/UDWFs, overriding any built-in + /// DataFusion functions of the same name with their Spark-semantics version. + pub fn enable_spark_functions(&self) -> PyResult<()> { + for udf in datafusion_spark::all_default_scalar_functions() { + self.ctx.register_udf((*udf).clone()); + } + for udaf in datafusion_spark::all_default_aggregate_functions() { + self.ctx.register_udaf((*udaf).clone()); + } + for udwf in datafusion_spark::all_default_window_functions() { + self.ctx.register_udwf((*udwf).clone()); + } + Ok(()) + } + pub fn deregister_udaf(&self, name: &str) { self.ctx.deregister_udaf(name); } @@ -1562,10 +1577,9 @@ impl PySessionContext { pub fn parse_file_compression_type( file_compression_type: Option, ) -> Result { - FileCompressionType::from_str(&*file_compression_type.unwrap_or("".to_string()).as_str()) - .map_err(|_| { - PyValueError::new_err("file_compression_type must one of: gzip, bz2, xz, zstd") - }) + FileCompressionType::from_str(&file_compression_type.unwrap_or_default()).map_err(|_| { + PyValueError::new_err("file_compression_type must one of: gzip, bz2, xz, zstd") + }) } impl From for SessionContext { diff --git a/crates/core/src/functions.rs b/crates/core/src/functions.rs index 395d5ebfd..e861994e0 100644 --- a/crates/core/src/functions.rs +++ b/crates/core/src/functions.rs @@ -31,7 +31,7 @@ use crate::expr::conditional_expr::PyCaseBuilder; use crate::expr::sort_expr::{PySortExpr, to_sort_expressions}; use crate::expr::window::PyWindowFrame; -fn add_builder_fns_to_aggregate( +pub(crate) fn add_builder_fns_to_aggregate( agg_fn: Expr, distinct: Option, filter: Option, diff --git a/crates/core/src/lib.rs b/crates/core/src/lib.rs index 48abcedc9..7f0f9cb39 100644 --- a/crates/core/src/lib.rs +++ b/crates/core/src/lib.rs @@ -26,22 +26,17 @@ pub use datafusion_substrait; use mimalloc::MiMalloc; use pyo3::prelude::*; -#[allow(clippy::borrow_deref_ref)] pub mod analyzer; pub mod catalog; pub mod codec; pub mod common; -#[allow(clippy::borrow_deref_ref)] pub mod context; -#[allow(clippy::borrow_deref_ref)] pub mod dataframe; mod dataset; mod dataset_exec; pub mod errors; -#[allow(clippy::borrow_deref_ref)] pub mod expr; -#[allow(clippy::borrow_deref_ref)] mod functions; pub mod metrics; mod options; @@ -49,6 +44,7 @@ pub mod physical_plan; mod pyarrow_filter_expression; pub mod pyarrow_util; mod record_batch; +mod spark_functions; pub mod sql; pub mod store; pub mod table; @@ -57,9 +53,7 @@ pub mod unparser; mod array; #[cfg(feature = "substrait")] pub mod substrait; -#[allow(clippy::borrow_deref_ref)] mod udaf; -#[allow(clippy::borrow_deref_ref)] mod udf; pub mod udtf; mod udwf; @@ -124,6 +118,10 @@ fn _internal(py: Python, m: Bound<'_, PyModule>) -> PyResult<()> { // Register the functions as a submodule let funcs = PyModule::new(py, "functions")?; functions::init_module(&funcs)?; + // Spark-compatible functions live under `functions.spark`. + let spark_funcs = PyModule::new(py, "spark")?; + spark_functions::init_module(&spark_funcs)?; + funcs.add_submodule(&spark_funcs)?; m.add_submodule(&funcs)?; let store = PyModule::new(py, "object_store")?; diff --git a/crates/core/src/spark_functions.rs b/crates/core/src/spark_functions.rs new file mode 100644 index 000000000..e7cb94f8c --- /dev/null +++ b/crates/core/src/spark_functions.rs @@ -0,0 +1,363 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! PyO3 wrappers for the [`datafusion-spark`] crate. +//! +//! Exposes Spark-compatible scalar and aggregate function builders for use +//! from Python under `datafusion.functions.spark`. + +use datafusion::logical_expr::Expr; +use datafusion::logical_expr::expr::ScalarFunction; +use datafusion_spark::{expr_fn, function as udf}; +use pyo3::prelude::*; +use pyo3::wrap_pyfunction; + +use crate::common::data_type::NullTreatment; +use crate::errors::PyDataFusionResult; +use crate::expr::PyExpr; +use crate::expr::sort_expr::PySortExpr; +use crate::functions::add_builder_fns_to_aggregate; + +/// Generates a [pyo3] wrapper for [datafusion_spark::expr_fn]. +/// +/// These functions have explicit named arguments and mirror the upstream +/// `expr_fn::$FUNC` signature. +macro_rules! spark_expr_fn { + ($FUNC:ident) => { + spark_expr_fn!($FUNC,); + }; + ($FUNC:ident, $($arg:ident)*) => { + #[pyfunction] + fn $FUNC($($arg: PyExpr),*) -> PyExpr { + expr_fn::$FUNC($($arg.into()),*).into() + } + }; +} + +/// Generates a variadic [pyo3] wrapper that calls the [`ScalarUDF`] factory +/// directly. Required for functions whose upstream `expr_fn` wrapper accepts +/// a single `Expr` instead of `Vec` (an upstream `export_functions!` +/// macro-arm quirk), so we bypass it to get true Python `*args` semantics. +macro_rules! spark_udf_vec { + ($PY_NAME:ident, $UDF_PATH:path) => { + #[pyfunction] + #[pyo3(signature = (*args))] + fn $PY_NAME(args: Vec) -> PyExpr { + let udf = $UDF_PATH(); + let args: Vec = args.into_iter().map(Into::into).collect(); + Expr::ScalarFunction(ScalarFunction::new_udf(udf, args)).into() + } + }; +} + +/// Generates a [pyo3] wrapper for Spark aggregate functions. Mirrors +/// [`crate::functions::aggregate_function`] but points at +/// [`datafusion_spark::expr_fn`]. +macro_rules! spark_aggregate { + ($NAME:ident) => { + spark_aggregate!($NAME, expr); + }; + ($NAME:ident, $($arg:ident)*) => { + #[pyfunction] + #[pyo3(signature = ($($arg),*, distinct=None, filter=None, order_by=None, null_treatment=None))] + fn $NAME( + $($arg: PyExpr),*, + distinct: Option, + filter: Option, + order_by: Option>, + null_treatment: Option, + ) -> PyDataFusionResult { + let agg_fn = expr_fn::$NAME($($arg.into()),*); + add_builder_fns_to_aggregate(agg_fn, distinct, filter, order_by, null_treatment) + } + }; +} + +// --------------------------------------------------------------------------- +// Aggregate functions +// --------------------------------------------------------------------------- + +spark_aggregate!(avg, arg1); +spark_aggregate!(try_sum, arg1); +spark_aggregate!(collect_list, arg1); +spark_aggregate!(collect_set, arg1); + +// --------------------------------------------------------------------------- +// Array functions +// --------------------------------------------------------------------------- + +// Upstream factory is `spark_array_contains`; expose under the Spark SQL +// name `array_contains` on the Python side. +#[pyfunction] +fn array_contains(arr: PyExpr, element: PyExpr) -> PyExpr { + expr_fn::spark_array_contains(arr.into(), element.into()).into() +} +spark_udf_vec!(array, udf::array::array); +spark_expr_fn!(shuffle, arg1); +spark_expr_fn!(array_repeat, element count); +spark_expr_fn!(slice, arr start length); + +// --------------------------------------------------------------------------- +// Bitmap functions +// --------------------------------------------------------------------------- + +spark_expr_fn!(bitmap_count, arg1); +spark_expr_fn!(bitmap_bit_position, arg1); +spark_expr_fn!(bitmap_bucket_number, arg1); + +// --------------------------------------------------------------------------- +// Bitwise functions +// --------------------------------------------------------------------------- + +spark_expr_fn!(bit_get, col pos); +spark_expr_fn!(bit_count, col); +spark_expr_fn!(bitwise_not, col); +spark_expr_fn!(shiftleft, value shift); +spark_expr_fn!(shiftright, value shift); +spark_expr_fn!(shiftrightunsigned, value shift); + +// --------------------------------------------------------------------------- +// Collection / Conditional / Conversion +// --------------------------------------------------------------------------- + +spark_expr_fn!(size, arg1); + +// Python keyword `if` → exposed as `if_`. Upstream Rust ident is `r#if`. +#[pyfunction] +fn if_(condition: PyExpr, if_true: PyExpr, if_false: PyExpr) -> PyExpr { + expr_fn::r#if(condition.into(), if_true.into(), if_false.into()).into() +} + +// `spark_cast` is config-injected by the upstream `expr_fn` helper; defaults +// applied automatically there. +spark_expr_fn!(spark_cast, arg1 arg2); + +// --------------------------------------------------------------------------- +// Datetime functions +// --------------------------------------------------------------------------- + +spark_expr_fn!(add_months, start_date num_months); +spark_expr_fn!(date_add, start_date days); +spark_expr_fn!(date_sub, start_date days); +spark_expr_fn!(hour, arg1); +spark_expr_fn!(minute, arg1); +spark_expr_fn!(second, arg1); +spark_expr_fn!(last_day, arg1); +spark_expr_fn!(make_dt_interval, days hours mins secs); +spark_expr_fn!(make_interval, years months weeks days hours mins secs); +spark_expr_fn!(next_day, start_date day_of_week); +spark_expr_fn!(date_diff, end_date start_date); +spark_expr_fn!(date_trunc, fmt ts); +spark_expr_fn!(time_trunc, fmt t); +spark_expr_fn!(trunc, dt fmt); +spark_expr_fn!(date_part, field source); +spark_expr_fn!(from_utc_timestamp, ts tz); +spark_expr_fn!(to_utc_timestamp, ts tz); +spark_expr_fn!(unix_date, dt); +spark_expr_fn!(unix_micros, ts); +spark_expr_fn!(unix_millis, ts); +spark_expr_fn!(unix_seconds, ts); + +// --------------------------------------------------------------------------- +// Hash functions +// --------------------------------------------------------------------------- + +spark_expr_fn!(crc32, arg1); +spark_expr_fn!(sha1, arg1); +spark_expr_fn!(sha2, arg1 bit_length); +spark_udf_vec!(xxhash64, udf::hash::xxhash64); + +// --------------------------------------------------------------------------- +// JSON functions +// --------------------------------------------------------------------------- + +spark_udf_vec!(json_tuple, udf::json::json_tuple); + +// --------------------------------------------------------------------------- +// Map functions +// --------------------------------------------------------------------------- + +spark_expr_fn!(map_from_arrays, keys values); +spark_expr_fn!(map_from_entries, arg1); +spark_expr_fn!(str_to_map, text pair_delim key_value_delim); + +// --------------------------------------------------------------------------- +// Math functions +// --------------------------------------------------------------------------- + +spark_expr_fn!(abs, arg1); +spark_expr_fn!(ceil, arg1); +spark_expr_fn!(expm1, arg1); +spark_expr_fn!(factorial, arg1); +spark_expr_fn!(floor, arg1); +spark_expr_fn!(hex, arg1); +spark_expr_fn!(modulus, dividend divisor); +spark_expr_fn!(pmod, dividend divisor); +spark_expr_fn!(rint, arg1); +spark_expr_fn!(round, value scale); +spark_expr_fn!(unhex, arg1); +spark_expr_fn!(width_bucket, value min_value max_value num_buckets); +spark_expr_fn!(csc, arg1); +spark_expr_fn!(sec, arg1); +spark_expr_fn!(negative, arg1); +spark_expr_fn!(bin, arg1); + +// --------------------------------------------------------------------------- +// String functions +// --------------------------------------------------------------------------- + +spark_expr_fn!(ascii, arg1); +spark_expr_fn!(base64, bin_input); +// `char` collides with the Rust primitive type in macro hygiene; rename the +// Rust ident and re-expose under the original name to Python. +#[pyfunction] +#[pyo3(name = "char")] +fn char_fn(arg1: PyExpr) -> PyExpr { + expr_fn::char(arg1.into()).into() +} +spark_udf_vec!(concat, udf::string::concat); +spark_udf_vec!(elt, udf::string::elt); +spark_expr_fn!(ilike, str pattern); +spark_expr_fn!(length, arg1); +spark_expr_fn!(like, str pattern); +spark_expr_fn!(luhn_check, arg1); +spark_udf_vec!(format_string, udf::string::format_string); +spark_expr_fn!(space, arg1); +spark_expr_fn!(substring, str pos length); +spark_expr_fn!(unbase64, str); +spark_expr_fn!(soundex, str); +spark_expr_fn!(is_valid_utf8, str); +spark_expr_fn!(make_valid_utf8, str); + +// --------------------------------------------------------------------------- +// URL functions +// --------------------------------------------------------------------------- + +spark_udf_vec!(parse_url, udf::url::parse_url); +spark_udf_vec!(try_parse_url, udf::url::try_parse_url); +spark_udf_vec!(url_decode, udf::url::url_decode); +spark_udf_vec!(try_url_decode, udf::url::try_url_decode); +spark_udf_vec!(url_encode, udf::url::url_encode); + +// --------------------------------------------------------------------------- +// Module init +// --------------------------------------------------------------------------- + +pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { + // Aggregate + m.add_wrapped(wrap_pyfunction!(avg))?; + m.add_wrapped(wrap_pyfunction!(try_sum))?; + m.add_wrapped(wrap_pyfunction!(collect_list))?; + m.add_wrapped(wrap_pyfunction!(collect_set))?; + // Array + m.add_wrapped(wrap_pyfunction!(array_contains))?; + m.add_wrapped(wrap_pyfunction!(array))?; + m.add_wrapped(wrap_pyfunction!(shuffle))?; + m.add_wrapped(wrap_pyfunction!(array_repeat))?; + m.add_wrapped(wrap_pyfunction!(slice))?; + // Bitmap + m.add_wrapped(wrap_pyfunction!(bitmap_count))?; + m.add_wrapped(wrap_pyfunction!(bitmap_bit_position))?; + m.add_wrapped(wrap_pyfunction!(bitmap_bucket_number))?; + // Bitwise + m.add_wrapped(wrap_pyfunction!(bit_get))?; + m.add_wrapped(wrap_pyfunction!(bit_count))?; + m.add_wrapped(wrap_pyfunction!(bitwise_not))?; + m.add_wrapped(wrap_pyfunction!(shiftleft))?; + m.add_wrapped(wrap_pyfunction!(shiftright))?; + m.add_wrapped(wrap_pyfunction!(shiftrightunsigned))?; + // Collection + m.add_wrapped(wrap_pyfunction!(size))?; + // Conditional + m.add_wrapped(wrap_pyfunction!(if_))?; + // Conversion + m.add_wrapped(wrap_pyfunction!(spark_cast))?; + // Datetime + m.add_wrapped(wrap_pyfunction!(add_months))?; + m.add_wrapped(wrap_pyfunction!(date_add))?; + m.add_wrapped(wrap_pyfunction!(date_sub))?; + m.add_wrapped(wrap_pyfunction!(hour))?; + m.add_wrapped(wrap_pyfunction!(minute))?; + m.add_wrapped(wrap_pyfunction!(second))?; + m.add_wrapped(wrap_pyfunction!(last_day))?; + m.add_wrapped(wrap_pyfunction!(make_dt_interval))?; + m.add_wrapped(wrap_pyfunction!(make_interval))?; + m.add_wrapped(wrap_pyfunction!(next_day))?; + m.add_wrapped(wrap_pyfunction!(date_diff))?; + m.add_wrapped(wrap_pyfunction!(date_trunc))?; + m.add_wrapped(wrap_pyfunction!(time_trunc))?; + m.add_wrapped(wrap_pyfunction!(trunc))?; + m.add_wrapped(wrap_pyfunction!(date_part))?; + m.add_wrapped(wrap_pyfunction!(from_utc_timestamp))?; + m.add_wrapped(wrap_pyfunction!(to_utc_timestamp))?; + m.add_wrapped(wrap_pyfunction!(unix_date))?; + m.add_wrapped(wrap_pyfunction!(unix_micros))?; + m.add_wrapped(wrap_pyfunction!(unix_millis))?; + m.add_wrapped(wrap_pyfunction!(unix_seconds))?; + // Hash + m.add_wrapped(wrap_pyfunction!(crc32))?; + m.add_wrapped(wrap_pyfunction!(sha1))?; + m.add_wrapped(wrap_pyfunction!(sha2))?; + m.add_wrapped(wrap_pyfunction!(xxhash64))?; + // JSON + m.add_wrapped(wrap_pyfunction!(json_tuple))?; + // Map + m.add_wrapped(wrap_pyfunction!(map_from_arrays))?; + m.add_wrapped(wrap_pyfunction!(map_from_entries))?; + m.add_wrapped(wrap_pyfunction!(str_to_map))?; + // Math + m.add_wrapped(wrap_pyfunction!(abs))?; + m.add_wrapped(wrap_pyfunction!(ceil))?; + m.add_wrapped(wrap_pyfunction!(expm1))?; + m.add_wrapped(wrap_pyfunction!(factorial))?; + m.add_wrapped(wrap_pyfunction!(floor))?; + m.add_wrapped(wrap_pyfunction!(hex))?; + m.add_wrapped(wrap_pyfunction!(modulus))?; + m.add_wrapped(wrap_pyfunction!(pmod))?; + m.add_wrapped(wrap_pyfunction!(rint))?; + m.add_wrapped(wrap_pyfunction!(round))?; + m.add_wrapped(wrap_pyfunction!(unhex))?; + m.add_wrapped(wrap_pyfunction!(width_bucket))?; + m.add_wrapped(wrap_pyfunction!(csc))?; + m.add_wrapped(wrap_pyfunction!(sec))?; + m.add_wrapped(wrap_pyfunction!(negative))?; + m.add_wrapped(wrap_pyfunction!(bin))?; + // String + m.add_wrapped(wrap_pyfunction!(ascii))?; + m.add_wrapped(wrap_pyfunction!(base64))?; + m.add_wrapped(wrap_pyfunction!(char_fn))?; + m.add_wrapped(wrap_pyfunction!(concat))?; + m.add_wrapped(wrap_pyfunction!(elt))?; + m.add_wrapped(wrap_pyfunction!(ilike))?; + m.add_wrapped(wrap_pyfunction!(length))?; + m.add_wrapped(wrap_pyfunction!(like))?; + m.add_wrapped(wrap_pyfunction!(luhn_check))?; + m.add_wrapped(wrap_pyfunction!(format_string))?; + m.add_wrapped(wrap_pyfunction!(space))?; + m.add_wrapped(wrap_pyfunction!(substring))?; + m.add_wrapped(wrap_pyfunction!(unbase64))?; + m.add_wrapped(wrap_pyfunction!(soundex))?; + m.add_wrapped(wrap_pyfunction!(is_valid_utf8))?; + m.add_wrapped(wrap_pyfunction!(make_valid_utf8))?; + // URL + m.add_wrapped(wrap_pyfunction!(parse_url))?; + m.add_wrapped(wrap_pyfunction!(try_parse_url))?; + m.add_wrapped(wrap_pyfunction!(url_decode))?; + m.add_wrapped(wrap_pyfunction!(try_url_decode))?; + m.add_wrapped(wrap_pyfunction!(url_encode))?; + Ok(()) +} diff --git a/docs/source/user-guide/common-operations/functions.rst b/docs/source/user-guide/common-operations/functions.rst index ccb47a4e7..f656f8af7 100644 --- a/docs/source/user-guide/common-operations/functions.rst +++ b/docs/source/user-guide/common-operations/functions.rst @@ -21,6 +21,12 @@ Functions DataFusion provides a large number of built-in functions for performing complex queries without requiring user-defined functions. In here we will cover some of the more popular use cases. If you want to view all the functions go to the :py:mod:`Functions ` API Reference. +.. note:: + + For Apache Spark-compatible versions of these functions (with Spark + NULL-propagation, 1-indexed substrings, HALF_UP rounding, etc.), see + :doc:`spark-functions`. + We'll use the pokemon dataset in the following examples. .. ipython:: python diff --git a/docs/source/user-guide/common-operations/index.rst b/docs/source/user-guide/common-operations/index.rst index 7abd1f138..8b65c6de6 100644 --- a/docs/source/user-guide/common-operations/index.rst +++ b/docs/source/user-guide/common-operations/index.rst @@ -29,6 +29,7 @@ The contents of this section are designed to guide a new user through how to use expressions joins functions + spark-functions aggregations windows udf-and-udfa diff --git a/docs/source/user-guide/common-operations/spark-functions.rst b/docs/source/user-guide/common-operations/spark-functions.rst new file mode 100644 index 000000000..cb056bfc5 --- /dev/null +++ b/docs/source/user-guide/common-operations/spark-functions.rst @@ -0,0 +1,92 @@ +.. Licensed to the Apache Software Foundation (ASF) under one +.. or more contributor license agreements. See the NOTICE file +.. distributed with this work for additional information +.. regarding copyright ownership. The ASF licenses this file +.. to you under the Apache License, Version 2.0 (the +.. "License"); you may not use this file except in compliance +.. with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, +.. software distributed under the License is distributed on an +.. "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +.. KIND, either express or implied. See the License for the +.. specific language governing permissions and limitations +.. under the License. + +Spark-Compatible Functions +========================== + +DataFusion ships Spark-compatible versions of a wide set of functions +(string, math, datetime, hash, array, aggregate) through the upstream +``datafusion-spark`` crate. ``datafusion-python`` exposes these under +``datafusion.functions.spark`` for use from the DataFrame API, and via +:py:meth:`~datafusion.SessionContext.enable_spark_functions` for use from +SQL. + +Why a Separate Namespace? +------------------------- + +Several Spark functions share names with DataFusion built-ins but differ in +semantics. The most common divergences: + +- ``concat`` propagates NULL. ``concat('a', NULL, 'b')`` returns NULL under + Spark semantics, whereas the DataFusion default returns ``'ab'``. +- ``substring`` is 1-indexed and supports negative positions counting from + the end of the string. +- ``round`` uses HALF_UP rounding mode (``round(2.5, 0) == 3``). +- Numeric functions (``floor``, ``ceil``, ``mod``) follow Spark's edge-case + handling for negative values and decimals. + +Enabling Spark functions does not affect the DataFrame API: you choose which +implementation to call by which module you import from. + +DataFrame API +------------- + +Import ``spark`` and use it like any other functions module. The Spark +functions can go anywhere you'd put a DataFusion expression — inside +``select``, ``filter``, ``with_column``, ``aggregate``, and so on. + +.. code-block:: python + + from datafusion import SessionContext, col, lit + from datafusion.functions import spark + + ctx = SessionContext() + df = ctx.from_pydict({"s": ["hello", "world"]}) + + # SHA-256 hash with Spark semantics + df.select(spark.sha2(col("s"), lit(256)).alias("h")).show() + + # 1-indexed substring + df.select(spark.substring(col("s"), lit(1), lit(3)).alias("p")).show() + +SQL +--- + +To use Spark functions in SQL queries, call +:py:meth:`~datafusion.SessionContext.enable_spark_functions` on the context. +This registers every Spark UDF/UDAF/UDWF, overriding any DataFusion built-in +of the same name. + +.. code-block:: python + + from datafusion import SessionContext + + ctx = SessionContext() + ctx.enable_spark_functions() + + ctx.sql("SELECT sha2('hello', 256)").show() + ctx.sql("SELECT concat('a', NULL, 'b')").show() # -> NULL, not 'ab' + +The override applies for the lifetime of the session. To call DataFusion's +built-in versions afterwards, create a fresh ``SessionContext``. + +Function Reference +------------------ + +The full, up-to-date list of available Spark functions — with signatures +and per-function docstrings — lives in the +:py:mod:`datafusion.functions.spark` API reference. diff --git a/pyproject.toml b/pyproject.toml index 2b6a976db..d33d38e7d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -184,6 +184,7 @@ extend-allowed-calls = ["datafusion.lit", "lit"] [tool.codespell] skip = [ "*/tests/test_functions.py", + "*/tests/test_spark_functions.py", "*/target", "./uv.lock", "uv.lock", diff --git a/python/datafusion/context.py b/python/datafusion/context.py index 52bd600c3..71b78c0c8 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -1297,6 +1297,28 @@ def register_udaf(self, udaf: AggregateUDF) -> None: """Register a user-defined aggregation function (UDAF) with the context.""" self.ctx.register_udaf(udaf._udaf) + def enable_spark_functions(self) -> None: + """Register all Spark-compatible functions for SQL access. + + Registers every UDF/UDAF/UDWF from the ``datafusion-spark`` crate, + overriding any DataFusion built-ins of the same name with their + Spark-semantics version (e.g. ``substring`` becomes 1-indexed, + ``concat`` propagates NULL, ``round`` uses HALF_UP rounding). + + For DataFrame use, import the typed wrappers from + :py:mod:`datafusion.functions.spark` directly; this method is only + needed for SQL queries. + + Examples: + >>> ctx = dfn.SessionContext() + >>> ctx.enable_spark_functions() + >>> ctx.sql( + ... "SELECT sha2('hello', 256) AS h" + ... ).collect_column("h")[0].as_py() + '2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824' + """ + self.ctx.enable_spark_functions() + def deregister_udaf(self, name: str) -> None: """Remove a user-defined aggregate function from the session. diff --git a/python/datafusion/functions.py b/python/datafusion/functions/__init__.py similarity index 99% rename from python/datafusion/functions.py rename to python/datafusion/functions/__init__.py index c8f07497d..2adb56fe4 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions/__init__.py @@ -59,6 +59,7 @@ sort_list_to_raw_sort_list, sort_or_default, ) +from datafusion.functions import spark __all__ = [ "abs", @@ -325,6 +326,7 @@ "signum", "sin", "sinh", + "spark", "split_part", "sqrt", "starts_with", diff --git a/python/datafusion/functions/spark.py b/python/datafusion/functions/spark.py new file mode 100644 index 000000000..4c7127344 --- /dev/null +++ b/python/datafusion/functions/spark.py @@ -0,0 +1,1762 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Spark-compatible function bindings. + +These functions mirror the semantics of their Apache Spark counterparts +exactly. Some override DataFusion built-ins (``substring`` is 1-indexed, +``concat`` propagates NULL, ``round`` uses HALF_UP rounding, etc.), which is +why they live in a separate namespace rather than replacing the defaults. + +For DataFrame use, import this module and call functions directly. For SQL +use, call :py:meth:`datafusion.SessionContext.enable_spark_functions` to +register the Spark UDFs by name (overriding any built-ins with matching +names) before issuing SQL queries. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import pyarrow as pa + +from datafusion._internal import functions as _functions +from datafusion.expr import Expr, sort_list_to_raw_sort_list + +if TYPE_CHECKING: + from datafusion.common import NullTreatment + from datafusion.expr import SortKey + +_f = _functions.spark + +# Reused int32 literal so optional-arg defaults don't rebuild it per call. +_ZERO_I32 = Expr.literal(pa.scalar(0, type=pa.int32())) + + +def _filter_raw(filter: Expr | None) -> Any: + return filter.expr if filter is not None else None + + +# --------------------------------------------------------------------------- +# Aggregate functions +# --------------------------------------------------------------------------- + + +def avg( + col: Expr, + distinct: bool | None = None, + filter: Expr | None = None, + order_by: list[SortKey] | SortKey | None = None, + null_treatment: NullTreatment | None = None, +) -> Expr: + """Spark ``avg``: returns the mean of a numeric column. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1.0, 2.0, 3.0]}) + >>> r = df.aggregate( + ... [], [dfn.functions.spark.avg(dfn.col("a")).alias("v")]) + >>> r.collect_column("v")[0].as_py() + 2.0 + """ + return Expr( + _f.avg( + col.expr, + distinct=distinct, + filter=_filter_raw(filter), + order_by=sort_list_to_raw_sort_list(order_by), + null_treatment=null_treatment, + ) + ) + + +def try_sum( + col: Expr, + distinct: bool | None = None, + filter: Expr | None = None, + order_by: list[SortKey] | SortKey | None = None, + null_treatment: NullTreatment | None = None, +) -> Expr: + """Spark ``try_sum``: sum that returns NULL on overflow. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 2, 3]}) + >>> r = df.aggregate( + ... [], [dfn.functions.spark.try_sum(dfn.col("a")).alias("v")]) + >>> r.collect_column("v")[0].as_py() + 6 + """ + return Expr( + _f.try_sum( + col.expr, + distinct=distinct, + filter=_filter_raw(filter), + order_by=sort_list_to_raw_sort_list(order_by), + null_treatment=null_treatment, + ) + ) + + +def collect_list( + col: Expr, + distinct: bool | None = None, + filter: Expr | None = None, + order_by: list[SortKey] | SortKey | None = None, + null_treatment: NullTreatment | None = None, +) -> Expr: + """Spark ``collect_list``: collect values into an array (preserves dups). + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 2, 2]}) + >>> r = df.aggregate( + ... [], [dfn.functions.spark.collect_list(dfn.col("a")).alias("v")]) + >>> sorted(r.collect_column("v")[0].as_py()) + [1, 2, 2] + """ + return Expr( + _f.collect_list( + col.expr, + distinct=distinct, + filter=_filter_raw(filter), + order_by=sort_list_to_raw_sort_list(order_by), + null_treatment=null_treatment, + ) + ) + + +def collect_set( + col: Expr, + distinct: bool | None = None, + filter: Expr | None = None, + order_by: list[SortKey] | SortKey | None = None, + null_treatment: NullTreatment | None = None, +) -> Expr: + """Spark ``collect_set``: collect distinct values into an array. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 2, 2, 3]}) + >>> r = df.aggregate( + ... [], [dfn.functions.spark.collect_set(dfn.col("a")).alias("v")]) + >>> sorted(r.collect_column("v")[0].as_py()) + [1, 2, 3] + """ + return Expr( + _f.collect_set( + col.expr, + distinct=distinct, + filter=_filter_raw(filter), + order_by=sort_list_to_raw_sort_list(order_by), + null_treatment=null_treatment, + ) + ) + + +# --------------------------------------------------------------------------- +# Array functions +# --------------------------------------------------------------------------- + + +def array_contains(col: Expr, value: Expr) -> Expr: + """Spark ``array_contains``: true if the array contains the element. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.array_contains( + ... dfn.functions.spark.array(dfn.lit(1), dfn.lit(2)), + ... dfn.lit(1), + ... ).alias("v") + ... ) + >>> r.collect_column("v")[0].as_py() + True + """ + return Expr(_f.array_contains(col.expr, value.expr)) + + +def array(*cols: Expr) -> Expr: + """Spark ``array``: builds an array from the given elements. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.array( + ... dfn.lit(1), dfn.lit(2), dfn.lit(3) + ... ).alias("v") + ... ) + >>> r.collect_column("v")[0].as_py() + [1, 2, 3] + """ + return Expr(_f.array(*[c.expr for c in cols])) + + +def shuffle(col: Expr, seed: int | None = None) -> Expr: + """Spark ``shuffle``: returns a random permutation of the input array. + + ``seed`` is accepted for pyspark parity but is not yet wired through the + Rust binding; passing a non-``None`` value raises ``NotImplementedError``. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.shuffle( + ... dfn.functions.spark.array(dfn.lit(1), dfn.lit(2), dfn.lit(3)) + ... ).alias("v") + ... ) + >>> sorted(r.collect_column("v")[0].as_py()) + [1, 2, 3] + """ + if seed is not None: + msg = "shuffle(seed=...) is not yet supported by the Spark UDF binding" + raise NotImplementedError(msg) + return Expr(_f.shuffle(col.expr)) + + +def array_repeat(col: Expr, count: Expr) -> Expr: + """Spark ``array_repeat``: array of ``element`` repeated ``count`` times. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.array_repeat(dfn.lit("a"), dfn.lit(3)).alias("v")) + >>> r.collect_column("v")[0].as_py() + ['a', 'a', 'a'] + """ + return Expr(_f.array_repeat(col.expr, count.expr)) + + +def slice(x: Expr, start: Expr, length: Expr) -> Expr: + """Spark ``slice``: subset of the array from 1-indexed ``start`` with ``length``. + + Negative ``start`` counts from the end. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.slice( + ... dfn.functions.spark.array( + ... dfn.lit(1), dfn.lit(2), dfn.lit(3), dfn.lit(4)), + ... dfn.lit(2), dfn.lit(2), + ... ).alias("v") + ... ) + >>> r.collect_column("v")[0].as_py() + [2, 3] + """ + return Expr(_f.slice(x.expr, start.expr, length.expr)) + + +# --------------------------------------------------------------------------- +# Bitmap functions +# --------------------------------------------------------------------------- + + +def bitmap_count(col: Expr) -> Expr: + r"""Spark ``bitmap_count``: number of set bits in a bitmap. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.bitmap_count(dfn.lit(b"\xff")).alias("v")) + >>> r.collect_column("v")[0].as_py() + 8 + """ + return Expr(_f.bitmap_count(col.expr)) + + +def bitmap_bit_position(col: Expr) -> Expr: + """Spark ``bitmap_bit_position``: bit position for a child expression. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.bitmap_bit_position(dfn.lit(15)).alias("v")) + >>> r.collect_column("v")[0].as_py() + 14 + """ + return Expr(_f.bitmap_bit_position(col.expr)) + + +def bitmap_bucket_number(col: Expr) -> Expr: + """Spark ``bitmap_bucket_number``: bucket number for a child expression. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.bitmap_bucket_number(dfn.lit(15)).alias("v")) + >>> r.collect_column("v")[0].as_py() + 1 + """ + return Expr(_f.bitmap_bucket_number(col.expr)) + + +# --------------------------------------------------------------------------- +# Bitwise functions +# --------------------------------------------------------------------------- + + +def bit_get(col: Expr, pos: Expr) -> Expr: + """Spark ``bit_get``: returns the bit (0 or 1) at ``pos``. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.bit_get(dfn.lit(5), dfn.lit(0)).alias("v")) + >>> r.collect_column("v")[0].as_py() + 1 + """ + return Expr(_f.bit_get(col.expr, pos.expr)) + + +def bit_count(col: Expr) -> Expr: + """Spark ``bit_count``: number of bits set in the integer's binary form. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select(dfn.functions.spark.bit_count(dfn.lit(7)).alias("v")) + >>> r.collect_column("v")[0].as_py() + 3 + """ + return Expr(_f.bit_count(col.expr)) + + +def bitwise_not(col: Expr) -> Expr: + """Spark ``~``: bitwise NOT. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select(dfn.functions.spark.bitwise_not(dfn.lit(0)).alias("v")) + >>> r.collect_column("v")[0].as_py() + -1 + """ + return Expr(_f.bitwise_not(col.expr)) + + +def shiftleft(col: Expr, numBits: Expr) -> Expr: # noqa: N803 + """Spark ``shiftleft``: ``value`` shifted left by ``shift`` bits. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.shiftleft(dfn.lit(1), dfn.lit(3)).alias("v")) + >>> r.collect_column("v")[0].as_py() + 8 + """ + return Expr(_f.shiftleft(col.expr, numBits.expr)) + + +def shiftright(col: Expr, numBits: Expr) -> Expr: # noqa: N803 + """Spark ``shiftright``: arithmetic right shift. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.shiftright(dfn.lit(8), dfn.lit(2)).alias("v")) + >>> r.collect_column("v")[0].as_py() + 2 + """ + return Expr(_f.shiftright(col.expr, numBits.expr)) + + +def shiftrightunsigned(col: Expr, numBits: Expr) -> Expr: # noqa: N803 + """Spark ``shiftrightunsigned``: logical (unsigned) right shift. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.shiftrightunsigned( + ... dfn.lit(8), dfn.lit(2) + ... ).alias("v") + ... ) + >>> r.collect_column("v")[0].as_py() + 2 + """ + return Expr(_f.shiftrightunsigned(col.expr, numBits.expr)) + + +# --------------------------------------------------------------------------- +# Collection / Conditional / Conversion +# --------------------------------------------------------------------------- + + +def size(col: Expr) -> Expr: + """Spark ``size``: length of an array or map. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.size( + ... dfn.functions.spark.array(dfn.lit(1), dfn.lit(2), dfn.lit(3)) + ... ).alias("v") + ... ) + >>> r.collect_column("v")[0].as_py() + 3 + """ + return Expr(_f.size(col.expr)) + + +def if_(condition: Expr, if_true: Expr, if_false: Expr) -> Expr: + """Spark ``if``: returns ``if_true`` when ``condition`` is true, else ``if_false``. + + Exposed as ``if_`` because ``if`` is a Python keyword. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 2]}) + >>> r = df.select( + ... dfn.functions.spark.if_( + ... dfn.col("a") > dfn.lit(1), dfn.lit("big"), dfn.lit("small") + ... ).alias("v") + ... ) + >>> r.collect_column("v").to_pylist() + ['small', 'big'] + """ + return Expr(_f.if_(condition.expr, if_true.expr, if_false.expr)) + + +def spark_cast(arg: Expr, type_str: Expr) -> Expr: + """Spark ``cast``: cast ``arg`` to the type named by ``type_str``. + + Uses Spark cast semantics (e.g. overflow returns NULL, not error). + + Currently only supports casting numeric values to ``"timestamp"``. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.spark_cast( + ... dfn.lit(1579098645), dfn.lit("timestamp") + ... ).alias("v") + ... ) + >>> r.collect_column("v")[0].as_py() + datetime.datetime(2020, 1, 15, 14, 30, 45, tzinfo=) + """ + return Expr(_f.spark_cast(arg.expr, type_str.expr)) + + +# --------------------------------------------------------------------------- +# Datetime functions +# --------------------------------------------------------------------------- + + +def add_months(start: Expr, months: Expr) -> Expr: + """Spark ``add_months``: date + N months. + + Examples: + >>> import pyarrow as pa + >>> from datetime import date + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> d = dfn.lit(pa.scalar(date(2020, 1, 15), type=pa.date32())) + >>> r = df.select( + ... dfn.functions.spark.add_months( + ... d, dfn.lit(pa.scalar(2, type=pa.int32())) + ... ).alias("v") + ... ) + >>> r.collect_column("v")[0].as_py() + datetime.date(2020, 3, 15) + """ + return Expr(_f.add_months(start.expr, months.expr)) + + +def date_add(start: Expr, days: Expr) -> Expr: + """Spark ``date_add``: date + N days. + + Examples: + >>> import pyarrow as pa + >>> from datetime import date + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> d = dfn.lit(pa.scalar(date(2020, 1, 15), type=pa.date32())) + >>> r = df.select( + ... dfn.functions.spark.date_add( + ... d, dfn.lit(pa.scalar(5, type=pa.int32())) + ... ).alias("v") + ... ) + >>> r.collect_column("v")[0].as_py() + datetime.date(2020, 1, 20) + """ + return Expr(_f.date_add(start.expr, days.expr)) + + +def date_sub(start: Expr, days: Expr) -> Expr: + """Spark ``date_sub``: date - N days. + + Examples: + >>> import pyarrow as pa + >>> from datetime import date + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> d = dfn.lit(pa.scalar(date(2020, 1, 15), type=pa.date32())) + >>> r = df.select( + ... dfn.functions.spark.date_sub( + ... d, dfn.lit(pa.scalar(5, type=pa.int32())) + ... ).alias("v") + ... ) + >>> r.collect_column("v")[0].as_py() + datetime.date(2020, 1, 10) + """ + return Expr(_f.date_sub(start.expr, days.expr)) + + +def hour(col: Expr) -> Expr: + """Spark ``hour``: extract hour component of a timestamp. + + Examples: + >>> import pyarrow as pa + >>> from datetime import datetime + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> ts = dfn.lit( + ... pa.scalar(datetime(2020, 1, 15, 14, 30, 45), + ... type=pa.timestamp('us'))) + >>> r = df.select(dfn.functions.spark.hour(ts).alias("v")) + >>> r.collect_column("v")[0].as_py() + 14 + """ + return Expr(_f.hour(col.expr)) + + +def minute(col: Expr) -> Expr: + """Spark ``minute``: extract minute component of a timestamp. + + Examples: + >>> import pyarrow as pa + >>> from datetime import datetime + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> ts = dfn.lit( + ... pa.scalar(datetime(2020, 1, 15, 14, 30, 45), + ... type=pa.timestamp('us'))) + >>> r = df.select(dfn.functions.spark.minute(ts).alias("v")) + >>> r.collect_column("v")[0].as_py() + 30 + """ + return Expr(_f.minute(col.expr)) + + +def second(col: Expr) -> Expr: + """Spark ``second``: extract second component of a timestamp. + + Examples: + >>> import pyarrow as pa + >>> from datetime import datetime + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> ts = dfn.lit( + ... pa.scalar(datetime(2020, 1, 15, 14, 30, 45), + ... type=pa.timestamp('us'))) + >>> r = df.select(dfn.functions.spark.second(ts).alias("v")) + >>> r.collect_column("v")[0].as_py() + 45 + """ + return Expr(_f.second(col.expr)) + + +def last_day(col: Expr) -> Expr: + """Spark ``last_day``: last day of the month containing the date. + + Examples: + >>> import pyarrow as pa + >>> from datetime import date + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> d = dfn.lit(pa.scalar(date(2020, 1, 15), type=pa.date32())) + >>> r = df.select(dfn.functions.spark.last_day(d).alias("v")) + >>> r.collect_column("v")[0].as_py() + datetime.date(2020, 1, 31) + """ + return Expr(_f.last_day(col.expr)) + + +def make_dt_interval( + days: Expr | None = None, + hours: Expr | None = None, + mins: Expr | None = None, + secs: Expr | None = None, +) -> Expr: + """Spark ``make_dt_interval``: day-time interval from components. + + All parts are optional; omitted parts default to zero, matching pyspark. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select(dfn.functions.spark.make_dt_interval().alias("v")) + >>> r.collect_column("v")[0].as_py() + datetime.timedelta(0) + + >>> import pyarrow as pa + >>> i32 = lambda n: dfn.lit(pa.scalar(n, type=pa.int32())) + >>> r = df.select( + ... dfn.functions.spark.make_dt_interval( + ... days=i32(1), hours=i32(2), mins=i32(3), secs=dfn.lit(4.5) + ... ).alias("v") + ... ) + >>> r.collect_column("v")[0].as_py() + datetime.timedelta(days=1, seconds=7384, microseconds=500000) + """ + return Expr( + _f.make_dt_interval( + (days if days is not None else _ZERO_I32).expr, + (hours if hours is not None else _ZERO_I32).expr, + (mins if mins is not None else _ZERO_I32).expr, + (secs if secs is not None else Expr.literal(0.0)).expr, + ) + ) + + +def make_interval( + years: Expr | None = None, + months: Expr | None = None, + weeks: Expr | None = None, + days: Expr | None = None, + hours: Expr | None = None, + mins: Expr | None = None, + secs: Expr | None = None, +) -> Expr: + """Spark ``make_interval``: interval from year/month/week/day/hour/min/sec parts. + + All parts are optional; omitted parts default to zero, matching pyspark. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select(dfn.functions.spark.make_interval().alias("v")) + >>> r.collect_column("v")[0].as_py().months + 0 + + >>> import pyarrow as pa + >>> i32 = lambda n: dfn.lit(pa.scalar(n, type=pa.int32())) + >>> r = df.select(dfn.functions.spark.make_interval(years=i32(1)).alias("v")) + >>> r.collect_column("v")[0].as_py().months + 12 + """ + return Expr( + _f.make_interval( + (years if years is not None else _ZERO_I32).expr, + (months if months is not None else _ZERO_I32).expr, + (weeks if weeks is not None else _ZERO_I32).expr, + (days if days is not None else _ZERO_I32).expr, + (hours if hours is not None else _ZERO_I32).expr, + (mins if mins is not None else _ZERO_I32).expr, + (secs if secs is not None else Expr.literal(0.0)).expr, + ) + ) + + +def next_day(date: Expr, dayOfWeek: Expr) -> Expr: # noqa: N803 + """Spark ``next_day``: first date after ``start_date`` named ``day_of_week``. + + Examples: + >>> import pyarrow as pa + >>> from datetime import date + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> d = dfn.lit(pa.scalar(date(2020, 1, 15), type=pa.date32())) + >>> r = df.select( + ... dfn.functions.spark.next_day(d, dfn.lit("Mon")).alias("v")) + >>> r.collect_column("v")[0].as_py() + datetime.date(2020, 1, 20) + """ + return Expr(_f.next_day(date.expr, dayOfWeek.expr)) + + +def date_diff(end: Expr, start: Expr) -> Expr: + """Spark ``date_diff``: number of days from ``start_date`` to ``end_date``. + + Examples: + >>> import pyarrow as pa + >>> from datetime import date + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> d = dfn.lit(pa.scalar(date(2020, 1, 15), type=pa.date32())) + >>> end = dfn.lit(pa.scalar(date(2020, 1, 20), type=pa.date32())) + >>> r = df.select(dfn.functions.spark.date_diff(end, d).alias("v")) + >>> r.collect_column("v")[0].as_py() + 5 + """ + return Expr(_f.date_diff(end.expr, start.expr)) + + +def date_trunc(format: Expr, timestamp: Expr) -> Expr: + """Spark ``date_trunc``: truncate timestamp to unit ``fmt``. + + Examples: + >>> import pyarrow as pa + >>> from datetime import datetime + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> ts = dfn.lit( + ... pa.scalar(datetime(2020, 1, 15, 14, 30, 45), + ... type=pa.timestamp('us'))) + >>> r = df.select( + ... dfn.functions.spark.date_trunc(dfn.lit("month"), ts).alias("v")) + >>> r.collect_column("v")[0].as_py() + datetime.datetime(2020, 1, 1, 0, 0) + """ + return Expr(_f.date_trunc(format.expr, timestamp.expr)) + + +def time_trunc(unit: Expr, time: Expr) -> Expr: + """Spark ``time_trunc``: truncate time value to unit ``fmt``. + + Examples: + >>> import pyarrow as pa + >>> from datetime import time + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> t = dfn.lit(pa.scalar(time(14, 30, 45), type=pa.time64('us'))) + >>> r = df.select( + ... dfn.functions.spark.time_trunc(dfn.lit("hour"), t).alias("v")) + >>> r.collect_column("v")[0].as_py() + datetime.time(14, 0) + """ + return Expr(_f.time_trunc(unit.expr, time.expr)) + + +def trunc(date: Expr, format: Expr) -> Expr: + """Spark ``trunc``: truncate date to unit ``fmt``. + + Examples: + >>> import pyarrow as pa + >>> from datetime import date + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> d = dfn.lit(pa.scalar(date(2020, 1, 15), type=pa.date32())) + >>> r = df.select( + ... dfn.functions.spark.trunc(d, dfn.lit("YEAR")).alias("v")) + >>> r.collect_column("v")[0].as_py() + datetime.date(2020, 1, 1) + """ + return Expr(_f.trunc(date.expr, format.expr)) + + +def date_part(field: Expr, source: Expr) -> Expr: + """Spark ``date_part``: extract ``field`` from a date/time/timestamp. + + Examples: + >>> import pyarrow as pa + >>> from datetime import date + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> d = dfn.lit(pa.scalar(date(2020, 1, 15), type=pa.date32())) + >>> r = df.select( + ... dfn.functions.spark.date_part(dfn.lit("year"), d).alias("v")) + >>> r.collect_column("v")[0].as_py() + 2020 + """ + return Expr(_f.date_part(field.expr, source.expr)) + + +def from_utc_timestamp(timestamp: Expr, tz: Expr) -> Expr: + """Spark ``from_utc_timestamp``: interpret ``ts`` as UTC, convert to ``tz``. + + Examples: + >>> import pyarrow as pa + >>> from datetime import datetime + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> ts = dfn.lit( + ... pa.scalar(datetime(2020, 1, 15, 14, 30, 45), + ... type=pa.timestamp('us'))) + >>> r = df.select( + ... dfn.functions.spark.from_utc_timestamp( + ... ts, dfn.lit("UTC") + ... ).alias("v") + ... ) + >>> r.collect_column("v")[0].as_py() + datetime.datetime(2020, 1, 15, 14, 30, 45) + """ + return Expr(_f.from_utc_timestamp(timestamp.expr, tz.expr)) + + +def to_utc_timestamp(timestamp: Expr, tz: Expr) -> Expr: + """Spark ``to_utc_timestamp``: interpret ``ts`` as ``tz``, convert to UTC. + + Examples: + >>> import pyarrow as pa + >>> from datetime import datetime + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> ts = dfn.lit( + ... pa.scalar(datetime(2020, 1, 15, 14, 30, 45), + ... type=pa.timestamp('us'))) + >>> r = df.select( + ... dfn.functions.spark.to_utc_timestamp( + ... ts, dfn.lit("UTC") + ... ).alias("v") + ... ) + >>> r.collect_column("v")[0].as_py() + datetime.datetime(2020, 1, 15, 14, 30, 45) + """ + return Expr(_f.to_utc_timestamp(timestamp.expr, tz.expr)) + + +def unix_date(col: Expr) -> Expr: + """Spark ``unix_date``: days since 1970-01-01 for ``dt``. + + Examples: + >>> import pyarrow as pa + >>> from datetime import date + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> d = dfn.lit(pa.scalar(date(2020, 1, 15), type=pa.date32())) + >>> r = df.select(dfn.functions.spark.unix_date(d).alias("v")) + >>> r.collect_column("v")[0].as_py() + 18276 + """ + return Expr(_f.unix_date(col.expr)) + + +def unix_micros(col: Expr) -> Expr: + """Spark ``unix_micros``: microseconds since epoch for ``ts``. + + Examples: + >>> import pyarrow as pa + >>> from datetime import datetime + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> ts = dfn.lit( + ... pa.scalar(datetime(2020, 1, 15, 14, 30, 45), + ... type=pa.timestamp('us'))) + >>> r = df.select(dfn.functions.spark.unix_micros(ts).alias("v")) + >>> r.collect_column("v")[0].as_py() + 1579098645000000 + """ + return Expr(_f.unix_micros(col.expr)) + + +def unix_millis(col: Expr) -> Expr: + """Spark ``unix_millis``: milliseconds since epoch for ``ts``. + + Examples: + >>> import pyarrow as pa + >>> from datetime import datetime + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> ts = dfn.lit( + ... pa.scalar(datetime(2020, 1, 15, 14, 30, 45), + ... type=pa.timestamp('us'))) + >>> r = df.select(dfn.functions.spark.unix_millis(ts).alias("v")) + >>> r.collect_column("v")[0].as_py() + 1579098645000 + """ + return Expr(_f.unix_millis(col.expr)) + + +def unix_seconds(col: Expr) -> Expr: + """Spark ``unix_seconds``: seconds since epoch for ``ts``. + + Examples: + >>> import pyarrow as pa + >>> from datetime import datetime + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> ts = dfn.lit( + ... pa.scalar(datetime(2020, 1, 15, 14, 30, 45), + ... type=pa.timestamp('us'))) + >>> r = df.select(dfn.functions.spark.unix_seconds(ts).alias("v")) + >>> r.collect_column("v")[0].as_py() + 1579098645 + """ + return Expr(_f.unix_seconds(col.expr)) + + +# --------------------------------------------------------------------------- +# Hash functions +# --------------------------------------------------------------------------- + + +def crc32(col: Expr) -> Expr: + """Spark ``crc32``: cyclic redundancy check value as a bigint. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"s": ["ABC"]}) + >>> r = df.select(dfn.functions.spark.crc32(dfn.col("s")).alias("v")) + >>> r.collect_column("v")[0].as_py() + 2743272264 + """ + return Expr(_f.crc32(col.expr)) + + +def sha1(col: Expr) -> Expr: + """Spark ``sha1``: SHA-1 hash as a hex string. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"s": ["hello"]}) + >>> r = df.select(dfn.functions.spark.sha1(dfn.col("s")).alias("v")) + >>> r.collect_column("v")[0].as_py() + 'aaf4c61ddcc5e8a2dabede0f3b482cd9aea9434d' + """ + return Expr(_f.sha1(col.expr)) + + +def sha2(col: Expr, numBits: Expr) -> Expr: # noqa: N803 + """Spark ``sha2``: SHA-2 family hash (224, 256, 384, 512). Bit length 0 = 256. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"s": ["hello"]}) + >>> r = df.select( + ... dfn.functions.spark.sha2(dfn.col("s"), dfn.lit(256)).alias("v")) + >>> r.collect_column("v")[0].as_py() + '2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824' + """ + return Expr(_f.sha2(col.expr, numBits.expr)) + + +def xxhash64(*cols: Expr) -> Expr: + """Spark ``xxhash64``: 64-bit xxHash of the arguments. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.xxhash64(dfn.lit("hello")).alias("v")) + >>> r.collect_column("v")[0].as_py() + -4367754540140381902 + """ + return Expr(_f.xxhash64(*[c.expr for c in cols])) + + +# --------------------------------------------------------------------------- +# JSON functions +# --------------------------------------------------------------------------- + + +def json_tuple(col: Expr, *fields: Expr) -> Expr: + """Spark ``json_tuple``: extract top-level fields from a JSON string. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.json_tuple( + ... dfn.lit('{"a":1,"b":"x"}'), dfn.lit("a"), dfn.lit("b") + ... ).alias("v") + ... ) + >>> r.collect_column("v")[0].as_py() + {'c0': '1', 'c1': 'x'} + """ + return Expr(_f.json_tuple(col.expr, *[f.expr for f in fields])) + + +# --------------------------------------------------------------------------- +# Map functions +# --------------------------------------------------------------------------- + + +def map_from_arrays(col1: Expr, col2: Expr) -> Expr: + """Spark ``map_from_arrays``: build a map from parallel key/value arrays. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> keys = dfn.functions.spark.array(dfn.lit("a"), dfn.lit("b")) + >>> vals = dfn.functions.spark.array(dfn.lit(1), dfn.lit(2)) + >>> r = df.select( + ... dfn.functions.spark.map_from_arrays(keys, vals).alias("v")) + >>> r.collect_column("v")[0].as_py() + [('a', 1), ('b', 2)] + """ + return Expr(_f.map_from_arrays(col1.expr, col2.expr)) + + +def map_from_entries(col: Expr) -> Expr: + """Spark ``map_from_entries``: build a map from an array of key/value structs. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.map_from_arrays( + ... dfn.functions.spark.array(dfn.lit("a")), + ... dfn.functions.spark.array(dfn.lit(1)), + ... ).alias("v") + ... ) + >>> r.collect_column("v")[0].as_py() + [('a', 1)] + """ + return Expr(_f.map_from_entries(col.expr)) + + +def str_to_map( + text: Expr, + pair_delim: Expr | None = None, + key_value_delim: Expr | None = None, +) -> Expr: + """Spark ``str_to_map``: split text into key/value pairs using delimiters. + + Delimiters default to ``","`` and ``":"`` when omitted, matching pyspark. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.str_to_map(dfn.lit("a:1,b:2")).alias("v")) + >>> r.collect_column("v")[0].as_py() + [('a', '1'), ('b', '2')] + + >>> r = df.select( + ... dfn.functions.spark.str_to_map( + ... dfn.lit("a=1;b=2"), + ... pair_delim=dfn.lit(";"), + ... key_value_delim=dfn.lit("="), + ... ).alias("v") + ... ) + >>> r.collect_column("v")[0].as_py() + [('a', '1'), ('b', '2')] + """ + pd = pair_delim if pair_delim is not None else Expr.literal(",") + kvd = key_value_delim if key_value_delim is not None else Expr.literal(":") + return Expr(_f.str_to_map(text.expr, pd.expr, kvd.expr)) + + +# --------------------------------------------------------------------------- +# Math functions +# --------------------------------------------------------------------------- + + +def abs(col: Expr) -> Expr: + """Spark ``abs``: absolute value. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select(dfn.functions.spark.abs(dfn.lit(-5)).alias("v")) + >>> r.collect_column("v")[0].as_py() + 5 + """ + return Expr(_f.abs(col.expr)) + + +def ceil(col: Expr) -> Expr: + """Spark ``ceil``: smallest integer ≥ arg. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select(dfn.functions.spark.ceil(dfn.lit(1.2)).alias("v")) + >>> r.collect_column("v")[0].as_py() + 2 + """ + return Expr(_f.ceil(col.expr)) + + +def expm1(col: Expr) -> Expr: + """Spark ``expm1``: exp(arg) - 1. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select(dfn.functions.spark.expm1(dfn.lit(0.0)).alias("v")) + >>> r.collect_column("v")[0].as_py() + 0.0 + """ + return Expr(_f.expm1(col.expr)) + + +def factorial(col: Expr) -> Expr: + """Spark ``factorial``: n! for n in [0..20], else NULL. + + Examples: + >>> import pyarrow as pa + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.factorial( + ... dfn.lit(pa.scalar(5, type=pa.int32())) + ... ).alias("v") + ... ) + >>> r.collect_column("v")[0].as_py() + 120 + """ + return Expr(_f.factorial(col.expr)) + + +def floor(col: Expr) -> Expr: + """Spark ``floor``: largest integer ≤ arg. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select(dfn.functions.spark.floor(dfn.lit(1.8)).alias("v")) + >>> r.collect_column("v")[0].as_py() + 1 + """ + return Expr(_f.floor(col.expr)) + + +def hex(col: Expr) -> Expr: + """Spark ``hex``: hexadecimal representation. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select(dfn.functions.spark.hex(dfn.lit(255)).alias("v")) + >>> r.collect_column("v")[0].as_py() + 'FF' + """ + return Expr(_f.hex(col.expr)) + + +def modulus(dividend: Expr, divisor: Expr) -> Expr: + """Spark ``mod``: remainder of ``dividend / divisor`` (sign follows dividend). + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.modulus(dfn.lit(10), dfn.lit(3)).alias("v")) + >>> r.collect_column("v")[0].as_py() + 1 + """ + return Expr(_f.modulus(dividend.expr, divisor.expr)) + + +def pmod(dividend: Expr, divisor: Expr) -> Expr: + """Spark ``pmod``: positive remainder of division. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.pmod(dfn.lit(-1), dfn.lit(3)).alias("v")) + >>> r.collect_column("v")[0].as_py() + 2 + """ + return Expr(_f.pmod(dividend.expr, divisor.expr)) + + +def rint(col: Expr) -> Expr: + """Spark ``rint``: round to nearest mathematical integer (as double). + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select(dfn.functions.spark.rint(dfn.lit(2.5)).alias("v")) + >>> r.collect_column("v")[0].as_py() + 2.0 + """ + return Expr(_f.rint(col.expr)) + + +def round(col: Expr, scale: Expr | None = None) -> Expr: + """Spark ``round``: round to ``scale`` decimal places, HALF_UP rounding. + + ``scale`` defaults to zero when omitted, matching pyspark. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select(dfn.functions.spark.round(dfn.lit(2.5)).alias("v")) + >>> r.collect_column("v")[0].as_py() + 3.0 + + >>> r = df.select( + ... dfn.functions.spark.round( + ... dfn.lit(2.345), scale=dfn.lit(2) + ... ).alias("v") + ... ) + >>> r.collect_column("v")[0].as_py() + 2.35 + """ + scale_expr = scale if scale is not None else _ZERO_I32 + return Expr(_f.round(col.expr, scale_expr.expr)) + + +def unhex(col: Expr) -> Expr: + r"""Spark ``unhex``: convert hexadecimal string to binary. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select(dfn.functions.spark.unhex(dfn.lit("FF")).alias("v")) + >>> r.collect_column("v")[0].as_py() + b'\xff' + """ + return Expr(_f.unhex(col.expr)) + + +def width_bucket( + v: Expr, + min: Expr, + max: Expr, + numBucket: Expr, # noqa: N803 +) -> Expr: + """Spark ``width_bucket``: bucket number for ``value`` in equi-width histogram. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.width_bucket( + ... dfn.lit(5.0), dfn.lit(0.0), dfn.lit(10.0), dfn.lit(5) + ... ).alias("v") + ... ) + >>> r.collect_column("v")[0].as_py() + 3 + """ + return Expr(_f.width_bucket(v.expr, min.expr, max.expr, numBucket.expr)) + + +def csc(col: Expr) -> Expr: + """Spark ``csc``: cosecant. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select(dfn.functions.spark.csc(dfn.lit(1.5708)).alias("v")) + >>> f"{r.collect_column('v')[0].as_py():.4f}" + '1.0000' + """ + return Expr(_f.csc(col.expr)) + + +def sec(col: Expr) -> Expr: + """Spark ``sec``: secant. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select(dfn.functions.spark.sec(dfn.lit(0.0)).alias("v")) + >>> r.collect_column("v")[0].as_py() + 1.0 + """ + return Expr(_f.sec(col.expr)) + + +def negative(col: Expr) -> Expr: + """Spark ``negative``: unary minus. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select(dfn.functions.spark.negative(dfn.lit(3)).alias("v")) + >>> r.collect_column("v")[0].as_py() + -3 + """ + return Expr(_f.negative(col.expr)) + + +def bin(col: Expr) -> Expr: + """Spark ``bin``: binary string representation of a long. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select(dfn.functions.spark.bin(dfn.lit(7)).alias("v")) + >>> r.collect_column("v")[0].as_py() + '111' + """ + return Expr(_f.bin(col.expr)) + + +# --------------------------------------------------------------------------- +# String functions +# --------------------------------------------------------------------------- + + +def ascii(col: Expr) -> Expr: + """Spark ``ascii``: code point of the first character. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select(dfn.functions.spark.ascii(dfn.lit("A")).alias("v")) + >>> r.collect_column("v")[0].as_py() + 65 + """ + return Expr(_f.ascii(col.expr)) + + +def base64(col: Expr) -> Expr: + """Spark ``base64``: encode binary as a base64 string. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select(dfn.functions.spark.base64(dfn.lit(b"hi")).alias("v")) + >>> r.collect_column("v")[0].as_py() + 'aGk=' + """ + return Expr(_f.base64(col.expr)) + + +def char(col: Expr) -> Expr: + """Spark ``char``: ASCII character for a code point (mod 256). + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select(dfn.functions.spark.char(dfn.lit(65)).alias("v")) + >>> r.collect_column("v")[0].as_py() + 'A' + """ + return Expr(_f.char(col.expr)) + + +def concat(*cols: Expr) -> Expr: + """Spark ``concat``: concatenates strings; NULL if any input is NULL. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.concat(dfn.lit("a"), dfn.lit("b")).alias("v")) + >>> r.collect_column("v")[0].as_py() + 'ab' + """ + return Expr(_f.concat(*[c.expr for c in cols])) + + +def elt(*inputs: Expr) -> Expr: + """Spark ``elt``: returns the n-th input (1-indexed). + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.elt( + ... dfn.lit(2), dfn.lit("a"), dfn.lit("b") + ... ).alias("v") + ... ) + >>> r.collect_column("v")[0].as_py() + 'b' + """ + return Expr(_f.elt(*[i.expr for i in inputs])) + + +def ilike( + str: Expr, + pattern: Expr, + escapeChar: str | None = None, # noqa: N803 +) -> Expr: + """Spark ``ilike``: case-insensitive pattern match. + + ``escapeChar`` is accepted for pyspark parity but is not yet wired through + the Rust binding; passing a non-``None`` value raises ``NotImplementedError``. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.ilike(dfn.lit("HELLO"), dfn.lit("h%")).alias("v")) + >>> r.collect_column("v")[0].as_py() + True + """ + if escapeChar is not None: + msg = "ilike(escapeChar=...) is not yet supported by the Spark UDF binding" + raise NotImplementedError(msg) + return Expr(_f.ilike(str.expr, pattern.expr)) + + +def length(col: Expr) -> Expr: + """Spark ``length``: character length of a string, or byte length of binary. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select(dfn.functions.spark.length(dfn.lit("hello")).alias("v")) + >>> r.collect_column("v")[0].as_py() + 5 + """ + return Expr(_f.length(col.expr)) + + +def like( + str: Expr, + pattern: Expr, + escapeChar: str | None = None, # noqa: N803 +) -> Expr: + """Spark ``like``: case-sensitive pattern match. + + ``escapeChar`` is accepted for pyspark parity but is not yet wired through + the Rust binding; passing a non-``None`` value raises ``NotImplementedError``. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.like(dfn.lit("hello"), dfn.lit("h%")).alias("v")) + >>> r.collect_column("v")[0].as_py() + True + """ + if escapeChar is not None: + msg = "like(escapeChar=...) is not yet supported by the Spark UDF binding" + raise NotImplementedError(msg) + return Expr(_f.like(str.expr, pattern.expr)) + + +def luhn_check(col: Expr) -> Expr: + """Spark ``luhn_check``: true if the digit string passes the Luhn check. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.luhn_check( + ... dfn.lit("4111111111111111") + ... ).alias("v") + ... ) + >>> r.collect_column("v")[0].as_py() + True + """ + return Expr(_f.luhn_check(col.expr)) + + +def format_string(format: str | Expr, *cols: Expr) -> Expr: + """Spark ``format_string``: printf-style format string. + + ``format`` is the printf-style template (a plain ``str`` is auto-promoted + to a literal expression); remaining args are values to substitute. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.format_string( + ... "%d-%s", dfn.lit(42), dfn.lit("hi") + ... ).alias("v") + ... ) + >>> r.collect_column("v")[0].as_py() + '42-hi' + """ + fmt_expr = format if isinstance(format, Expr) else Expr.literal(format) + return Expr(_f.format_string(fmt_expr.expr, *[c.expr for c in cols])) + + +def space(col: Expr) -> Expr: + """Spark ``space``: string of n spaces. + + Examples: + >>> import pyarrow as pa + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.space( + ... dfn.lit(pa.scalar(3, type=pa.int32())) + ... ).alias("v") + ... ) + >>> r.collect_column("v")[0].as_py() + ' ' + """ + return Expr(_f.space(col.expr)) + + +def substring(str: Expr, pos: Expr, len: Expr) -> Expr: + """Spark ``substring``: 1-indexed substring starting at ``pos`` of given ``length``. + + Negative ``pos`` counts from the end. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.substring( + ... dfn.lit("hello"), dfn.lit(1), dfn.lit(3) + ... ).alias("v") + ... ) + >>> r.collect_column("v")[0].as_py() + 'hel' + """ + return Expr(_f.substring(str.expr, pos.expr, len.expr)) + + +def unbase64(col: Expr) -> Expr: + """Spark ``unbase64``: decode a base64 string to binary. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.unbase64(dfn.lit("aGk=")).alias("v")) + >>> r.collect_column("v")[0].as_py() + b'hi' + """ + return Expr(_f.unbase64(col.expr)) + + +def soundex(col: Expr) -> Expr: + """Spark ``soundex``: Soundex phonetic code. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select(dfn.functions.spark.soundex(dfn.lit("Robert")).alias("v")) + >>> r.collect_column("v")[0].as_py() + 'R163' + """ + return Expr(_f.soundex(col.expr)) + + +def is_valid_utf8(str: Expr) -> Expr: + """Spark ``is_valid_utf8``: true if the string is valid UTF-8. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.is_valid_utf8(dfn.lit("hello")).alias("v")) + >>> r.collect_column("v")[0].as_py() + True + """ + return Expr(_f.is_valid_utf8(str.expr)) + + +def make_valid_utf8(str: Expr) -> Expr: + """Spark ``make_valid_utf8``: replace invalid UTF-8 bytes with U+FFFD. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.make_valid_utf8(dfn.lit("hello")).alias("v")) + >>> r.collect_column("v")[0].as_py() + 'hello' + """ + return Expr(_f.make_valid_utf8(str.expr)) + + +# --------------------------------------------------------------------------- +# URL functions +# --------------------------------------------------------------------------- + + +def parse_url( + url: Expr, + partToExtract: Expr, # noqa: N803 + key: Expr | None = None, +) -> Expr: + """Spark ``parse_url``: extract a part from a URL; errors on invalid URLs. + + ``partToExtract`` is one of ``"HOST"``, ``"PATH"``, ``"QUERY"``, + ``"REF"``, ``"PROTOCOL"``, ``"FILE"``, ``"AUTHORITY"``, ``"USERINFO"``. + Pass ``key`` only with ``"QUERY"`` to extract a single parameter. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.parse_url( + ... dfn.lit("http://example.com/path?q=1"), dfn.lit("HOST") + ... ).alias("v") + ... ) + >>> r.collect_column("v")[0].as_py() + 'example.com' + + >>> r = df.select( + ... dfn.functions.spark.parse_url( + ... dfn.lit("http://example.com/path?q=1"), + ... dfn.lit("QUERY"), + ... key=dfn.lit("q"), + ... ).alias("v") + ... ) + >>> r.collect_column("v")[0].as_py() + '1' + """ + if key is None: + return Expr(_f.parse_url(url.expr, partToExtract.expr)) + return Expr(_f.parse_url(url.expr, partToExtract.expr, key.expr)) + + +def try_parse_url( + url: Expr, + partToExtract: Expr, # noqa: N803 + key: Expr | None = None, +) -> Expr: + """Spark ``try_parse_url``: like ``parse_url`` but returns NULL on invalid URLs. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.try_parse_url( + ... dfn.lit("http://example.com/"), dfn.lit("HOST") + ... ).alias("v") + ... ) + >>> r.collect_column("v")[0].as_py() + 'example.com' + """ + if key is None: + return Expr(_f.try_parse_url(url.expr, partToExtract.expr)) + return Expr(_f.try_parse_url(url.expr, partToExtract.expr, key.expr)) + + +def url_decode(str: Expr) -> Expr: + """Spark ``url_decode``: decode an application/x-www-form-urlencoded string. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.url_decode(dfn.lit("a%20b")).alias("v")) + >>> r.collect_column("v")[0].as_py() + 'a b' + """ + return Expr(_f.url_decode(str.expr)) + + +def try_url_decode(str: Expr) -> Expr: + """Spark ``try_url_decode``: like ``url_decode``; returns NULL on invalid input. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.try_url_decode(dfn.lit("a%20b")).alias("v")) + >>> r.collect_column("v")[0].as_py() + 'a b' + """ + return Expr(_f.try_url_decode(str.expr)) + + +def url_encode(str: Expr) -> Expr: + """Spark ``url_encode``: encode a string in application/x-www-form-urlencoded. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"x": [1]}) + >>> r = df.select( + ... dfn.functions.spark.url_encode(dfn.lit("a b")).alias("v")) + >>> r.collect_column("v")[0].as_py() + 'a+b' + """ + return Expr(_f.url_encode(str.expr)) + + +__all__ = [ + # Math + "abs", + # Datetime + "add_months", + "array", + # Array + "array_contains", + "array_repeat", + # String + "ascii", + # Aggregate + "avg", + "base64", + "bin", + "bit_count", + # Bitwise + "bit_get", + "bitmap_bit_position", + "bitmap_bucket_number", + # Bitmap + "bitmap_count", + "bitwise_not", + "ceil", + "char", + "collect_list", + "collect_set", + "concat", + # Hash + "crc32", + "csc", + "date_add", + "date_diff", + "date_part", + "date_sub", + "date_trunc", + "elt", + "expm1", + "factorial", + "floor", + "format_string", + "from_utc_timestamp", + "hex", + "hour", + "if_", + "ilike", + "is_valid_utf8", + # JSON + "json_tuple", + "last_day", + "length", + "like", + "luhn_check", + "make_dt_interval", + "make_interval", + "make_valid_utf8", + # Map + "map_from_arrays", + "map_from_entries", + "minute", + "modulus", + "negative", + "next_day", + # URL + "parse_url", + "pmod", + "rint", + "round", + "sec", + "second", + "sha1", + "sha2", + "shiftleft", + "shiftright", + "shiftrightunsigned", + "shuffle", + # Collection / Conditional / Conversion + "size", + "slice", + "soundex", + "space", + "spark_cast", + "str_to_map", + "substring", + "time_trunc", + "to_utc_timestamp", + "trunc", + "try_parse_url", + "try_sum", + "try_url_decode", + "unbase64", + "unhex", + "unix_date", + "unix_micros", + "unix_millis", + "unix_seconds", + "url_decode", + "url_encode", + "width_bucket", + "xxhash64", +] diff --git a/python/tests/test_spark_functions.py b/python/tests/test_spark_functions.py new file mode 100644 index 000000000..193d21b2c --- /dev/null +++ b/python/tests/test_spark_functions.py @@ -0,0 +1,347 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Tests for the Spark-compatible function bindings.""" + +import pyarrow as pa +import pytest +from datafusion import SessionContext, col, lit +from datafusion import functions as f +from datafusion.functions import spark + + +@pytest.fixture +def df(): + ctx = SessionContext() + batch = pa.RecordBatch.from_arrays( + [ + pa.array(["hello", "WORLD", "abc"]), + pa.array([1, 2, 3]), + pa.array([-5, 0, 7]), + pa.array([1.0, 2.5, 3.0]), + pa.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), + ], + names=["s", "i", "n", "f", "a"], + ) + return ctx.create_dataframe([[batch]]) + + +def _val(df, expr): + return df.select(expr.alias("v")).collect_column("v")[0].as_py() + + +# --------------------------------------------------------------------------- +# Math +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + ("expr_factory", "expected"), + [ + (lambda: spark.abs(lit(-5)), 5), + (lambda: spark.ceil(lit(1.2)), 2), + (lambda: spark.floor(lit(1.8)), 1), + (lambda: spark.bin(lit(7)), "111"), + (lambda: spark.hex(lit(255)), "FF"), + (lambda: spark.modulus(lit(10), lit(3)), 1), + (lambda: spark.pmod(lit(-1), lit(3)), 2), + (lambda: spark.rint(lit(2.5)), 2.0), + (lambda: spark.round(lit(2.5), lit(0)), 3.0), + (lambda: spark.negative(lit(3)), -3), + ], +) +def test_math(df, expr_factory, expected): + assert _val(df, expr_factory()) == expected + + +def test_factorial(df): + # factorial wants Int32; lit(int) is Int64 by default. + expr = spark.factorial(lit(pa.scalar(5, type=pa.int32()))) + assert _val(df, expr) == 120 + + +# --------------------------------------------------------------------------- +# String +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + ("expr_factory", "expected"), + [ + (lambda: spark.ascii(lit("A")), 65), + (lambda: spark.char(lit(65)), "A"), + (lambda: spark.length(lit("hello")), 5), + (lambda: spark.like(lit("hello"), lit("h%")), True), + (lambda: spark.ilike(lit("HELLO"), lit("h%")), True), + (lambda: spark.substring(lit("hello"), lit(1), lit(3)), "hel"), + (lambda: spark.soundex(lit("Robert")), "R163"), + (lambda: spark.is_valid_utf8(lit("hi")), True), + (lambda: spark.concat(lit("a"), lit("b")), "ab"), + (lambda: spark.elt(lit(2), lit("a"), lit("b")), "b"), + ], +) +def test_string(df, expr_factory, expected): + assert _val(df, expr_factory()) == expected + + +def test_space(df): + expr = spark.space(lit(pa.scalar(3, type=pa.int32()))) + assert _val(df, expr) == " " + + +# --------------------------------------------------------------------------- +# Hash +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + ("expr_factory", "expected"), + [ + ( + lambda: spark.sha1(lit("hello")), + "aaf4c61ddcc5e8a2dabede0f3b482cd9aea9434d", + ), + ( + lambda: spark.sha2(lit("hello"), lit(256)), + "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824", + ), + (lambda: spark.crc32(lit("ABC")), 2743272264), + ], +) +def test_hash(df, expr_factory, expected): + assert _val(df, expr_factory()) == expected + + +# --------------------------------------------------------------------------- +# Bitwise +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + ("expr_factory", "expected"), + [ + (lambda: spark.bit_count(lit(7)), 3), + (lambda: spark.bit_get(lit(5), lit(0)), 1), + (lambda: spark.bitwise_not(lit(0)), -1), + (lambda: spark.shiftleft(lit(1), lit(3)), 8), + (lambda: spark.shiftright(lit(8), lit(2)), 2), + (lambda: spark.shiftrightunsigned(lit(8), lit(2)), 2), + ], +) +def test_bitwise(df, expr_factory, expected): + assert _val(df, expr_factory()) == expected + + +# --------------------------------------------------------------------------- +# Array / collection / conditional +# --------------------------------------------------------------------------- + + +def test_array_and_size(df): + arr = spark.array(lit(1), lit(2), lit(3)) + assert _val(df, arr) == [1, 2, 3] + assert _val(df, spark.size(arr)) == 3 + assert _val(df, spark.array_contains(arr, lit(2))) is True + + +def test_slice(df): + arr = spark.array(lit(1), lit(2), lit(3), lit(4)) + assert _val(df, spark.slice(arr, lit(2), lit(2))) == [2, 3] + + +def test_array_repeat(df): + assert _val(df, spark.array_repeat(lit("a"), lit(3))) == ["a", "a", "a"] + + +def test_if(df): + assert _val(df, spark.if_(lit(True), lit("yes"), lit("no"))) == "yes" + assert _val(df, spark.if_(lit(False), lit("yes"), lit("no"))) == "no" + + +# --------------------------------------------------------------------------- +# Aggregate +# --------------------------------------------------------------------------- + + +def test_aggregate(df): + r = df.aggregate( + [], + [ + spark.avg(col("f")).alias("avg"), + spark.try_sum(col("i")).alias("sum"), + ], + ).collect() + rec = pa.Table.from_batches(r) + assert rec.column("avg")[0].as_py() == pytest.approx(2.1666666, rel=1e-3) + assert rec.column("sum")[0].as_py() == 6 + + +def test_collect_list_set(df): + r = df.aggregate( + [], + [ + spark.collect_list(col("i")).alias("cl"), + spark.collect_set(col("i")).alias("cs"), + ], + ).collect() + rec = pa.Table.from_batches(r) + assert sorted(rec.column("cl")[0].as_py()) == [1, 2, 3] + assert sorted(rec.column("cs")[0].as_py()) == [1, 2, 3] + + +# --------------------------------------------------------------------------- +# Spark-semantics conflicts vs DataFusion defaults +# --------------------------------------------------------------------------- + + +def test_concat_null_propagates(): + """Spark concat returns NULL on any NULL input; default skips NULLs.""" + ctx = SessionContext() + df = ctx.from_pydict({"x": [1]}) + default_out = ( + df.select(f.concat(lit("a"), lit(None), lit("b")).alias("v")) + .collect_column("v")[0] + .as_py() + ) + spark_out = _val(df, spark.concat(lit("a"), lit(None), lit("b"))) + assert default_out == "ab" + assert spark_out is None + + +def test_round_half_up(): + """Spark round uses HALF_UP rounding (2.5 → 3).""" + ctx = SessionContext() + df = ctx.from_pydict({"x": [1]}) + assert _val(df, spark.round(lit(2.5), lit(0))) == 3.0 + + +# --------------------------------------------------------------------------- +# Optional parameter defaults / NotImplementedError +# --------------------------------------------------------------------------- + + +def test_round_scale_default(): + """spark.round defaults scale to 0.""" + ctx = SessionContext() + df = ctx.from_pydict({"x": [1]}) + assert _val(df, spark.round(lit(2.5))) == 3.0 + + +def test_make_dt_interval_defaults(): + """spark.make_dt_interval with no args returns a zero day-time interval.""" + import datetime as dt + + ctx = SessionContext() + df = ctx.from_pydict({"x": [1]}) + assert _val(df, spark.make_dt_interval()) == dt.timedelta(0) + + +def test_make_interval_defaults(): + """spark.make_interval with no args returns a zero interval.""" + ctx = SessionContext() + df = ctx.from_pydict({"x": [1]}) + assert _val(df, spark.make_interval()).months == 0 + + +def test_str_to_map_defaults(): + """spark.str_to_map defaults delimiters to ',' and ':'.""" + ctx = SessionContext() + df = ctx.from_pydict({"x": [1]}) + assert _val(df, spark.str_to_map(lit("a:1,b:2"))) == [("a", "1"), ("b", "2")] + + +def test_shuffle_seed_raises(): + """spark.shuffle(seed=...) raises NotImplementedError until Rust supports it.""" + with pytest.raises(NotImplementedError, match="seed"): + spark.shuffle(spark.array(lit(1), lit(2)), seed=1) + + +def test_like_escape_raises(): + """spark.like/ilike escapeChar raises NotImplementedError until Rust supports.""" + with pytest.raises(NotImplementedError, match="escapeChar"): + spark.like(lit("a"), lit("a"), escapeChar="\\") + with pytest.raises(NotImplementedError, match="escapeChar"): + spark.ilike(lit("a"), lit("a"), escapeChar="\\") + + +def test_parse_url_three_arg(): + """parse_url(url, partToExtract, key=...) extracts query parameters.""" + ctx = SessionContext() + df = ctx.from_pydict({"x": [1]}) + url = lit("http://example.com/p?q=hello&n=1") + assert _val(df, spark.parse_url(url, lit("QUERY"), key=lit("q"))) == "hello" + assert _val(df, spark.try_parse_url(url, lit("QUERY"), key=lit("n"))) == "1" + + +def test_format_string_plain_str_format(): + """format_string accepts a plain str format that is auto-promoted to lit.""" + ctx = SessionContext() + df = ctx.from_pydict({"x": [1]}) + assert _val(df, spark.format_string("%d-%s", lit(42), lit("hi"))) == "42-hi" + + +def test_aggregate_positional_compat(): + """Pyspark-style positional calls still work after the rename to ``col``.""" + ctx = SessionContext() + df = ctx.from_pydict({"a": [1.0, 2.0, 3.0]}) + out = df.aggregate( + [], + [ + spark.avg(col("a")).alias("av"), + spark.try_sum(col("a")).alias("ts"), + spark.collect_list(col("a")).alias("cl"), + spark.collect_set(col("a")).alias("cs"), + ], + ).collect() + rec = pa.Table.from_batches(out) + assert rec.column("av")[0].as_py() == 2.0 + assert rec.column("ts")[0].as_py() == 6.0 + + +# --------------------------------------------------------------------------- +# SQL path via enable_spark_functions +# --------------------------------------------------------------------------- + + +def test_sql_requires_enable(): + """Spark-only function (xxhash64) is not in SQL namespace by default.""" + ctx = SessionContext() + with pytest.raises(Exception, match=r"(?i)xxhash64|Invalid function"): + ctx.sql("SELECT xxhash64('hello')").collect() + + +def test_sql_after_enable(): + ctx = SessionContext() + ctx.enable_spark_functions() + out = ctx.sql("SELECT sha2('hello', 256) AS h").collect_column("h")[0].as_py() + assert out == ("2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824") + + +def test_sql_concat_semantics_override(): + """After enable, SQL concat propagates NULL like Spark.""" + ctx = SessionContext() + default_out = ( + ctx.sql("SELECT concat('a', NULL, 'b') AS c").collect_column("c")[0].as_py() + ) + assert default_out == "ab" + + ctx2 = SessionContext() + ctx2.enable_spark_functions() + spark_out = ( + ctx2.sql("SELECT concat('a', NULL, 'b') AS c").collect_column("c")[0].as_py() + ) + assert spark_out is None diff --git a/skills/datafusion_python/SKILL.md b/skills/datafusion_python/SKILL.md index 1aeb78777..006c028d1 100644 --- a/skills/datafusion_python/SKILL.md +++ b/skills/datafusion_python/SKILL.md @@ -26,12 +26,14 @@ can interoperate with DataFusion. | `DataFrame` | Lazy query builder. Each method returns a new DataFrame. | Returned by context methods | | `Expr` | Expression tree node (column ref, literal, function call, ...). | `from datafusion import col, lit` | | `functions` | 290+ built-in scalar, aggregate, and window functions. | `from datafusion import functions as F` | +| `functions.spark` | PySpark-compatible function surface (parameter names match `pyspark.sql.functions`). | `from datafusion.functions import spark` | ## Import Conventions ```python from datafusion import SessionContext, col, lit from datafusion import functions as F +from datafusion.functions import spark # only when porting pyspark code ``` ## Data Loading @@ -762,3 +764,51 @@ F.left(col("c_phone"), lit(2)) # prefix shortcut **Other**: `in_list`, `order_by`, `alias`, `col`, `encode`, `decode`, `to_hex`, `to_char`, `uuid`, `version`, `bit_length`, `octet_length` + +### Spark-Compatible Functions + +A separate `datafusion.functions.spark` namespace mirrors the +`pyspark.sql.functions` API for callers porting code from PySpark. + +```python +from datafusion.functions import spark +``` + +Use it for DataFrame work; for SQL, register the Spark UDFs first: + +```python +ctx = SessionContext() +ctx.enable_spark_functions() # makes Spark UDFs visible to SQL +ctx.sql("SELECT sha2('hello', 256)").show() +``` + +Coverage spans aggregate, array, bitmap, bitwise, datetime, hash, JSON, +map, math, string, URL, and conditional categories. The authoritative +list of what is currently exposed is the `__all__` in +`python/datafusion/functions/spark.py`: + +```bash +python -c "from datafusion.functions import spark; print(sorted(spark.__all__))" +``` + +When you need to know whether a specific pyspark function is available, +check `__all__` rather than this skill — the list there moves with the +code; any enumeration here would drift. + +**Semantic divergences vs the default namespace.** Functions that exist in +both `functions` and `functions.spark` may behave differently: + +| Function | Default `functions` | `functions.spark` | +|---|---|---| +| `concat` | NULL inputs treated as empty | NULL inputs propagate to NULL | +| `round` | HALF_EVEN (banker's) | HALF_UP | +| `trunc` | Numeric truncation | Date truncation | +| `substring` | 1-indexed | 1-indexed (parity) | + +Pick the namespace whose semantics match your intent — both stay imported +side by side; `enable_spark_functions()` only affects SQL. + +**Parameter names match pyspark exactly.** The spark namespace uses +pyspark parameter names (`col`, `str`, `numBits`, `partToExtract`, ...) so +you can paste pyspark code and keep keyword arguments working. The default +namespace keeps DataFusion's parameter names.