diff --git a/docs/source/contributor-guide/expression-audits/agg_funcs.md b/docs/source/contributor-guide/expression-audits/agg_funcs.md index fb27662b49..47764dd1a3 100644 --- a/docs/source/contributor-guide/expression-audits/agg_funcs.md +++ b/docs/source/contributor-guide/expression-audits/agg_funcs.md @@ -39,6 +39,13 @@ - Spark 3.5.8 (2026-05-26) - Spark 4.0.1 (2026-05-26) +## max_by + +- Spark 3.4.3 (2026-07-03): `MaxBy` is a 2-argument `DeclarativeAggregate` registered as `expression[MaxBy]("max_by")`. Buffer is `(valueWithExtremumOrdering, extremumOrdering)`; null orderings are ignored, the value paired with the maximum ordering is returned (and may itself be null), and an all-null-ordering group yields null. Comet implements a native `max_by` aggregate. Only fixed-length value and ordering types are accelerated: a variable-length or nested type (string, binary, struct) forces Spark's `SortAggregate`, which Comet does not support, so those cases fall back to Spark. `max_by` is non-deterministic when several rows tie on the maximum ordering, matching Spark's documented behavior. +- Spark 3.5.8 (2026-07-03): aggregate logic identical to 3.4.3. +- Spark 4.0.1 (2026-07-03): aggregate logic identical to 3.4.3; only the `@ExpressionDescription` example and note text differ. +- Spark 4.1.1 (2026-07-03): aggregate logic identical to 3.4.3. The 3-argument top-k form `max_by(x, y, k)` (via `MaxByBuilder` / `MaxMinByK`) is only present on Spark master, not in any released 3.4 through 4.1 version, so Comet handles only the 2-argument form. + ## median - Spark 3.4.3 (audited 2026-06-24): `Median(child)` is a `RuntimeReplaceableAggregate` with `replacement = Percentile(child, Literal(0.5))`. Catalyst rewrites `median(x)` to `percentile(x, 0.5)` before Comet sees the plan, so it is served by `CometPercentile`. @@ -46,6 +53,13 @@ - Spark 4.0.1 (audited 2026-06-24): `replacement` becomes `lazy val`; semantics unchanged. - Spark 4.1.1 (audited 2026-06-24): identical to 4.0.1. +## min_by + +- Spark 3.4.3 (2026-07-03): `MinBy` shares the abstract `MaxMinBy` `DeclarativeAggregate` with `MaxBy`, differing only in the comparison direction (`least` / `<` instead of `greatest` / `>`). Registered as `expression[MinBy]("min_by")`. Null orderings are ignored, the value paired with the minimum ordering is returned (and may itself be null), and an all-null-ordering group yields null. Comet serves it through the same native `MaxMinBy` aggregate as `max_by`, with the same fixed-length value and ordering restriction (variable-length or nested types fall back to Spark). Non-deterministic on ties, matching Spark. +- Spark 3.5.8 (2026-07-03): aggregate logic identical to 3.4.3. +- Spark 4.0.1 (2026-07-03): aggregate logic identical to 3.4.3; only the `@ExpressionDescription` example and note text differ. +- Spark 4.1.1 (2026-07-03): aggregate logic identical to 3.4.3. The 3-argument top-k form `min_by(x, y, k)` (via `MinByBuilder` / `MaxMinByK`) is only present on Spark master, so Comet handles only the 2-argument form. + ## percentile - Spark 3.4.3 (audited 2026-06-24): `Percentile(child, percentageExpression, frequencyExpression, ..., reverse)` over `PercentileBase`. Exact percentile using `index = p * (n - 1)` linear interpolation, NULL inputs skipped, empty/all-null group returns NULL. `CometPercentile` maps the single-literal-percentage, default-frequency, numeric-input, ascending form to DataFusion's `percentile_cont` (same interpolation). Array-of-percentages, a non-default frequency argument, descending order, and interval inputs fall back to Spark. diff --git a/docs/source/user-guide/latest/expressions.md b/docs/source/user-guide/latest/expressions.md index 2ca3a13c62..9a172f0473 100644 --- a/docs/source/user-guide/latest/expressions.md +++ b/docs/source/user-guide/latest/expressions.md @@ -95,11 +95,11 @@ The tables below list every Spark built-in expression with its current status. | `last_value` | ✅ | | | `listagg` | 🔜 | String aggregation | | `max` | ✅ | | -| `max_by` | 🔜 | [#3841](https://github.com/apache/datafusion-comet/issues/3841) | +| `max_by` | ✅ | Value and ordering must be fixed-length types | | `mean` | ✅ | | | `median` | ✅ | Rewrites to `percentile(col, 0.5)`; falls back by default, opt-in via allowIncompatible ([#4719](https://github.com/apache/datafusion-comet/issues/4719)) | | `min` | ✅ | | -| `min_by` | 🔜 | [#3841](https://github.com/apache/datafusion-comet/issues/3841) | +| `min_by` | ✅ | Value and ordering must be fixed-length types | | `mode` | 🔜 | [#3970](https://github.com/apache/datafusion-comet/issues/3970) | | `percentile` | ✅ | Single literal percentage on numeric input; array of percentages and a frequency argument fall back to Spark. Falls back by default, opt-in via allowIncompatible ([#4719](https://github.com/apache/datafusion-comet/issues/4719)) | | `percentile_cont` | ✅ | Spark 4.0+ `WITHIN GROUP (ORDER BY ...)`; ascending only, `DESC` falls back to Spark. Falls back by default, opt-in via allowIncompatible ([#4719](https://github.com/apache/datafusion-comet/issues/4719)) | diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 25162332fd..77ec85ac0e 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -130,8 +130,8 @@ use datafusion_comet_proto::{ use datafusion_comet_spark_expr::{ jvm_udf::JvmScalarUdfExpr, ArrayInsert, Avg, AvgDecimal, Cast, CheckOverflow, Correlation, Covariance, CreateNamedStruct, DecimalRescaleCheckOverflow, GetArrayStructFields, - GetStructField, IfExpr, ListExtract, NormalizeNaNAndZero, SparkCastOptions, Stddev, SumDecimal, - ToJson, UnboundColumn, Variance, WideDecimalBinaryExpr, WideDecimalOp, + GetStructField, IfExpr, ListExtract, MaxMinBy, NormalizeNaNAndZero, SparkCastOptions, Stddev, + SumDecimal, ToJson, UnboundColumn, Variance, WideDecimalBinaryExpr, WideDecimalOp, }; use itertools::Itertools; use jni::objects::{Global, JObject}; @@ -2653,6 +2653,20 @@ impl PhysicalPlanner { let func = AggregateUDF::new_from_impl(SparkCollectSet::new()); Self::create_aggr_func_expr("collect_set", schema, vec![child], func) } + AggExprStruct::MaxBy(expr) => { + let value = self.create_expr(expr.value.as_ref().unwrap(), Arc::clone(&schema))?; + let ordering = + self.create_expr(expr.ordering.as_ref().unwrap(), Arc::clone(&schema))?; + let func = AggregateUDF::new_from_impl(MaxMinBy::new_max_by()); + Self::create_aggr_func_expr("max_by", schema, vec![value, ordering], func) + } + AggExprStruct::MinBy(expr) => { + let value = self.create_expr(expr.value.as_ref().unwrap(), Arc::clone(&schema))?; + let ordering = + self.create_expr(expr.ordering.as_ref().unwrap(), Arc::clone(&schema))?; + let func = AggregateUDF::new_from_impl(MaxMinBy::new_min_by()); + Self::create_aggr_func_expr("min_by", schema, vec![value, ordering], func) + } } } diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 32adc16b72..300183f88d 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -146,6 +146,8 @@ message AggExpr { BloomFilterAgg bloomFilterAgg = 16; CollectSet collectSet = 17; Percentile percentile = 18; + MaxBy maxBy = 19; + MinBy minBy = 20; } // Optional filter expression for SQL FILTER (WHERE ...) clause. @@ -277,6 +279,20 @@ message CollectSet { DataType datatype = 2; } +message MaxBy { + // The value returned by the aggregate (associated with the maximum ordering). + Expr value = 1; + // The ordering expression whose maximum selects the returned value. + Expr ordering = 2; +} + +message MinBy { + // The value returned by the aggregate (associated with the minimum ordering). + Expr value = 1; + // The ordering expression whose minimum selects the returned value. + Expr ordering = 2; +} + enum EvalMode { LEGACY = 0; TRY = 1; diff --git a/native/spark-expr/src/agg_funcs/max_min_by.rs b/native/spark-expr/src/agg_funcs/max_min_by.rs new file mode 100644 index 0000000000..f3639675b6 --- /dev/null +++ b/native/spark-expr/src/agg_funcs/max_min_by.rs @@ -0,0 +1,632 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{new_null_array, Array, ArrayRef, BooleanArray}; +use arrow::compute::SortOptions; +use arrow::datatypes::{DataType, Field, FieldRef}; +use arrow::row::{OwnedRow, RowConverter, SortField}; +use datafusion::common::{Result, ScalarValue}; +use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion::logical_expr::{ + Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, Signature, Volatility, +}; +use datafusion::physical_expr::expressions::format_state_name; +use std::mem::size_of_val; +use std::sync::Arc; + +/// Spark-compatible `max_by(value, ordering)` / `min_by(value, ordering)` aggregate. +/// +/// Returns the `value` associated with the maximum (`max_by`) or minimum (`min_by`) +/// non-null `ordering`. Rows with a null `ordering` are ignored. The returned value +/// may itself be null when it is the value paired with the extremum ordering. If every +/// `ordering` in the group is null, the result is null. +/// +/// Spark's `MaxBy`/`MinBy` are `DeclarativeAggregate`s that keep a `(value, ordering)` +/// buffer and, on a tie in the ordering, the later row wins. Because ties across +/// partitions are processed in an unspecified order, Spark documents the function as +/// non-deterministic when several rows share the extremum ordering. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct MaxMinBy { + name: String, + signature: Signature, + /// `true` for `max_by`, `false` for `min_by`. + is_max: bool, +} + +impl std::hash::Hash for MaxMinBy { + fn hash(&self, state: &mut H) { + self.name.hash(state); + self.signature.hash(state); + self.is_max.hash(state); + } +} + +impl MaxMinBy { + /// Create a `max_by` aggregate. + pub fn new_max_by() -> Self { + Self { + name: "max_by".to_string(), + signature: Signature::any(2, Volatility::Immutable), + is_max: true, + } + } + + /// Create a `min_by` aggregate. + pub fn new_min_by() -> Self { + Self { + name: "min_by".to_string(), + signature: Signature::any(2, Volatility::Immutable), + is_max: false, + } + } +} + +impl AggregateUDFImpl for MaxMinBy { + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + // The result has the same type as the `value` argument. + Ok(arg_types[0].clone()) + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + let value_type = acc_args.exprs[0].data_type(acc_args.schema)?; + let ordering_type = acc_args.exprs[1].data_type(acc_args.schema)?; + Ok(Box::new(MaxMinByAccumulator::try_new( + value_type, + ordering_type, + self.is_max, + )?)) + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + let value_type = args.input_fields[0].data_type().clone(); + let ordering_type = args.input_fields[1].data_type().clone(); + Ok(vec![ + Arc::new(Field::new( + format_state_name(&self.name, "value"), + value_type, + true, + )), + Arc::new(Field::new( + format_state_name(&self.name, "ordering"), + ordering_type, + true, + )), + ]) + } + + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { + true + } + + fn create_groups_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + let value_type = args.exprs[0].data_type(args.schema)?; + let ordering_type = args.exprs[1].data_type(args.schema)?; + Ok(Box::new(MaxMinByGroupsAccumulator::try_new( + value_type, + ordering_type, + self.is_max, + )?)) + } +} + +/// Sort options that make the wanted extremum encode to the largest row bytes: ascending for +/// `max_by` (largest ordering wins), descending for `min_by` (smallest ordering wins). Nulls +/// sort first (smallest) so they are never selected as the extremum; null orderings are also +/// skipped explicitly. +fn extremum_sort_options(is_max: bool) -> SortOptions { + SortOptions { + descending: !is_max, + nulls_first: true, + } +} + +/// Accumulator that tracks the running `(value, ordering)` pair for the extremum ordering. +#[derive(Debug)] +struct MaxMinByAccumulator { + /// The value paired with the current extremum ordering. May be null. + value: ScalarValue, + /// The current extremum ordering. Null means no non-null ordering has been seen yet. + ordering: ScalarValue, + /// `true` for `max_by`, `false` for `min_by`. + is_max: bool, +} + +impl MaxMinByAccumulator { + fn try_new(value_type: DataType, ordering_type: DataType, is_max: bool) -> Result { + Ok(Self { + value: ScalarValue::try_from(&value_type)?, + ordering: ScalarValue::try_from(&ordering_type)?, + is_max, + }) + } + + fn sort_options(&self) -> SortOptions { + // Encode the ordering column into arrow's row format so that the extremum can be + // found for any orderable type with a single comparison. + extremum_sort_options(self.is_max) + } + + /// Apply a batch of `(value, ordering)` columns, keeping the value paired with the + /// extremum ordering. Rows with a null ordering are ignored. + fn update_from(&mut self, value_arr: &ArrayRef, ordering_arr: &ArrayRef) -> Result<()> { + if ordering_arr.is_empty() { + return Ok(()); + } + + let converter = RowConverter::new(vec![SortField::new_with_options( + ordering_arr.data_type().clone(), + self.sort_options(), + )])?; + let rows = converter.convert_columns(&[Arc::clone(ordering_arr)])?; + + // Find the index of the extremum ordering in this batch (last one wins on a tie, + // matching Spark's sequential row processing), ignoring null orderings. + let mut best: Option = None; + for i in 0..ordering_arr.len() { + if ordering_arr.is_null(i) { + continue; + } + best = match best { + None => Some(i), + Some(b) if rows.row(i) >= rows.row(b) => Some(i), + Some(b) => Some(b), + }; + } + + let Some(b) = best else { + return Ok(()); + }; + + let candidate_ordering = ScalarValue::try_from_array(ordering_arr, b)?; + let take = if self.ordering.is_null() { + true + } else { + // Compare the batch's extremum ordering against the running extremum using the + // same row encoding. Build a two-row array [running, candidate] and compare. + let pair = ScalarValue::iter_to_array(vec![ + self.ordering.clone(), + candidate_ordering.clone(), + ])?; + let pair_rows = converter.convert_columns(&[pair])?; + pair_rows.row(1) >= pair_rows.row(0) + }; + + if take { + self.value = ScalarValue::try_from_array(value_arr, b)?; + self.ordering = candidate_ordering; + } + + Ok(()) + } +} + +impl Accumulator for MaxMinByAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + self.update_from(&values[0], &values[1]) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + // State columns mirror the input columns: [value, ordering]. + self.update_from(&states[0], &states[1]) + } + + fn state(&mut self) -> Result> { + Ok(vec![self.value.clone(), self.ordering.clone()]) + } + + fn evaluate(&mut self) -> Result { + Ok(self.value.clone()) + } + + fn size(&self) -> usize { + size_of_val(self) + self.value.size() + self.ordering.size() + } +} + +/// Vectorized grouped accumulator for `max_by` / `min_by`. +/// +/// Each group keeps the best `(value, ordering)` pair as Arrow row-format bytes. The ordering +/// rows are byte-comparable, so selecting the extremum for a batch is a single row conversion +/// plus per-row byte comparisons, avoiding the per-group `ScalarValue` work of the generic +/// `GroupsAccumulatorAdapter`. +struct MaxMinByGroupsAccumulator { + /// Converts and compares the ordering column. Its sort options encode the wanted extremum + /// as the largest row bytes (see `extremum_sort_options`). + ordering_converter: RowConverter, + /// Converts the value column to and from row bytes. Sort options are irrelevant here since + /// values are only stored, never compared. + value_converter: RowConverter, + /// Row bytes for a null value, used for groups that have not been updated (or whose winning + /// value is null). + null_value_row: OwnedRow, + /// Row bytes for a null ordering, used for groups that have seen no non-null ordering. + null_ordering_row: OwnedRow, + /// Per-group winning value (row bytes). + best_value: Vec, + /// Per-group winning ordering (row bytes). + best_ordering: Vec, + /// Per-group flag: has a non-null ordering been seen yet? + has_ordering: Vec, +} + +impl MaxMinByGroupsAccumulator { + fn try_new(value_type: DataType, ordering_type: DataType, is_max: bool) -> Result { + let ordering_converter = RowConverter::new(vec![SortField::new_with_options( + ordering_type.clone(), + extremum_sort_options(is_max), + )])?; + let value_converter = RowConverter::new(vec![SortField::new(value_type.clone())])?; + let null_ordering_row = ordering_converter + .convert_columns(&[new_null_array(&ordering_type, 1)])? + .row(0) + .owned(); + let null_value_row = value_converter + .convert_columns(&[new_null_array(&value_type, 1)])? + .row(0) + .owned(); + Ok(Self { + ordering_converter, + value_converter, + null_value_row, + null_ordering_row, + best_value: Vec::new(), + best_ordering: Vec::new(), + has_ordering: Vec::new(), + }) + } + + fn resize(&mut self, total_num_groups: usize) { + self.best_value + .resize(total_num_groups, self.null_value_row.clone()); + self.best_ordering + .resize(total_num_groups, self.null_ordering_row.clone()); + self.has_ordering.resize(total_num_groups, false); + } + + /// Shared update/merge logic: `values[0]` is the value column, `values[1]` the ordering + /// column. Rows with a null ordering are ignored; on a tie the later row wins. + fn update_groups( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + self.resize(total_num_groups); + let value_rows = self + .value_converter + .convert_columns(&[Arc::clone(&values[0])])?; + let ordering_arr = &values[1]; + let ordering_rows = self + .ordering_converter + .convert_columns(&[Arc::clone(ordering_arr)])?; + + for (idx, &group_index) in group_indices.iter().enumerate() { + if let Some(filter) = opt_filter { + if !filter.is_valid(idx) || !filter.value(idx) { + continue; + } + } + if ordering_arr.is_null(idx) { + continue; + } + let candidate = ordering_rows.row(idx); + let take = !self.has_ordering[group_index] + || candidate >= self.best_ordering[group_index].row(); + if take { + self.best_ordering[group_index] = candidate.owned(); + self.best_value[group_index] = value_rows.row(idx).owned(); + self.has_ordering[group_index] = true; + } + } + Ok(()) + } +} + +impl GroupsAccumulator for MaxMinByGroupsAccumulator { + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + self.update_groups(values, group_indices, opt_filter, total_num_groups) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + // State columns mirror the input columns: [value, ordering]. + self.update_groups(values, group_indices, opt_filter, total_num_groups) + } + + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + let value_rows = emit_to.take_needed(&mut self.best_value); + let _ = emit_to.take_needed(&mut self.best_ordering); + let _ = emit_to.take_needed(&mut self.has_ordering); + let arrays = self + .value_converter + .convert_rows(value_rows.iter().map(|r| r.row()))?; + Ok(Arc::clone(&arrays[0])) + } + + fn state(&mut self, emit_to: EmitTo) -> Result> { + let value_rows = emit_to.take_needed(&mut self.best_value); + let ordering_rows = emit_to.take_needed(&mut self.best_ordering); + let _ = emit_to.take_needed(&mut self.has_ordering); + let value_arrays = self + .value_converter + .convert_rows(value_rows.iter().map(|r| r.row()))?; + let ordering_arrays = self + .ordering_converter + .convert_rows(ordering_rows.iter().map(|r| r.row()))?; + Ok(vec![ + Arc::clone(&value_arrays[0]), + Arc::clone(&ordering_arrays[0]), + ]) + } + + fn size(&self) -> usize { + size_of_val(self) + + (self.best_value.capacity() + self.best_ordering.capacity()) + * std::mem::size_of::() + + self.has_ordering.capacity() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{AsArray, Float64Array, Int32Array, StringArray}; + + fn max_by_acc(value_type: DataType, ordering_type: DataType) -> MaxMinByAccumulator { + MaxMinByAccumulator::try_new(value_type, ordering_type, true).unwrap() + } + + fn min_by_acc(value_type: DataType, ordering_type: DataType) -> MaxMinByAccumulator { + MaxMinByAccumulator::try_new(value_type, ordering_type, false).unwrap() + } + + #[test] + fn max_by_basic() { + let mut acc = max_by_acc(DataType::Utf8, DataType::Int32); + let values: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c"])); + let ordering: ArrayRef = Arc::new(Int32Array::from(vec![10, 50, 20])); + acc.update_batch(&[values, ordering]).unwrap(); + assert_eq!(acc.evaluate().unwrap(), ScalarValue::from("b")); + } + + #[test] + fn min_by_basic() { + let mut acc = min_by_acc(DataType::Utf8, DataType::Int32); + let values: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c"])); + let ordering: ArrayRef = Arc::new(Int32Array::from(vec![10, 50, 20])); + acc.update_batch(&[values, ordering]).unwrap(); + assert_eq!(acc.evaluate().unwrap(), ScalarValue::from("a")); + } + + #[test] + fn null_ordering_is_ignored() { + let mut acc = max_by_acc(DataType::Utf8, DataType::Int32); + let values: ArrayRef = Arc::new(StringArray::from(vec![Some("a"), Some("b"), Some("c")])); + let ordering: ArrayRef = Arc::new(Int32Array::from(vec![Some(10), None, Some(5)])); + acc.update_batch(&[values, ordering]).unwrap(); + // The row with ordering=None (value "b") is ignored; max ordering is 10 -> "a". + assert_eq!(acc.evaluate().unwrap(), ScalarValue::from("a")); + } + + #[test] + fn all_null_ordering_yields_null() { + let mut acc = max_by_acc(DataType::Utf8, DataType::Int32); + let values: ArrayRef = Arc::new(StringArray::from(vec![Some("a"), Some("b")])); + let ordering: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); + acc.update_batch(&[values, ordering]).unwrap(); + assert_eq!(acc.evaluate().unwrap(), ScalarValue::Utf8(None)); + } + + #[test] + fn null_value_at_extremum_is_returned() { + let mut acc = max_by_acc(DataType::Utf8, DataType::Int32); + let values: ArrayRef = Arc::new(StringArray::from(vec![Some("a"), None])); + let ordering: ArrayRef = Arc::new(Int32Array::from(vec![Some(10), Some(50)])); + acc.update_batch(&[values, ordering]).unwrap(); + // Max ordering 50 pairs with a null value. + assert_eq!(acc.evaluate().unwrap(), ScalarValue::Utf8(None)); + } + + #[test] + fn empty_group_yields_null() { + let mut acc = max_by_acc(DataType::Utf8, DataType::Int32); + assert_eq!(acc.evaluate().unwrap(), ScalarValue::Utf8(None)); + } + + #[test] + fn max_by_nan_is_largest() { + let mut acc = max_by_acc(DataType::Utf8, DataType::Float64); + let values: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c"])); + let ordering: ArrayRef = Arc::new(Float64Array::from(vec![1.0, f64::NAN, 2.0])); + acc.update_batch(&[values, ordering]).unwrap(); + // Spark treats NaN as the largest value, matching arrow's row ordering. + assert_eq!(acc.evaluate().unwrap(), ScalarValue::from("b")); + } + + #[test] + fn merge_matches_single_shot() { + let single = { + let mut acc = max_by_acc(DataType::Utf8, DataType::Int32); + let values: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e", "f"])); + let ordering: ArrayRef = Arc::new(Int32Array::from(vec![1, 6, 3, 2, 5, 4])); + acc.update_batch(&[values, ordering]).unwrap(); + acc.evaluate().unwrap() + }; + + let mut left = max_by_acc(DataType::Utf8, DataType::Int32); + left.update_batch(&[ + Arc::new(StringArray::from(vec!["a", "b", "c"])) as ArrayRef, + Arc::new(Int32Array::from(vec![1, 6, 3])) as ArrayRef, + ]) + .unwrap(); + let mut right = max_by_acc(DataType::Utf8, DataType::Int32); + right + .update_batch(&[ + Arc::new(StringArray::from(vec!["d", "e", "f"])) as ArrayRef, + Arc::new(Int32Array::from(vec![2, 5, 4])) as ArrayRef, + ]) + .unwrap(); + + let mut merged = max_by_acc(DataType::Utf8, DataType::Int32); + for acc in [&mut left, &mut right] { + let state = acc.state().unwrap(); + let value_arr = ScalarValue::iter_to_array(vec![state[0].clone()]).unwrap(); + let ordering_arr = ScalarValue::iter_to_array(vec![state[1].clone()]).unwrap(); + merged.merge_batch(&[value_arr, ordering_arr]).unwrap(); + } + assert_eq!(merged.evaluate().unwrap(), single); + } + + // ---- GroupsAccumulator tests ---- + + fn max_by_groups(value_type: DataType, ordering_type: DataType) -> MaxMinByGroupsAccumulator { + MaxMinByGroupsAccumulator::try_new(value_type, ordering_type, true).unwrap() + } + + fn min_by_groups(value_type: DataType, ordering_type: DataType) -> MaxMinByGroupsAccumulator { + MaxMinByGroupsAccumulator::try_new(value_type, ordering_type, false).unwrap() + } + + fn eval_int(acc: &mut MaxMinByGroupsAccumulator) -> Vec> { + acc.evaluate(EmitTo::All) + .unwrap() + .as_primitive::() + .iter() + .collect() + } + + #[test] + fn groups_max_by_multi_group() { + let mut acc = max_by_groups(DataType::Int32, DataType::Int32); + // group 0: values 10,30 orderings 1,2 -> 30 ; group 1: values 20,40 orderings 5,4 -> 20 + let values: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 30, 40])); + let ordering: ArrayRef = Arc::new(Int32Array::from(vec![1, 5, 2, 4])); + acc.update_batch(&[values, ordering], &[0, 1, 0, 1], None, 2) + .unwrap(); + assert_eq!(eval_int(&mut acc), vec![Some(30), Some(20)]); + } + + #[test] + fn groups_min_by_multi_group() { + let mut acc = min_by_groups(DataType::Int32, DataType::Int32); + let values: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 30, 40])); + let ordering: ArrayRef = Arc::new(Int32Array::from(vec![1, 5, 2, 4])); + acc.update_batch(&[values, ordering], &[0, 1, 0, 1], None, 2) + .unwrap(); + // group 0: min ordering 1 -> 10 ; group 1: min ordering 4 -> 40 + assert_eq!(eval_int(&mut acc), vec![Some(10), Some(40)]); + } + + #[test] + fn groups_null_ordering_and_empty_group() { + let mut acc = max_by_groups(DataType::Int32, DataType::Int32); + // group 0: (10,null),(30,5) -> 30 ; group 1: no rows -> null ; group 2: (40,null) -> null + let values: ArrayRef = Arc::new(Int32Array::from(vec![Some(10), Some(30), Some(40)])); + let ordering: ArrayRef = Arc::new(Int32Array::from(vec![None, Some(5), None])); + acc.update_batch(&[values, ordering], &[0, 0, 2], None, 3) + .unwrap(); + assert_eq!(eval_int(&mut acc), vec![Some(30), None, None]); + } + + #[test] + fn groups_null_value_at_extremum() { + let mut acc = max_by_groups(DataType::Int32, DataType::Int32); + let values: ArrayRef = Arc::new(Int32Array::from(vec![Some(10), None])); + let ordering: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), Some(9)])); + acc.update_batch(&[values, ordering], &[0, 0], None, 1) + .unwrap(); + assert_eq!(eval_int(&mut acc), vec![None]); + } + + #[test] + fn groups_merge_matches_single_shot() { + let single = { + let mut acc = max_by_groups(DataType::Int32, DataType::Int32); + let values: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5, 6])); + let ordering: ArrayRef = Arc::new(Int32Array::from(vec![1, 6, 3, 2, 5, 4])); + acc.update_batch(&[values, ordering], &[0, 0, 0, 0, 0, 0], None, 1) + .unwrap(); + eval_int(&mut acc) + }; + + let mut left = max_by_groups(DataType::Int32, DataType::Int32); + left.update_batch( + &[ + Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef, + Arc::new(Int32Array::from(vec![1, 6, 3])) as ArrayRef, + ], + &[0, 0, 0], + None, + 1, + ) + .unwrap(); + let mut right = max_by_groups(DataType::Int32, DataType::Int32); + right + .update_batch( + &[ + Arc::new(Int32Array::from(vec![4, 5, 6])) as ArrayRef, + Arc::new(Int32Array::from(vec![2, 5, 4])) as ArrayRef, + ], + &[0, 0, 0], + None, + 1, + ) + .unwrap(); + + let mut merged = max_by_groups(DataType::Int32, DataType::Int32); + for acc in [&mut left, &mut right] { + let state = acc.state(EmitTo::All).unwrap(); + merged.merge_batch(&state, &[0], None, 1).unwrap(); + } + assert_eq!(eval_int(&mut merged), single); + } + + #[test] + fn groups_filter_is_respected() { + let mut acc = max_by_groups(DataType::Int32, DataType::Int32); + let values: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 30])); + let ordering: ArrayRef = Arc::new(Int32Array::from(vec![1, 9, 2])); + // Filter out the row with ordering 9, so the max becomes ordering 2 -> value 30. + let filter = BooleanArray::from(vec![true, false, true]); + acc.update_batch(&[values, ordering], &[0, 0, 0], Some(&filter), 1) + .unwrap(); + assert_eq!(eval_int(&mut acc), vec![Some(30)]); + } +} diff --git a/native/spark-expr/src/agg_funcs/mod.rs b/native/spark-expr/src/agg_funcs/mod.rs index 2a0322e46c..ce6aefedb5 100644 --- a/native/spark-expr/src/agg_funcs/mod.rs +++ b/native/spark-expr/src/agg_funcs/mod.rs @@ -19,6 +19,7 @@ mod avg; mod avg_decimal; mod correlation; mod covariance; +mod max_min_by; mod stddev; mod sum_decimal; mod sum_int; @@ -29,6 +30,7 @@ pub use avg::Avg; pub use avg_decimal::AvgDecimal; pub use correlation::Correlation; pub use covariance::Covariance; +pub use max_min_by::MaxMinBy; pub use stddev::Stddev; pub use sum_decimal::SumDecimal; pub use sum_int::SumInteger; diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 7146eaec9b..753417fe08 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -399,7 +399,9 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim { classOf[First] -> CometFirst, classOf[Last] -> CometLast, classOf[Max] -> CometMax, + classOf[MaxBy] -> CometMaxBy, classOf[Min] -> CometMin, + classOf[MinBy] -> CometMinBy, classOf[Percentile] -> CometPercentile, classOf[StddevPop] -> CometStddevPop, classOf[StddevSamp] -> CometStddevSamp, diff --git a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala index 5710232cb4..4f02329b5f 100644 --- a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala +++ b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala @@ -22,7 +22,7 @@ package org.apache.comet.serde import scala.jdk.CollectionConverters._ import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal} -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate, CentralMomentAgg, CollectSet, Corr, Count, Covariance, CovPopulation, CovSample, First, Last, Max, Min, Percentile, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate, CentralMomentAgg, CollectSet, Corr, Count, Covariance, CovPopulation, CovSample, First, Last, Max, MaxBy, MaxMinBy, Min, MinBy, Percentile, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{ByteType, DecimalType, DoubleType, IntegerType, LongType, NumericType, ShortType, StringType} @@ -105,6 +105,82 @@ object CometMax extends CometAggregateExpressionSerde[Max] { } } +/** + * Shared serde for `max_by` and `min_by`. Both are 2-argument `MaxMinBy` `DeclarativeAggregate`s + * differing only in the comparison direction, so the value/ordering handling is identical. + */ +abstract class CometMaxMinBy[T <: MaxMinBy] extends CometAggregateExpressionSerde[T] { + + /** `true` for `max_by`, `false` for `min_by`. */ + protected def isMax: Boolean + + /** `"maximum"` or `"minimum"`, used in the non-determinism note. */ + private def extremum: String = if (isMax) "maximum" else "minimum" + + override def getCompatibleNotes(): Seq[String] = Seq( + s"This function is non-deterministic when multiple rows share the $extremum ordering value." + + " Results may differ from Spark in that case.") + + override def getUnsupportedReasons(): Seq[String] = Seq( + "The value and ordering must both be fixed-length types (boolean, integral, floating-point," + + " decimal, date, or timestamp). A variable-length or nested type such as string, binary, or" + + " struct forces Spark's `SortAggregate`, which Comet does not accelerate, so the aggregate" + + " falls back to Spark.") + + override def getSupportLevel(expr: T): SupportLevel = { + // Both the value and ordering must be fixed-length types. Spark only uses HashAggregate + // (the aggregate operator Comet accelerates) when the aggregation buffer is mutable; the + // buffer holds both the running value and the running ordering, so a variable-length type + // such as StringType in either position forces SortAggregate and falls back to Spark. + // The native side compares the ordering column via Arrow's row format, which supports all + // of these fixed-length orderable types. + if (!AggSerde.minMaxDataTypeSupported(expr.valueExpr.dataType)) { + Unsupported(Some(s"Unsupported value data type: ${expr.valueExpr.dataType}")) + } else if (!AggSerde.minMaxDataTypeSupported(expr.orderingExpr.dataType)) { + Unsupported(Some(s"Unsupported ordering data type: ${expr.orderingExpr.dataType}")) + } else { + Compatible() + } + } + + override def convert( + aggExpr: AggregateExpression, + expr: T, + inputs: Seq[Attribute], + binding: Boolean, + conf: SQLConf): Option[ExprOuterClass.AggExpr] = { + val valueExpr = exprToProto(expr.valueExpr, inputs, binding) + val orderingExpr = exprToProto(expr.orderingExpr, inputs, binding) + + if (valueExpr.isDefined && orderingExpr.isDefined) { + val aggBuilder = ExprOuterClass.AggExpr.newBuilder() + if (isMax) { + val builder = ExprOuterClass.MaxBy.newBuilder() + builder.setValue(valueExpr.get) + builder.setOrdering(orderingExpr.get) + aggBuilder.setMaxBy(builder) + } else { + val builder = ExprOuterClass.MinBy.newBuilder() + builder.setValue(valueExpr.get) + builder.setOrdering(orderingExpr.get) + aggBuilder.setMinBy(builder) + } + Some(aggBuilder.build()) + } else { + withFallbackReason(aggExpr, expr.valueExpr, expr.orderingExpr) + None + } + } +} + +object CometMaxBy extends CometMaxMinBy[MaxBy] { + override protected def isMax: Boolean = true +} + +object CometMinBy extends CometMaxMinBy[MinBy] { + override protected def isMax: Boolean = false +} + object CometCount extends CometAggregateExpressionSerde[Count] { override def convert( aggExpr: AggregateExpression, diff --git a/spark/src/test/resources/sql-tests/expressions/aggregate/max_by.sql b/spark/src/test/resources/sql-tests/expressions/aggregate/max_by.sql new file mode 100644 index 0000000000..c0463d375e --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/aggregate/max_by.sql @@ -0,0 +1,231 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- max_by(x, y) returns the value of x associated with the maximum value of y. +-- +-- The value (x) must be a fixed-length type: Spark only uses HashAggregate (the aggregate +-- operator Comet accelerates) when the aggregation buffer is mutable, so variable-length +-- value types such as string force SortAggregate and fall back to Spark. +-- +-- Ordering values are kept unique within each group so results are deterministic (max_by is +-- non-deterministic when several rows tie on the maximum ordering). + +-- ============================================================ +-- Setup: tables +-- ============================================================ + +statement +CREATE TABLE mb_src(v int, ord int, grp string) USING parquet + +statement +INSERT INTO mb_src VALUES + (10, 10, 'g1'), (20, 50, 'g1'), (30, 20, 'g1'), + (40, 40, 'g2'), (50, 5, 'g2'), (60, 30, 'g2'), + (70, 99, 'g3') + +-- ordering NULLs are ignored; a group of all-NULL orderings yields NULL +statement +CREATE TABLE mb_nulls(v int, ord int, grp string) USING parquet + +statement +INSERT INTO mb_nulls VALUES + (1, 10, 'g1'), (2, NULL, 'g1'), (3, 5, 'g1'), + (4, NULL, 'g2'), (5, NULL, 'g2'), + (6, 7, 'g3'), (NULL, 100, 'g3') + +statement +CREATE TABLE mb_empty(v int, ord int) USING parquet + +-- ============================================================ +-- Global aggregate (no GROUP BY) +-- ============================================================ + +query +SELECT max_by(v, ord) FROM mb_src + +-- ============================================================ +-- GROUP BY +-- ============================================================ + +query +SELECT grp, max_by(v, ord) FROM mb_src GROUP BY grp ORDER BY grp + +-- ============================================================ +-- NULL handling: NULL orderings ignored; the value paired with the +-- maximum ordering may itself be NULL; all-NULL orderings yield NULL +-- ============================================================ + +query +SELECT grp, max_by(v, ord) FROM mb_nulls GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Empty table yields NULL +-- ============================================================ + +query +SELECT max_by(v, ord) FROM mb_empty + +-- ============================================================ +-- Literal arguments (evaluated natively; constant folding is disabled) +-- ============================================================ + +query +SELECT max_by(5, 10), max_by(CAST(NULL AS INT), 20) + +-- ============================================================ +-- Mixed with other aggregates +-- ============================================================ + +query +SELECT grp, max_by(v, ord), count(*), max(ord) +FROM mb_src GROUP BY grp ORDER BY grp + +-- ============================================================ +-- BigInt value +-- ============================================================ + +statement +CREATE TABLE mb_long(val bigint, ord int, grp string) USING parquet + +statement +INSERT INTO mb_long VALUES + (1000000000000, 1, 'a'), (2000000000000, 3, 'a'), (3000000000000, 2, 'a'), + (4000000000000, 5, 'b'), (5000000000000, 4, 'b') + +query +SELECT grp, max_by(val, ord) FROM mb_long GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Double value and double ordering (NaN is the maximum in Spark) +-- ============================================================ + +statement +CREATE TABLE mb_dbl(v double, ord double, grp string) USING parquet + +statement +INSERT INTO mb_dbl VALUES + (1.1, 1.5, 'g1'), (2.2, 2.5, 'g1'), (3.3, 0.5, 'g1'), + (4.4, 1.0, 'g2'), (5.5, CAST('NaN' AS DOUBLE), 'g2'), (6.6, 100.0, 'g2'), + (7.7, CAST('Infinity' AS DOUBLE), 'g3'), (8.8, 3.0, 'g3') + +query +SELECT grp, max_by(v, ord) FROM mb_dbl GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Decimal value and decimal ordering +-- ============================================================ + +statement +CREATE TABLE mb_dec(v decimal(10,2), ord decimal(10,2), grp string) USING parquet + +statement +INSERT INTO mb_dec VALUES + (10.01, 1.50, 'g1'), (20.02, 9.99, 'g1'), (30.03, 5.00, 'g1'), + (40.04, 2.00, 'g2'), (50.05, 8.25, 'g2') + +query +SELECT grp, max_by(v, ord) FROM mb_dec GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Date / timestamp value and ordering +-- ============================================================ + +statement +CREATE TABLE mb_dt(d date, ts timestamp, ord int, grp string) USING parquet + +statement +INSERT INTO mb_dt VALUES + (DATE '2024-01-01', TIMESTAMP '2024-01-01 00:00:00', 1, 'g1'), + (DATE '2024-06-15', TIMESTAMP '2024-06-15 12:30:00', 3, 'g1'), + (DATE '2023-12-31', TIMESTAMP '2023-12-31 23:59:59', 2, 'g1'), + (DATE '2024-03-01', TIMESTAMP '2024-03-01 08:00:00', 1, 'g2') + +query +SELECT grp, max_by(d, ord) FROM mb_dt GROUP BY grp ORDER BY grp + +query +SELECT grp, max_by(ts, ord) FROM mb_dt GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Date / timestamp as the ordering column +-- ============================================================ + +query +SELECT grp, max_by(ord, d) FROM mb_dt GROUP BY grp ORDER BY grp + +query +SELECT grp, max_by(ord, ts) FROM mb_dt GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Negative and boundary ordering values +-- ============================================================ + +statement +CREATE TABLE mb_bound(v int, ord bigint, grp string) USING parquet + +statement +INSERT INTO mb_bound VALUES + (1, -100, 'g1'), (2, -5, 'g1'), (3, -50, 'g1'), + (4, -9223372036854775808, 'g2'), (5, 9223372036854775807, 'g2'), (6, 0, 'g2'), + (7, -2147483648, 'g3'), (8, 2147483647, 'g3') + +query +SELECT grp, max_by(v, ord) FROM mb_bound GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Double ordering with -Infinity and Infinity +-- ============================================================ + +statement +CREATE TABLE mb_inf(v int, ord double, grp string) USING parquet + +statement +INSERT INTO mb_inf VALUES + (1, CAST('-Infinity' AS DOUBLE), 'g1'), (2, -1.0, 'g1'), (3, CAST('Infinity' AS DOUBLE), 'g1'), + (4, CAST('-Infinity' AS DOUBLE), 'g2'), (5, -2.0, 'g2') + +query +SELECT grp, max_by(v, ord) FROM mb_inf GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Multiple max_by in one query +-- ============================================================ + +query +SELECT grp, max_by(v, ord), max_by(ord, v) FROM mb_src GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Value and ordering are the same column +-- ============================================================ + +query +SELECT grp, max_by(ord, ord) FROM mb_src GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Boolean value +-- ============================================================ + +statement +CREATE TABLE mb_bool(v boolean, ord int, grp string) USING parquet + +statement +INSERT INTO mb_bool VALUES + (true, 1, 'a'), (false, 3, 'a'), (true, 2, 'a'), + (false, 5, 'b'), (true, 4, 'b') + +query +SELECT grp, max_by(v, ord) FROM mb_bool GROUP BY grp ORDER BY grp diff --git a/spark/src/test/resources/sql-tests/expressions/aggregate/min_by.sql b/spark/src/test/resources/sql-tests/expressions/aggregate/min_by.sql new file mode 100644 index 0000000000..0f0101e1e7 --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/aggregate/min_by.sql @@ -0,0 +1,216 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- min_by(x, y) returns the value of x associated with the minimum value of y. +-- +-- The value (x) must be a fixed-length type: Spark only uses HashAggregate (the aggregate +-- operator Comet accelerates) when the aggregation buffer is mutable, so variable-length +-- value types such as string force SortAggregate and fall back to Spark. +-- +-- Ordering values are kept unique within each group so results are deterministic (min_by is +-- non-deterministic when several rows tie on the minimum ordering). + +-- ============================================================ +-- Setup: tables +-- ============================================================ + +statement +CREATE TABLE nb_src(v int, ord int, grp string) USING parquet + +statement +INSERT INTO nb_src VALUES + (10, 10, 'g1'), (20, 50, 'g1'), (30, 20, 'g1'), + (40, 40, 'g2'), (50, 5, 'g2'), (60, 30, 'g2'), + (70, 99, 'g3') + +-- ordering NULLs are ignored; a group of all-NULL orderings yields NULL +statement +CREATE TABLE nb_nulls(v int, ord int, grp string) USING parquet + +statement +INSERT INTO nb_nulls VALUES + (1, 10, 'g1'), (2, NULL, 'g1'), (3, 5, 'g1'), + (4, NULL, 'g2'), (5, NULL, 'g2'), + (6, 7, 'g3'), (NULL, 1, 'g3') + +statement +CREATE TABLE nb_empty(v int, ord int) USING parquet + +-- ============================================================ +-- Global aggregate (no GROUP BY) +-- ============================================================ + +query +SELECT min_by(v, ord) FROM nb_src + +-- ============================================================ +-- GROUP BY +-- ============================================================ + +query +SELECT grp, min_by(v, ord) FROM nb_src GROUP BY grp ORDER BY grp + +-- ============================================================ +-- NULL handling: NULL orderings ignored; the value paired with the +-- minimum ordering may itself be NULL; all-NULL orderings yield NULL +-- ============================================================ + +query +SELECT grp, min_by(v, ord) FROM nb_nulls GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Empty table yields NULL +-- ============================================================ + +query +SELECT min_by(v, ord) FROM nb_empty + +-- ============================================================ +-- Literal arguments (evaluated natively; constant folding is disabled) +-- ============================================================ + +query +SELECT min_by(5, 10), min_by(CAST(NULL AS INT), 20) + +-- ============================================================ +-- Mixed with other aggregates +-- ============================================================ + +query +SELECT grp, min_by(v, ord), count(*), min(ord) +FROM nb_src GROUP BY grp ORDER BY grp + +-- ============================================================ +-- BigInt value +-- ============================================================ + +statement +CREATE TABLE nb_long(val bigint, ord int, grp string) USING parquet + +statement +INSERT INTO nb_long VALUES + (1000000000000, 1, 'a'), (2000000000000, 3, 'a'), (3000000000000, 2, 'a'), + (4000000000000, 5, 'b'), (5000000000000, 4, 'b') + +query +SELECT grp, min_by(val, ord) FROM nb_long GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Double value and double ordering (NaN is the maximum in Spark, so it is +-- never selected by min_by unless it is the only non-null ordering) +-- ============================================================ + +statement +CREATE TABLE nb_dbl(v double, ord double, grp string) USING parquet + +statement +INSERT INTO nb_dbl VALUES + (1.1, 1.5, 'g1'), (2.2, 2.5, 'g1'), (3.3, 0.5, 'g1'), + (4.4, 1.0, 'g2'), (5.5, CAST('NaN' AS DOUBLE), 'g2'), (6.6, 100.0, 'g2'), + (7.7, CAST('-Infinity' AS DOUBLE), 'g3'), (8.8, 3.0, 'g3') + +query +SELECT grp, min_by(v, ord) FROM nb_dbl GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Decimal value and decimal ordering +-- ============================================================ + +statement +CREATE TABLE nb_dec(v decimal(10,2), ord decimal(10,2), grp string) USING parquet + +statement +INSERT INTO nb_dec VALUES + (10.01, 1.50, 'g1'), (20.02, 9.99, 'g1'), (30.03, 5.00, 'g1'), + (40.04, 2.00, 'g2'), (50.05, 8.25, 'g2') + +query +SELECT grp, min_by(v, ord) FROM nb_dec GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Date / timestamp value and ordering +-- ============================================================ + +statement +CREATE TABLE nb_dt(d date, ts timestamp, ord int, grp string) USING parquet + +statement +INSERT INTO nb_dt VALUES + (DATE '2024-01-01', TIMESTAMP '2024-01-01 00:00:00', 1, 'g1'), + (DATE '2024-06-15', TIMESTAMP '2024-06-15 12:30:00', 3, 'g1'), + (DATE '2023-12-31', TIMESTAMP '2023-12-31 23:59:59', 2, 'g1'), + (DATE '2024-03-01', TIMESTAMP '2024-03-01 08:00:00', 1, 'g2') + +query +SELECT grp, min_by(d, ord) FROM nb_dt GROUP BY grp ORDER BY grp + +query +SELECT grp, min_by(ts, ord) FROM nb_dt GROUP BY grp ORDER BY grp + +query +SELECT grp, min_by(ord, d) FROM nb_dt GROUP BY grp ORDER BY grp + +query +SELECT grp, min_by(ord, ts) FROM nb_dt GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Negative and boundary ordering values +-- ============================================================ + +statement +CREATE TABLE nb_bound(v int, ord bigint, grp string) USING parquet + +statement +INSERT INTO nb_bound VALUES + (1, -100, 'g1'), (2, -5, 'g1'), (3, -50, 'g1'), + (4, -9223372036854775808, 'g2'), (5, 9223372036854775807, 'g2'), (6, 0, 'g2'), + (7, -2147483648, 'g3'), (8, 2147483647, 'g3') + +query +SELECT grp, min_by(v, ord) FROM nb_bound GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Multiple min_by and value = ordering column +-- ============================================================ + +query +SELECT grp, min_by(v, ord), min_by(ord, v) FROM nb_src GROUP BY grp ORDER BY grp + +query +SELECT grp, min_by(ord, ord) FROM nb_src GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Boolean value +-- ============================================================ + +statement +CREATE TABLE nb_bool(v boolean, ord int, grp string) USING parquet + +statement +INSERT INTO nb_bool VALUES + (true, 1, 'a'), (false, 3, 'a'), (true, 2, 'a'), + (false, 5, 'b'), (true, 4, 'b') + +query +SELECT grp, min_by(v, ord) FROM nb_bool GROUP BY grp ORDER BY grp + +-- ============================================================ +-- max_by and min_by in the same query +-- ============================================================ + +query +SELECT grp, max_by(v, ord), min_by(v, ord) FROM nb_src GROUP BY grp ORDER BY grp diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometAggregateExpressionBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometAggregateExpressionBenchmark.scala index a9ee46802a..07d6446857 100644 --- a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometAggregateExpressionBenchmark.scala +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometAggregateExpressionBenchmark.scala @@ -125,6 +125,28 @@ object CometAggregateExpressionBenchmark extends CometBenchmarkBase { "percentile_double_high_card", "SELECT percentile(c_double, 0.5) FROM parquetV1Table GROUP BY high_card_grp")) + // max_by / min_by run natively only when both the value and ordering are fixed-length types (a + // variable-length value or ordering forces Spark's SortAggregate, which Comet does not run). + private val maxMinByAggregates = List( + AggExprConfig("max_by_int", "SELECT max_by(c_int, c_long) FROM parquetV1Table GROUP BY grp"), + AggExprConfig( + "max_by_double", + "SELECT max_by(c_double, c_int) FROM parquetV1Table GROUP BY grp"), + AggExprConfig( + "max_by_decimal", + "SELECT max_by(c_decimal, c_int) FROM parquetV1Table GROUP BY grp"), + AggExprConfig( + "max_by_high_card", + "SELECT max_by(c_int, c_long) FROM parquetV1Table GROUP BY high_card_grp"), + AggExprConfig("max_by_global", "SELECT max_by(c_int, c_long) FROM parquetV1Table"), + AggExprConfig("min_by_int", "SELECT min_by(c_int, c_long) FROM parquetV1Table GROUP BY grp"), + AggExprConfig( + "min_by_double", + "SELECT min_by(c_double, c_int) FROM parquetV1Table GROUP BY grp"), + AggExprConfig( + "min_by_high_card", + "SELECT min_by(c_int, c_long) FROM parquetV1Table GROUP BY high_card_grp")) + override def runCometBenchmark(mainArgs: Array[String]): Unit = { val values = 1024 * 1024 @@ -148,7 +170,7 @@ object CometAggregateExpressionBenchmark extends CometBenchmarkBase { val allAggregates = basicAggregates ++ statisticalAggregates ++ bitwiseAggregates ++ multiKeyAggregates ++ multiAggregates ++ decimalAggregates ++ - highCardinalityAggregates ++ percentileAggregates + highCardinalityAggregates ++ percentileAggregates ++ maxMinByAggregates allAggregates.foreach { config => runExpressionBenchmark(config.name, v, config.query, config.extraCometConfigs)