From ae5a877a618a02e666527c552851511cb86d9f47 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 3 Jul 2026 09:42:29 -0600 Subject: [PATCH 1/2] feat: support max_by aggregate expression Add native support for Spark's max_by(x, y) aggregate, which returns the value of x associated with the maximum value of y. The expression is implemented as a native DataFusion aggregate (MaxMinBy) that compares the ordering column via Arrow's row format. Null orderings are ignored, the value paired with the maximum ordering is returned (and may itself be null), and an all-null-ordering group yields null. Both the value and ordering must be fixed-length types: a variable-length or nested type forces Spark's SortAggregate, which Comet does not run, so those cases fall back to Spark. --- .../expression-audits/agg_funcs.md | 7 + docs/source/user-guide/latest/expressions.md | 2 +- native/core/src/execution/planner.rs | 11 +- native/proto/src/proto/expr.proto | 8 + native/spark-expr/src/agg_funcs/max_min_by.rs | 338 ++++++++++++++++++ native/spark-expr/src/agg_funcs/mod.rs | 2 + .../apache/comet/serde/QueryPlanSerde.scala | 1 + .../org/apache/comet/serde/aggregates.scala | 56 ++- .../expressions/aggregate/max_by.sql | 231 ++++++++++++ .../CometAggregateExpressionBenchmark.scala | 17 +- 10 files changed, 668 insertions(+), 5 deletions(-) create mode 100644 native/spark-expr/src/agg_funcs/max_min_by.rs create mode 100644 spark/src/test/resources/sql-tests/expressions/aggregate/max_by.sql diff --git a/docs/source/contributor-guide/expression-audits/agg_funcs.md b/docs/source/contributor-guide/expression-audits/agg_funcs.md index fb27662b49..1baf3115ee 100644 --- a/docs/source/contributor-guide/expression-audits/agg_funcs.md +++ b/docs/source/contributor-guide/expression-audits/agg_funcs.md @@ -39,6 +39,13 @@ - Spark 3.5.8 (2026-05-26) - Spark 4.0.1 (2026-05-26) +## max_by + +- Spark 3.4.3 (2026-07-03): `MaxBy` is a 2-argument `DeclarativeAggregate` registered as `expression[MaxBy]("max_by")`. Buffer is `(valueWithExtremumOrdering, extremumOrdering)`; null orderings are ignored, the value paired with the maximum ordering is returned (and may itself be null), and an all-null-ordering group yields null. Comet implements a native `max_by` aggregate. Only fixed-length value and ordering types are accelerated: a variable-length or nested type (string, binary, struct) forces Spark's `SortAggregate`, which Comet does not support, so those cases fall back to Spark. `max_by` is non-deterministic when several rows tie on the maximum ordering, matching Spark's documented behavior. +- Spark 3.5.8 (2026-07-03): aggregate logic identical to 3.4.3. +- Spark 4.0.1 (2026-07-03): aggregate logic identical to 3.4.3; only the `@ExpressionDescription` example and note text differ. +- Spark 4.1.1 (2026-07-03): aggregate logic identical to 3.4.3. The 3-argument top-k form `max_by(x, y, k)` (via `MaxByBuilder` / `MaxMinByK`) is only present on Spark master, not in any released 3.4 through 4.1 version, so Comet handles only the 2-argument form. + ## median - Spark 3.4.3 (audited 2026-06-24): `Median(child)` is a `RuntimeReplaceableAggregate` with `replacement = Percentile(child, Literal(0.5))`. Catalyst rewrites `median(x)` to `percentile(x, 0.5)` before Comet sees the plan, so it is served by `CometPercentile`. diff --git a/docs/source/user-guide/latest/expressions.md b/docs/source/user-guide/latest/expressions.md index 2ca3a13c62..c696d801ec 100644 --- a/docs/source/user-guide/latest/expressions.md +++ b/docs/source/user-guide/latest/expressions.md @@ -95,7 +95,7 @@ The tables below list every Spark built-in expression with its current status. | `last_value` | ✅ | | | `listagg` | 🔜 | String aggregation | | `max` | ✅ | | -| `max_by` | 🔜 | [#3841](https://github.com/apache/datafusion-comet/issues/3841) | +| `max_by` | ✅ | Value and ordering must be fixed-length types | | `mean` | ✅ | | | `median` | ✅ | Rewrites to `percentile(col, 0.5)`; falls back by default, opt-in via allowIncompatible ([#4719](https://github.com/apache/datafusion-comet/issues/4719)) | | `min` | ✅ | | diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 25162332fd..82d3ed1984 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -130,8 +130,8 @@ use datafusion_comet_proto::{ use datafusion_comet_spark_expr::{ jvm_udf::JvmScalarUdfExpr, ArrayInsert, Avg, AvgDecimal, Cast, CheckOverflow, Correlation, Covariance, CreateNamedStruct, DecimalRescaleCheckOverflow, GetArrayStructFields, - GetStructField, IfExpr, ListExtract, NormalizeNaNAndZero, SparkCastOptions, Stddev, SumDecimal, - ToJson, UnboundColumn, Variance, WideDecimalBinaryExpr, WideDecimalOp, + GetStructField, IfExpr, ListExtract, MaxMinBy, NormalizeNaNAndZero, SparkCastOptions, Stddev, + SumDecimal, ToJson, UnboundColumn, Variance, WideDecimalBinaryExpr, WideDecimalOp, }; use itertools::Itertools; use jni::objects::{Global, JObject}; @@ -2653,6 +2653,13 @@ impl PhysicalPlanner { let func = AggregateUDF::new_from_impl(SparkCollectSet::new()); Self::create_aggr_func_expr("collect_set", schema, vec![child], func) } + AggExprStruct::MaxBy(expr) => { + let value = self.create_expr(expr.value.as_ref().unwrap(), Arc::clone(&schema))?; + let ordering = + self.create_expr(expr.ordering.as_ref().unwrap(), Arc::clone(&schema))?; + let func = AggregateUDF::new_from_impl(MaxMinBy::new_max_by()); + Self::create_aggr_func_expr("max_by", schema, vec![value, ordering], func) + } } } diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 32adc16b72..d164c357c1 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -146,6 +146,7 @@ message AggExpr { BloomFilterAgg bloomFilterAgg = 16; CollectSet collectSet = 17; Percentile percentile = 18; + MaxBy maxBy = 19; } // Optional filter expression for SQL FILTER (WHERE ...) clause. @@ -277,6 +278,13 @@ message CollectSet { DataType datatype = 2; } +message MaxBy { + // The value returned by the aggregate (associated with the maximum ordering). + Expr value = 1; + // The ordering expression whose maximum selects the returned value. + Expr ordering = 2; +} + enum EvalMode { LEGACY = 0; TRY = 1; diff --git a/native/spark-expr/src/agg_funcs/max_min_by.rs b/native/spark-expr/src/agg_funcs/max_min_by.rs new file mode 100644 index 0000000000..b4de0e1e37 --- /dev/null +++ b/native/spark-expr/src/agg_funcs/max_min_by.rs @@ -0,0 +1,338 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, ArrayRef}; +use arrow::compute::SortOptions; +use arrow::datatypes::{DataType, Field, FieldRef}; +use arrow::row::RowConverter; +use arrow::row::SortField; +use datafusion::common::{Result, ScalarValue}; +use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion::logical_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility}; +use datafusion::physical_expr::expressions::format_state_name; +use std::mem::size_of_val; +use std::sync::Arc; + +/// Spark-compatible `max_by(value, ordering)` / `min_by(value, ordering)` aggregate. +/// +/// Returns the `value` associated with the maximum (`max_by`) or minimum (`min_by`) +/// non-null `ordering`. Rows with a null `ordering` are ignored. The returned value +/// may itself be null when it is the value paired with the extremum ordering. If every +/// `ordering` in the group is null, the result is null. +/// +/// Spark's `MaxBy`/`MinBy` are `DeclarativeAggregate`s that keep a `(value, ordering)` +/// buffer and, on a tie in the ordering, the later row wins. Because ties across +/// partitions are processed in an unspecified order, Spark documents the function as +/// non-deterministic when several rows share the extremum ordering. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct MaxMinBy { + name: String, + signature: Signature, + /// `true` for `max_by`, `false` for `min_by`. + is_max: bool, +} + +impl std::hash::Hash for MaxMinBy { + fn hash(&self, state: &mut H) { + self.name.hash(state); + self.signature.hash(state); + self.is_max.hash(state); + } +} + +impl MaxMinBy { + /// Create a `max_by` aggregate. + pub fn new_max_by() -> Self { + Self { + name: "max_by".to_string(), + signature: Signature::any(2, Volatility::Immutable), + is_max: true, + } + } + + /// Create a `min_by` aggregate. + #[allow(dead_code)] + pub fn new_min_by() -> Self { + Self { + name: "min_by".to_string(), + signature: Signature::any(2, Volatility::Immutable), + is_max: false, + } + } +} + +impl AggregateUDFImpl for MaxMinBy { + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + // The result has the same type as the `value` argument. + Ok(arg_types[0].clone()) + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + let value_type = acc_args.exprs[0].data_type(acc_args.schema)?; + let ordering_type = acc_args.exprs[1].data_type(acc_args.schema)?; + Ok(Box::new(MaxMinByAccumulator::try_new( + value_type, + ordering_type, + self.is_max, + )?)) + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + let value_type = args.input_fields[0].data_type().clone(); + let ordering_type = args.input_fields[1].data_type().clone(); + Ok(vec![ + Arc::new(Field::new( + format_state_name(&self.name, "value"), + value_type, + true, + )), + Arc::new(Field::new( + format_state_name(&self.name, "ordering"), + ordering_type, + true, + )), + ]) + } +} + +/// Accumulator that tracks the running `(value, ordering)` pair for the extremum ordering. +#[derive(Debug)] +struct MaxMinByAccumulator { + /// The value paired with the current extremum ordering. May be null. + value: ScalarValue, + /// The current extremum ordering. Null means no non-null ordering has been seen yet. + ordering: ScalarValue, + /// `true` for `max_by`, `false` for `min_by`. + is_max: bool, +} + +impl MaxMinByAccumulator { + fn try_new(value_type: DataType, ordering_type: DataType, is_max: bool) -> Result { + Ok(Self { + value: ScalarValue::try_from(&value_type)?, + ordering: ScalarValue::try_from(&ordering_type)?, + is_max, + }) + } + + fn sort_options(&self) -> SortOptions { + // Encode the ordering column into arrow's row format so that the extremum can be + // found for any orderable type with a single comparison. For `max_by` we sort + // ascending, so the largest ordering yields the largest row bytes. For `min_by` + // we sort descending, so the smallest ordering yields the largest row bytes; the + // same "argmax of the row bytes" scan then selects the minimum. + SortOptions { + descending: !self.is_max, + nulls_first: true, + } + } + + /// Apply a batch of `(value, ordering)` columns, keeping the value paired with the + /// extremum ordering. Rows with a null ordering are ignored. + fn update_from(&mut self, value_arr: &ArrayRef, ordering_arr: &ArrayRef) -> Result<()> { + if ordering_arr.is_empty() { + return Ok(()); + } + + let converter = RowConverter::new(vec![SortField::new_with_options( + ordering_arr.data_type().clone(), + self.sort_options(), + )])?; + let rows = converter.convert_columns(&[Arc::clone(ordering_arr)])?; + + // Find the index of the extremum ordering in this batch (last one wins on a tie, + // matching Spark's sequential row processing), ignoring null orderings. + let mut best: Option = None; + for i in 0..ordering_arr.len() { + if ordering_arr.is_null(i) { + continue; + } + best = match best { + None => Some(i), + Some(b) if rows.row(i) >= rows.row(b) => Some(i), + Some(b) => Some(b), + }; + } + + let Some(b) = best else { + return Ok(()); + }; + + let candidate_ordering = ScalarValue::try_from_array(ordering_arr, b)?; + let take = if self.ordering.is_null() { + true + } else { + // Compare the batch's extremum ordering against the running extremum using the + // same row encoding. Build a two-row array [running, candidate] and compare. + let pair = ScalarValue::iter_to_array(vec![ + self.ordering.clone(), + candidate_ordering.clone(), + ])?; + let pair_rows = converter.convert_columns(&[pair])?; + pair_rows.row(1) >= pair_rows.row(0) + }; + + if take { + self.value = ScalarValue::try_from_array(value_arr, b)?; + self.ordering = candidate_ordering; + } + + Ok(()) + } +} + +impl Accumulator for MaxMinByAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + self.update_from(&values[0], &values[1]) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + // State columns mirror the input columns: [value, ordering]. + self.update_from(&states[0], &states[1]) + } + + fn state(&mut self) -> Result> { + Ok(vec![self.value.clone(), self.ordering.clone()]) + } + + fn evaluate(&mut self) -> Result { + Ok(self.value.clone()) + } + + fn size(&self) -> usize { + size_of_val(self) + self.value.size() + self.ordering.size() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Float64Array, Int32Array, StringArray}; + + fn max_by_acc(value_type: DataType, ordering_type: DataType) -> MaxMinByAccumulator { + MaxMinByAccumulator::try_new(value_type, ordering_type, true).unwrap() + } + + fn min_by_acc(value_type: DataType, ordering_type: DataType) -> MaxMinByAccumulator { + MaxMinByAccumulator::try_new(value_type, ordering_type, false).unwrap() + } + + #[test] + fn max_by_basic() { + let mut acc = max_by_acc(DataType::Utf8, DataType::Int32); + let values: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c"])); + let ordering: ArrayRef = Arc::new(Int32Array::from(vec![10, 50, 20])); + acc.update_batch(&[values, ordering]).unwrap(); + assert_eq!(acc.evaluate().unwrap(), ScalarValue::from("b")); + } + + #[test] + fn min_by_basic() { + let mut acc = min_by_acc(DataType::Utf8, DataType::Int32); + let values: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c"])); + let ordering: ArrayRef = Arc::new(Int32Array::from(vec![10, 50, 20])); + acc.update_batch(&[values, ordering]).unwrap(); + assert_eq!(acc.evaluate().unwrap(), ScalarValue::from("a")); + } + + #[test] + fn null_ordering_is_ignored() { + let mut acc = max_by_acc(DataType::Utf8, DataType::Int32); + let values: ArrayRef = Arc::new(StringArray::from(vec![Some("a"), Some("b"), Some("c")])); + let ordering: ArrayRef = Arc::new(Int32Array::from(vec![Some(10), None, Some(5)])); + acc.update_batch(&[values, ordering]).unwrap(); + // The row with ordering=None (value "b") is ignored; max ordering is 10 -> "a". + assert_eq!(acc.evaluate().unwrap(), ScalarValue::from("a")); + } + + #[test] + fn all_null_ordering_yields_null() { + let mut acc = max_by_acc(DataType::Utf8, DataType::Int32); + let values: ArrayRef = Arc::new(StringArray::from(vec![Some("a"), Some("b")])); + let ordering: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); + acc.update_batch(&[values, ordering]).unwrap(); + assert_eq!(acc.evaluate().unwrap(), ScalarValue::Utf8(None)); + } + + #[test] + fn null_value_at_extremum_is_returned() { + let mut acc = max_by_acc(DataType::Utf8, DataType::Int32); + let values: ArrayRef = Arc::new(StringArray::from(vec![Some("a"), None])); + let ordering: ArrayRef = Arc::new(Int32Array::from(vec![Some(10), Some(50)])); + acc.update_batch(&[values, ordering]).unwrap(); + // Max ordering 50 pairs with a null value. + assert_eq!(acc.evaluate().unwrap(), ScalarValue::Utf8(None)); + } + + #[test] + fn empty_group_yields_null() { + let mut acc = max_by_acc(DataType::Utf8, DataType::Int32); + assert_eq!(acc.evaluate().unwrap(), ScalarValue::Utf8(None)); + } + + #[test] + fn max_by_nan_is_largest() { + let mut acc = max_by_acc(DataType::Utf8, DataType::Float64); + let values: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c"])); + let ordering: ArrayRef = Arc::new(Float64Array::from(vec![1.0, f64::NAN, 2.0])); + acc.update_batch(&[values, ordering]).unwrap(); + // Spark treats NaN as the largest value, matching arrow's row ordering. + assert_eq!(acc.evaluate().unwrap(), ScalarValue::from("b")); + } + + #[test] + fn merge_matches_single_shot() { + let single = { + let mut acc = max_by_acc(DataType::Utf8, DataType::Int32); + let values: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e", "f"])); + let ordering: ArrayRef = Arc::new(Int32Array::from(vec![1, 6, 3, 2, 5, 4])); + acc.update_batch(&[values, ordering]).unwrap(); + acc.evaluate().unwrap() + }; + + let mut left = max_by_acc(DataType::Utf8, DataType::Int32); + left.update_batch(&[ + Arc::new(StringArray::from(vec!["a", "b", "c"])) as ArrayRef, + Arc::new(Int32Array::from(vec![1, 6, 3])) as ArrayRef, + ]) + .unwrap(); + let mut right = max_by_acc(DataType::Utf8, DataType::Int32); + right + .update_batch(&[ + Arc::new(StringArray::from(vec!["d", "e", "f"])) as ArrayRef, + Arc::new(Int32Array::from(vec![2, 5, 4])) as ArrayRef, + ]) + .unwrap(); + + let mut merged = max_by_acc(DataType::Utf8, DataType::Int32); + for acc in [&mut left, &mut right] { + let state = acc.state().unwrap(); + let value_arr = ScalarValue::iter_to_array(vec![state[0].clone()]).unwrap(); + let ordering_arr = ScalarValue::iter_to_array(vec![state[1].clone()]).unwrap(); + merged.merge_batch(&[value_arr, ordering_arr]).unwrap(); + } + assert_eq!(merged.evaluate().unwrap(), single); + } +} diff --git a/native/spark-expr/src/agg_funcs/mod.rs b/native/spark-expr/src/agg_funcs/mod.rs index 2a0322e46c..ce6aefedb5 100644 --- a/native/spark-expr/src/agg_funcs/mod.rs +++ b/native/spark-expr/src/agg_funcs/mod.rs @@ -19,6 +19,7 @@ mod avg; mod avg_decimal; mod correlation; mod covariance; +mod max_min_by; mod stddev; mod sum_decimal; mod sum_int; @@ -29,6 +30,7 @@ pub use avg::Avg; pub use avg_decimal::AvgDecimal; pub use correlation::Correlation; pub use covariance::Covariance; +pub use max_min_by::MaxMinBy; pub use stddev::Stddev; pub use sum_decimal::SumDecimal; pub use sum_int::SumInteger; diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 7146eaec9b..1785fa384e 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -399,6 +399,7 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim { classOf[First] -> CometFirst, classOf[Last] -> CometLast, classOf[Max] -> CometMax, + classOf[MaxBy] -> CometMaxBy, classOf[Min] -> CometMin, classOf[Percentile] -> CometPercentile, classOf[StddevPop] -> CometStddevPop, diff --git a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala index 5710232cb4..39878c9c4c 100644 --- a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala +++ b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala @@ -22,7 +22,7 @@ package org.apache.comet.serde import scala.jdk.CollectionConverters._ import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal} -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate, CentralMomentAgg, CollectSet, Corr, Count, Covariance, CovPopulation, CovSample, First, Last, Max, Min, Percentile, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate, CentralMomentAgg, CollectSet, Corr, Count, Covariance, CovPopulation, CovSample, First, Last, Max, MaxBy, Min, Percentile, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{ByteType, DecimalType, DoubleType, IntegerType, LongType, NumericType, ShortType, StringType} @@ -105,6 +105,60 @@ object CometMax extends CometAggregateExpressionSerde[Max] { } } +object CometMaxBy extends CometAggregateExpressionSerde[MaxBy] { + + override def getCompatibleNotes(): Seq[String] = Seq( + "This function is non-deterministic when multiple rows share the maximum ordering value." + + " Results may differ from Spark in that case.") + + override def getUnsupportedReasons(): Seq[String] = Seq( + "The value and ordering must both be fixed-length types (boolean, integral, floating-point," + + " decimal, date, or timestamp). A variable-length or nested type such as string, binary, or" + + " struct forces Spark's `SortAggregate`, which Comet does not accelerate, so the aggregate" + + " falls back to Spark.") + + override def getSupportLevel(expr: MaxBy): SupportLevel = { + // Both the value and ordering must be fixed-length types. Spark only uses HashAggregate + // (the aggregate operator Comet accelerates) when the aggregation buffer is mutable; the + // buffer holds both the running value and the running ordering, so a variable-length type + // such as StringType in either position forces SortAggregate and falls back to Spark. + // The native side compares the ordering column via Arrow's row format, which supports all + // of these fixed-length orderable types. + if (!AggSerde.minMaxDataTypeSupported(expr.valueExpr.dataType)) { + Unsupported(Some(s"Unsupported value data type: ${expr.valueExpr.dataType}")) + } else if (!AggSerde.minMaxDataTypeSupported(expr.orderingExpr.dataType)) { + Unsupported(Some(s"Unsupported ordering data type: ${expr.orderingExpr.dataType}")) + } else { + Compatible() + } + } + + override def convert( + aggExpr: AggregateExpression, + expr: MaxBy, + inputs: Seq[Attribute], + binding: Boolean, + conf: SQLConf): Option[ExprOuterClass.AggExpr] = { + val valueExpr = exprToProto(expr.valueExpr, inputs, binding) + val orderingExpr = exprToProto(expr.orderingExpr, inputs, binding) + + if (valueExpr.isDefined && orderingExpr.isDefined) { + val builder = ExprOuterClass.MaxBy.newBuilder() + builder.setValue(valueExpr.get) + builder.setOrdering(orderingExpr.get) + + Some( + ExprOuterClass.AggExpr + .newBuilder() + .setMaxBy(builder) + .build()) + } else { + withFallbackReason(aggExpr, expr.valueExpr, expr.orderingExpr) + None + } + } +} + object CometCount extends CometAggregateExpressionSerde[Count] { override def convert( aggExpr: AggregateExpression, diff --git a/spark/src/test/resources/sql-tests/expressions/aggregate/max_by.sql b/spark/src/test/resources/sql-tests/expressions/aggregate/max_by.sql new file mode 100644 index 0000000000..c0463d375e --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/aggregate/max_by.sql @@ -0,0 +1,231 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- max_by(x, y) returns the value of x associated with the maximum value of y. +-- +-- The value (x) must be a fixed-length type: Spark only uses HashAggregate (the aggregate +-- operator Comet accelerates) when the aggregation buffer is mutable, so variable-length +-- value types such as string force SortAggregate and fall back to Spark. +-- +-- Ordering values are kept unique within each group so results are deterministic (max_by is +-- non-deterministic when several rows tie on the maximum ordering). + +-- ============================================================ +-- Setup: tables +-- ============================================================ + +statement +CREATE TABLE mb_src(v int, ord int, grp string) USING parquet + +statement +INSERT INTO mb_src VALUES + (10, 10, 'g1'), (20, 50, 'g1'), (30, 20, 'g1'), + (40, 40, 'g2'), (50, 5, 'g2'), (60, 30, 'g2'), + (70, 99, 'g3') + +-- ordering NULLs are ignored; a group of all-NULL orderings yields NULL +statement +CREATE TABLE mb_nulls(v int, ord int, grp string) USING parquet + +statement +INSERT INTO mb_nulls VALUES + (1, 10, 'g1'), (2, NULL, 'g1'), (3, 5, 'g1'), + (4, NULL, 'g2'), (5, NULL, 'g2'), + (6, 7, 'g3'), (NULL, 100, 'g3') + +statement +CREATE TABLE mb_empty(v int, ord int) USING parquet + +-- ============================================================ +-- Global aggregate (no GROUP BY) +-- ============================================================ + +query +SELECT max_by(v, ord) FROM mb_src + +-- ============================================================ +-- GROUP BY +-- ============================================================ + +query +SELECT grp, max_by(v, ord) FROM mb_src GROUP BY grp ORDER BY grp + +-- ============================================================ +-- NULL handling: NULL orderings ignored; the value paired with the +-- maximum ordering may itself be NULL; all-NULL orderings yield NULL +-- ============================================================ + +query +SELECT grp, max_by(v, ord) FROM mb_nulls GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Empty table yields NULL +-- ============================================================ + +query +SELECT max_by(v, ord) FROM mb_empty + +-- ============================================================ +-- Literal arguments (evaluated natively; constant folding is disabled) +-- ============================================================ + +query +SELECT max_by(5, 10), max_by(CAST(NULL AS INT), 20) + +-- ============================================================ +-- Mixed with other aggregates +-- ============================================================ + +query +SELECT grp, max_by(v, ord), count(*), max(ord) +FROM mb_src GROUP BY grp ORDER BY grp + +-- ============================================================ +-- BigInt value +-- ============================================================ + +statement +CREATE TABLE mb_long(val bigint, ord int, grp string) USING parquet + +statement +INSERT INTO mb_long VALUES + (1000000000000, 1, 'a'), (2000000000000, 3, 'a'), (3000000000000, 2, 'a'), + (4000000000000, 5, 'b'), (5000000000000, 4, 'b') + +query +SELECT grp, max_by(val, ord) FROM mb_long GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Double value and double ordering (NaN is the maximum in Spark) +-- ============================================================ + +statement +CREATE TABLE mb_dbl(v double, ord double, grp string) USING parquet + +statement +INSERT INTO mb_dbl VALUES + (1.1, 1.5, 'g1'), (2.2, 2.5, 'g1'), (3.3, 0.5, 'g1'), + (4.4, 1.0, 'g2'), (5.5, CAST('NaN' AS DOUBLE), 'g2'), (6.6, 100.0, 'g2'), + (7.7, CAST('Infinity' AS DOUBLE), 'g3'), (8.8, 3.0, 'g3') + +query +SELECT grp, max_by(v, ord) FROM mb_dbl GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Decimal value and decimal ordering +-- ============================================================ + +statement +CREATE TABLE mb_dec(v decimal(10,2), ord decimal(10,2), grp string) USING parquet + +statement +INSERT INTO mb_dec VALUES + (10.01, 1.50, 'g1'), (20.02, 9.99, 'g1'), (30.03, 5.00, 'g1'), + (40.04, 2.00, 'g2'), (50.05, 8.25, 'g2') + +query +SELECT grp, max_by(v, ord) FROM mb_dec GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Date / timestamp value and ordering +-- ============================================================ + +statement +CREATE TABLE mb_dt(d date, ts timestamp, ord int, grp string) USING parquet + +statement +INSERT INTO mb_dt VALUES + (DATE '2024-01-01', TIMESTAMP '2024-01-01 00:00:00', 1, 'g1'), + (DATE '2024-06-15', TIMESTAMP '2024-06-15 12:30:00', 3, 'g1'), + (DATE '2023-12-31', TIMESTAMP '2023-12-31 23:59:59', 2, 'g1'), + (DATE '2024-03-01', TIMESTAMP '2024-03-01 08:00:00', 1, 'g2') + +query +SELECT grp, max_by(d, ord) FROM mb_dt GROUP BY grp ORDER BY grp + +query +SELECT grp, max_by(ts, ord) FROM mb_dt GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Date / timestamp as the ordering column +-- ============================================================ + +query +SELECT grp, max_by(ord, d) FROM mb_dt GROUP BY grp ORDER BY grp + +query +SELECT grp, max_by(ord, ts) FROM mb_dt GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Negative and boundary ordering values +-- ============================================================ + +statement +CREATE TABLE mb_bound(v int, ord bigint, grp string) USING parquet + +statement +INSERT INTO mb_bound VALUES + (1, -100, 'g1'), (2, -5, 'g1'), (3, -50, 'g1'), + (4, -9223372036854775808, 'g2'), (5, 9223372036854775807, 'g2'), (6, 0, 'g2'), + (7, -2147483648, 'g3'), (8, 2147483647, 'g3') + +query +SELECT grp, max_by(v, ord) FROM mb_bound GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Double ordering with -Infinity and Infinity +-- ============================================================ + +statement +CREATE TABLE mb_inf(v int, ord double, grp string) USING parquet + +statement +INSERT INTO mb_inf VALUES + (1, CAST('-Infinity' AS DOUBLE), 'g1'), (2, -1.0, 'g1'), (3, CAST('Infinity' AS DOUBLE), 'g1'), + (4, CAST('-Infinity' AS DOUBLE), 'g2'), (5, -2.0, 'g2') + +query +SELECT grp, max_by(v, ord) FROM mb_inf GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Multiple max_by in one query +-- ============================================================ + +query +SELECT grp, max_by(v, ord), max_by(ord, v) FROM mb_src GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Value and ordering are the same column +-- ============================================================ + +query +SELECT grp, max_by(ord, ord) FROM mb_src GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Boolean value +-- ============================================================ + +statement +CREATE TABLE mb_bool(v boolean, ord int, grp string) USING parquet + +statement +INSERT INTO mb_bool VALUES + (true, 1, 'a'), (false, 3, 'a'), (true, 2, 'a'), + (false, 5, 'b'), (true, 4, 'b') + +query +SELECT grp, max_by(v, ord) FROM mb_bool GROUP BY grp ORDER BY grp diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometAggregateExpressionBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometAggregateExpressionBenchmark.scala index a9ee46802a..fca15fc65d 100644 --- a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometAggregateExpressionBenchmark.scala +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometAggregateExpressionBenchmark.scala @@ -125,6 +125,21 @@ object CometAggregateExpressionBenchmark extends CometBenchmarkBase { "percentile_double_high_card", "SELECT percentile(c_double, 0.5) FROM parquetV1Table GROUP BY high_card_grp")) + // max_by runs natively only when both the value and ordering are fixed-length types (a + // variable-length value or ordering forces Spark's SortAggregate, which Comet does not run). + private val maxByAggregates = List( + AggExprConfig("max_by_int", "SELECT max_by(c_int, c_long) FROM parquetV1Table GROUP BY grp"), + AggExprConfig( + "max_by_double", + "SELECT max_by(c_double, c_int) FROM parquetV1Table GROUP BY grp"), + AggExprConfig( + "max_by_decimal", + "SELECT max_by(c_decimal, c_int) FROM parquetV1Table GROUP BY grp"), + AggExprConfig( + "max_by_high_card", + "SELECT max_by(c_int, c_long) FROM parquetV1Table GROUP BY high_card_grp"), + AggExprConfig("max_by_global", "SELECT max_by(c_int, c_long) FROM parquetV1Table")) + override def runCometBenchmark(mainArgs: Array[String]): Unit = { val values = 1024 * 1024 @@ -148,7 +163,7 @@ object CometAggregateExpressionBenchmark extends CometBenchmarkBase { val allAggregates = basicAggregates ++ statisticalAggregates ++ bitwiseAggregates ++ multiKeyAggregates ++ multiAggregates ++ decimalAggregates ++ - highCardinalityAggregates ++ percentileAggregates + highCardinalityAggregates ++ percentileAggregates ++ maxByAggregates allAggregates.foreach { config => runExpressionBenchmark(config.name, v, config.query, config.extraCometConfigs) From 6aa656ccce507fdc363327199b16c6901e36d35d Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 3 Jul 2026 09:58:33 -0600 Subject: [PATCH 2/2] feat: add min_by, vectorized GroupsAccumulator, and benchmark Add min_by alongside max_by, both served by the shared native MaxMinBy aggregate. Refactor the Scala serde into a shared CometMaxMinBy base. Add a vectorized GroupsAccumulator that keeps each group's best (value, ordering) pair as Arrow row-format bytes, so grouped max_by/min_by avoids the per-group ScalarValue work of the generic GroupsAccumulatorAdapter. Extend the aggregate microbenchmark and SQL file tests to cover both max_by and min_by. --- .../expression-audits/agg_funcs.md | 7 + docs/source/user-guide/latest/expressions.md | 2 +- native/core/src/execution/planner.rs | 7 + native/proto/src/proto/expr.proto | 8 + native/spark-expr/src/agg_funcs/max_min_by.rs | 322 +++++++++++++++++- .../apache/comet/serde/QueryPlanSerde.scala | 1 + .../org/apache/comet/serde/aggregates.scala | 50 ++- .../expressions/aggregate/min_by.sql | 216 ++++++++++++ .../CometAggregateExpressionBenchmark.scala | 15 +- 9 files changed, 595 insertions(+), 33 deletions(-) create mode 100644 spark/src/test/resources/sql-tests/expressions/aggregate/min_by.sql diff --git a/docs/source/contributor-guide/expression-audits/agg_funcs.md b/docs/source/contributor-guide/expression-audits/agg_funcs.md index 1baf3115ee..47764dd1a3 100644 --- a/docs/source/contributor-guide/expression-audits/agg_funcs.md +++ b/docs/source/contributor-guide/expression-audits/agg_funcs.md @@ -53,6 +53,13 @@ - Spark 4.0.1 (audited 2026-06-24): `replacement` becomes `lazy val`; semantics unchanged. - Spark 4.1.1 (audited 2026-06-24): identical to 4.0.1. +## min_by + +- Spark 3.4.3 (2026-07-03): `MinBy` shares the abstract `MaxMinBy` `DeclarativeAggregate` with `MaxBy`, differing only in the comparison direction (`least` / `<` instead of `greatest` / `>`). Registered as `expression[MinBy]("min_by")`. Null orderings are ignored, the value paired with the minimum ordering is returned (and may itself be null), and an all-null-ordering group yields null. Comet serves it through the same native `MaxMinBy` aggregate as `max_by`, with the same fixed-length value and ordering restriction (variable-length or nested types fall back to Spark). Non-deterministic on ties, matching Spark. +- Spark 3.5.8 (2026-07-03): aggregate logic identical to 3.4.3. +- Spark 4.0.1 (2026-07-03): aggregate logic identical to 3.4.3; only the `@ExpressionDescription` example and note text differ. +- Spark 4.1.1 (2026-07-03): aggregate logic identical to 3.4.3. The 3-argument top-k form `min_by(x, y, k)` (via `MinByBuilder` / `MaxMinByK`) is only present on Spark master, so Comet handles only the 2-argument form. + ## percentile - Spark 3.4.3 (audited 2026-06-24): `Percentile(child, percentageExpression, frequencyExpression, ..., reverse)` over `PercentileBase`. Exact percentile using `index = p * (n - 1)` linear interpolation, NULL inputs skipped, empty/all-null group returns NULL. `CometPercentile` maps the single-literal-percentage, default-frequency, numeric-input, ascending form to DataFusion's `percentile_cont` (same interpolation). Array-of-percentages, a non-default frequency argument, descending order, and interval inputs fall back to Spark. diff --git a/docs/source/user-guide/latest/expressions.md b/docs/source/user-guide/latest/expressions.md index c696d801ec..9a172f0473 100644 --- a/docs/source/user-guide/latest/expressions.md +++ b/docs/source/user-guide/latest/expressions.md @@ -99,7 +99,7 @@ The tables below list every Spark built-in expression with its current status. | `mean` | ✅ | | | `median` | ✅ | Rewrites to `percentile(col, 0.5)`; falls back by default, opt-in via allowIncompatible ([#4719](https://github.com/apache/datafusion-comet/issues/4719)) | | `min` | ✅ | | -| `min_by` | 🔜 | [#3841](https://github.com/apache/datafusion-comet/issues/3841) | +| `min_by` | ✅ | Value and ordering must be fixed-length types | | `mode` | 🔜 | [#3970](https://github.com/apache/datafusion-comet/issues/3970) | | `percentile` | ✅ | Single literal percentage on numeric input; array of percentages and a frequency argument fall back to Spark. Falls back by default, opt-in via allowIncompatible ([#4719](https://github.com/apache/datafusion-comet/issues/4719)) | | `percentile_cont` | ✅ | Spark 4.0+ `WITHIN GROUP (ORDER BY ...)`; ascending only, `DESC` falls back to Spark. Falls back by default, opt-in via allowIncompatible ([#4719](https://github.com/apache/datafusion-comet/issues/4719)) | diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 82d3ed1984..77ec85ac0e 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -2660,6 +2660,13 @@ impl PhysicalPlanner { let func = AggregateUDF::new_from_impl(MaxMinBy::new_max_by()); Self::create_aggr_func_expr("max_by", schema, vec![value, ordering], func) } + AggExprStruct::MinBy(expr) => { + let value = self.create_expr(expr.value.as_ref().unwrap(), Arc::clone(&schema))?; + let ordering = + self.create_expr(expr.ordering.as_ref().unwrap(), Arc::clone(&schema))?; + let func = AggregateUDF::new_from_impl(MaxMinBy::new_min_by()); + Self::create_aggr_func_expr("min_by", schema, vec![value, ordering], func) + } } } diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index d164c357c1..300183f88d 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -147,6 +147,7 @@ message AggExpr { CollectSet collectSet = 17; Percentile percentile = 18; MaxBy maxBy = 19; + MinBy minBy = 20; } // Optional filter expression for SQL FILTER (WHERE ...) clause. @@ -285,6 +286,13 @@ message MaxBy { Expr ordering = 2; } +message MinBy { + // The value returned by the aggregate (associated with the minimum ordering). + Expr value = 1; + // The ordering expression whose minimum selects the returned value. + Expr ordering = 2; +} + enum EvalMode { LEGACY = 0; TRY = 1; diff --git a/native/spark-expr/src/agg_funcs/max_min_by.rs b/native/spark-expr/src/agg_funcs/max_min_by.rs index b4de0e1e37..f3639675b6 100644 --- a/native/spark-expr/src/agg_funcs/max_min_by.rs +++ b/native/spark-expr/src/agg_funcs/max_min_by.rs @@ -15,14 +15,15 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{Array, ArrayRef}; +use arrow::array::{new_null_array, Array, ArrayRef, BooleanArray}; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, FieldRef}; -use arrow::row::RowConverter; -use arrow::row::SortField; +use arrow::row::{OwnedRow, RowConverter, SortField}; use datafusion::common::{Result, ScalarValue}; use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs}; -use datafusion::logical_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility}; +use datafusion::logical_expr::{ + Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, Signature, Volatility, +}; use datafusion::physical_expr::expressions::format_state_name; use std::mem::size_of_val; use std::sync::Arc; @@ -65,7 +66,6 @@ impl MaxMinBy { } /// Create a `min_by` aggregate. - #[allow(dead_code)] pub fn new_min_by() -> Self { Self { name: "min_by".to_string(), @@ -115,6 +115,34 @@ impl AggregateUDFImpl for MaxMinBy { )), ]) } + + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { + true + } + + fn create_groups_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + let value_type = args.exprs[0].data_type(args.schema)?; + let ordering_type = args.exprs[1].data_type(args.schema)?; + Ok(Box::new(MaxMinByGroupsAccumulator::try_new( + value_type, + ordering_type, + self.is_max, + )?)) + } +} + +/// Sort options that make the wanted extremum encode to the largest row bytes: ascending for +/// `max_by` (largest ordering wins), descending for `min_by` (smallest ordering wins). Nulls +/// sort first (smallest) so they are never selected as the extremum; null orderings are also +/// skipped explicitly. +fn extremum_sort_options(is_max: bool) -> SortOptions { + SortOptions { + descending: !is_max, + nulls_first: true, + } } /// Accumulator that tracks the running `(value, ordering)` pair for the extremum ordering. @@ -139,14 +167,8 @@ impl MaxMinByAccumulator { fn sort_options(&self) -> SortOptions { // Encode the ordering column into arrow's row format so that the extremum can be - // found for any orderable type with a single comparison. For `max_by` we sort - // ascending, so the largest ordering yields the largest row bytes. For `min_by` - // we sort descending, so the smallest ordering yields the largest row bytes; the - // same "argmax of the row bytes" scan then selects the minimum. - SortOptions { - descending: !self.is_max, - nulls_first: true, - } + // found for any orderable type with a single comparison. + extremum_sort_options(self.is_max) } /// Apply a batch of `(value, ordering)` columns, keeping the value paired with the @@ -226,10 +248,166 @@ impl Accumulator for MaxMinByAccumulator { } } +/// Vectorized grouped accumulator for `max_by` / `min_by`. +/// +/// Each group keeps the best `(value, ordering)` pair as Arrow row-format bytes. The ordering +/// rows are byte-comparable, so selecting the extremum for a batch is a single row conversion +/// plus per-row byte comparisons, avoiding the per-group `ScalarValue` work of the generic +/// `GroupsAccumulatorAdapter`. +struct MaxMinByGroupsAccumulator { + /// Converts and compares the ordering column. Its sort options encode the wanted extremum + /// as the largest row bytes (see `extremum_sort_options`). + ordering_converter: RowConverter, + /// Converts the value column to and from row bytes. Sort options are irrelevant here since + /// values are only stored, never compared. + value_converter: RowConverter, + /// Row bytes for a null value, used for groups that have not been updated (or whose winning + /// value is null). + null_value_row: OwnedRow, + /// Row bytes for a null ordering, used for groups that have seen no non-null ordering. + null_ordering_row: OwnedRow, + /// Per-group winning value (row bytes). + best_value: Vec, + /// Per-group winning ordering (row bytes). + best_ordering: Vec, + /// Per-group flag: has a non-null ordering been seen yet? + has_ordering: Vec, +} + +impl MaxMinByGroupsAccumulator { + fn try_new(value_type: DataType, ordering_type: DataType, is_max: bool) -> Result { + let ordering_converter = RowConverter::new(vec![SortField::new_with_options( + ordering_type.clone(), + extremum_sort_options(is_max), + )])?; + let value_converter = RowConverter::new(vec![SortField::new(value_type.clone())])?; + let null_ordering_row = ordering_converter + .convert_columns(&[new_null_array(&ordering_type, 1)])? + .row(0) + .owned(); + let null_value_row = value_converter + .convert_columns(&[new_null_array(&value_type, 1)])? + .row(0) + .owned(); + Ok(Self { + ordering_converter, + value_converter, + null_value_row, + null_ordering_row, + best_value: Vec::new(), + best_ordering: Vec::new(), + has_ordering: Vec::new(), + }) + } + + fn resize(&mut self, total_num_groups: usize) { + self.best_value + .resize(total_num_groups, self.null_value_row.clone()); + self.best_ordering + .resize(total_num_groups, self.null_ordering_row.clone()); + self.has_ordering.resize(total_num_groups, false); + } + + /// Shared update/merge logic: `values[0]` is the value column, `values[1]` the ordering + /// column. Rows with a null ordering are ignored; on a tie the later row wins. + fn update_groups( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + self.resize(total_num_groups); + let value_rows = self + .value_converter + .convert_columns(&[Arc::clone(&values[0])])?; + let ordering_arr = &values[1]; + let ordering_rows = self + .ordering_converter + .convert_columns(&[Arc::clone(ordering_arr)])?; + + for (idx, &group_index) in group_indices.iter().enumerate() { + if let Some(filter) = opt_filter { + if !filter.is_valid(idx) || !filter.value(idx) { + continue; + } + } + if ordering_arr.is_null(idx) { + continue; + } + let candidate = ordering_rows.row(idx); + let take = !self.has_ordering[group_index] + || candidate >= self.best_ordering[group_index].row(); + if take { + self.best_ordering[group_index] = candidate.owned(); + self.best_value[group_index] = value_rows.row(idx).owned(); + self.has_ordering[group_index] = true; + } + } + Ok(()) + } +} + +impl GroupsAccumulator for MaxMinByGroupsAccumulator { + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + self.update_groups(values, group_indices, opt_filter, total_num_groups) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + // State columns mirror the input columns: [value, ordering]. + self.update_groups(values, group_indices, opt_filter, total_num_groups) + } + + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + let value_rows = emit_to.take_needed(&mut self.best_value); + let _ = emit_to.take_needed(&mut self.best_ordering); + let _ = emit_to.take_needed(&mut self.has_ordering); + let arrays = self + .value_converter + .convert_rows(value_rows.iter().map(|r| r.row()))?; + Ok(Arc::clone(&arrays[0])) + } + + fn state(&mut self, emit_to: EmitTo) -> Result> { + let value_rows = emit_to.take_needed(&mut self.best_value); + let ordering_rows = emit_to.take_needed(&mut self.best_ordering); + let _ = emit_to.take_needed(&mut self.has_ordering); + let value_arrays = self + .value_converter + .convert_rows(value_rows.iter().map(|r| r.row()))?; + let ordering_arrays = self + .ordering_converter + .convert_rows(ordering_rows.iter().map(|r| r.row()))?; + Ok(vec![ + Arc::clone(&value_arrays[0]), + Arc::clone(&ordering_arrays[0]), + ]) + } + + fn size(&self) -> usize { + size_of_val(self) + + (self.best_value.capacity() + self.best_ordering.capacity()) + * std::mem::size_of::() + + self.has_ordering.capacity() + } +} + #[cfg(test)] mod tests { use super::*; - use arrow::array::{Float64Array, Int32Array, StringArray}; + use arrow::array::{AsArray, Float64Array, Int32Array, StringArray}; fn max_by_acc(value_type: DataType, ordering_type: DataType) -> MaxMinByAccumulator { MaxMinByAccumulator::try_new(value_type, ordering_type, true).unwrap() @@ -335,4 +513,120 @@ mod tests { } assert_eq!(merged.evaluate().unwrap(), single); } + + // ---- GroupsAccumulator tests ---- + + fn max_by_groups(value_type: DataType, ordering_type: DataType) -> MaxMinByGroupsAccumulator { + MaxMinByGroupsAccumulator::try_new(value_type, ordering_type, true).unwrap() + } + + fn min_by_groups(value_type: DataType, ordering_type: DataType) -> MaxMinByGroupsAccumulator { + MaxMinByGroupsAccumulator::try_new(value_type, ordering_type, false).unwrap() + } + + fn eval_int(acc: &mut MaxMinByGroupsAccumulator) -> Vec> { + acc.evaluate(EmitTo::All) + .unwrap() + .as_primitive::() + .iter() + .collect() + } + + #[test] + fn groups_max_by_multi_group() { + let mut acc = max_by_groups(DataType::Int32, DataType::Int32); + // group 0: values 10,30 orderings 1,2 -> 30 ; group 1: values 20,40 orderings 5,4 -> 20 + let values: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 30, 40])); + let ordering: ArrayRef = Arc::new(Int32Array::from(vec![1, 5, 2, 4])); + acc.update_batch(&[values, ordering], &[0, 1, 0, 1], None, 2) + .unwrap(); + assert_eq!(eval_int(&mut acc), vec![Some(30), Some(20)]); + } + + #[test] + fn groups_min_by_multi_group() { + let mut acc = min_by_groups(DataType::Int32, DataType::Int32); + let values: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 30, 40])); + let ordering: ArrayRef = Arc::new(Int32Array::from(vec![1, 5, 2, 4])); + acc.update_batch(&[values, ordering], &[0, 1, 0, 1], None, 2) + .unwrap(); + // group 0: min ordering 1 -> 10 ; group 1: min ordering 4 -> 40 + assert_eq!(eval_int(&mut acc), vec![Some(10), Some(40)]); + } + + #[test] + fn groups_null_ordering_and_empty_group() { + let mut acc = max_by_groups(DataType::Int32, DataType::Int32); + // group 0: (10,null),(30,5) -> 30 ; group 1: no rows -> null ; group 2: (40,null) -> null + let values: ArrayRef = Arc::new(Int32Array::from(vec![Some(10), Some(30), Some(40)])); + let ordering: ArrayRef = Arc::new(Int32Array::from(vec![None, Some(5), None])); + acc.update_batch(&[values, ordering], &[0, 0, 2], None, 3) + .unwrap(); + assert_eq!(eval_int(&mut acc), vec![Some(30), None, None]); + } + + #[test] + fn groups_null_value_at_extremum() { + let mut acc = max_by_groups(DataType::Int32, DataType::Int32); + let values: ArrayRef = Arc::new(Int32Array::from(vec![Some(10), None])); + let ordering: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), Some(9)])); + acc.update_batch(&[values, ordering], &[0, 0], None, 1) + .unwrap(); + assert_eq!(eval_int(&mut acc), vec![None]); + } + + #[test] + fn groups_merge_matches_single_shot() { + let single = { + let mut acc = max_by_groups(DataType::Int32, DataType::Int32); + let values: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5, 6])); + let ordering: ArrayRef = Arc::new(Int32Array::from(vec![1, 6, 3, 2, 5, 4])); + acc.update_batch(&[values, ordering], &[0, 0, 0, 0, 0, 0], None, 1) + .unwrap(); + eval_int(&mut acc) + }; + + let mut left = max_by_groups(DataType::Int32, DataType::Int32); + left.update_batch( + &[ + Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef, + Arc::new(Int32Array::from(vec![1, 6, 3])) as ArrayRef, + ], + &[0, 0, 0], + None, + 1, + ) + .unwrap(); + let mut right = max_by_groups(DataType::Int32, DataType::Int32); + right + .update_batch( + &[ + Arc::new(Int32Array::from(vec![4, 5, 6])) as ArrayRef, + Arc::new(Int32Array::from(vec![2, 5, 4])) as ArrayRef, + ], + &[0, 0, 0], + None, + 1, + ) + .unwrap(); + + let mut merged = max_by_groups(DataType::Int32, DataType::Int32); + for acc in [&mut left, &mut right] { + let state = acc.state(EmitTo::All).unwrap(); + merged.merge_batch(&state, &[0], None, 1).unwrap(); + } + assert_eq!(eval_int(&mut merged), single); + } + + #[test] + fn groups_filter_is_respected() { + let mut acc = max_by_groups(DataType::Int32, DataType::Int32); + let values: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 30])); + let ordering: ArrayRef = Arc::new(Int32Array::from(vec![1, 9, 2])); + // Filter out the row with ordering 9, so the max becomes ordering 2 -> value 30. + let filter = BooleanArray::from(vec![true, false, true]); + acc.update_batch(&[values, ordering], &[0, 0, 0], Some(&filter), 1) + .unwrap(); + assert_eq!(eval_int(&mut acc), vec![Some(30)]); + } } diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 1785fa384e..753417fe08 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -401,6 +401,7 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim { classOf[Max] -> CometMax, classOf[MaxBy] -> CometMaxBy, classOf[Min] -> CometMin, + classOf[MinBy] -> CometMinBy, classOf[Percentile] -> CometPercentile, classOf[StddevPop] -> CometStddevPop, classOf[StddevSamp] -> CometStddevSamp, diff --git a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala index 39878c9c4c..4f02329b5f 100644 --- a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala +++ b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala @@ -22,7 +22,7 @@ package org.apache.comet.serde import scala.jdk.CollectionConverters._ import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal} -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate, CentralMomentAgg, CollectSet, Corr, Count, Covariance, CovPopulation, CovSample, First, Last, Max, MaxBy, Min, Percentile, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate, CentralMomentAgg, CollectSet, Corr, Count, Covariance, CovPopulation, CovSample, First, Last, Max, MaxBy, MaxMinBy, Min, MinBy, Percentile, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{ByteType, DecimalType, DoubleType, IntegerType, LongType, NumericType, ShortType, StringType} @@ -105,10 +105,20 @@ object CometMax extends CometAggregateExpressionSerde[Max] { } } -object CometMaxBy extends CometAggregateExpressionSerde[MaxBy] { +/** + * Shared serde for `max_by` and `min_by`. Both are 2-argument `MaxMinBy` `DeclarativeAggregate`s + * differing only in the comparison direction, so the value/ordering handling is identical. + */ +abstract class CometMaxMinBy[T <: MaxMinBy] extends CometAggregateExpressionSerde[T] { + + /** `true` for `max_by`, `false` for `min_by`. */ + protected def isMax: Boolean + + /** `"maximum"` or `"minimum"`, used in the non-determinism note. */ + private def extremum: String = if (isMax) "maximum" else "minimum" override def getCompatibleNotes(): Seq[String] = Seq( - "This function is non-deterministic when multiple rows share the maximum ordering value." + + s"This function is non-deterministic when multiple rows share the $extremum ordering value." + " Results may differ from Spark in that case.") override def getUnsupportedReasons(): Seq[String] = Seq( @@ -117,7 +127,7 @@ object CometMaxBy extends CometAggregateExpressionSerde[MaxBy] { " struct forces Spark's `SortAggregate`, which Comet does not accelerate, so the aggregate" + " falls back to Spark.") - override def getSupportLevel(expr: MaxBy): SupportLevel = { + override def getSupportLevel(expr: T): SupportLevel = { // Both the value and ordering must be fixed-length types. Spark only uses HashAggregate // (the aggregate operator Comet accelerates) when the aggregation buffer is mutable; the // buffer holds both the running value and the running ordering, so a variable-length type @@ -135,7 +145,7 @@ object CometMaxBy extends CometAggregateExpressionSerde[MaxBy] { override def convert( aggExpr: AggregateExpression, - expr: MaxBy, + expr: T, inputs: Seq[Attribute], binding: Boolean, conf: SQLConf): Option[ExprOuterClass.AggExpr] = { @@ -143,15 +153,19 @@ object CometMaxBy extends CometAggregateExpressionSerde[MaxBy] { val orderingExpr = exprToProto(expr.orderingExpr, inputs, binding) if (valueExpr.isDefined && orderingExpr.isDefined) { - val builder = ExprOuterClass.MaxBy.newBuilder() - builder.setValue(valueExpr.get) - builder.setOrdering(orderingExpr.get) - - Some( - ExprOuterClass.AggExpr - .newBuilder() - .setMaxBy(builder) - .build()) + val aggBuilder = ExprOuterClass.AggExpr.newBuilder() + if (isMax) { + val builder = ExprOuterClass.MaxBy.newBuilder() + builder.setValue(valueExpr.get) + builder.setOrdering(orderingExpr.get) + aggBuilder.setMaxBy(builder) + } else { + val builder = ExprOuterClass.MinBy.newBuilder() + builder.setValue(valueExpr.get) + builder.setOrdering(orderingExpr.get) + aggBuilder.setMinBy(builder) + } + Some(aggBuilder.build()) } else { withFallbackReason(aggExpr, expr.valueExpr, expr.orderingExpr) None @@ -159,6 +173,14 @@ object CometMaxBy extends CometAggregateExpressionSerde[MaxBy] { } } +object CometMaxBy extends CometMaxMinBy[MaxBy] { + override protected def isMax: Boolean = true +} + +object CometMinBy extends CometMaxMinBy[MinBy] { + override protected def isMax: Boolean = false +} + object CometCount extends CometAggregateExpressionSerde[Count] { override def convert( aggExpr: AggregateExpression, diff --git a/spark/src/test/resources/sql-tests/expressions/aggregate/min_by.sql b/spark/src/test/resources/sql-tests/expressions/aggregate/min_by.sql new file mode 100644 index 0000000000..0f0101e1e7 --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/aggregate/min_by.sql @@ -0,0 +1,216 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- min_by(x, y) returns the value of x associated with the minimum value of y. +-- +-- The value (x) must be a fixed-length type: Spark only uses HashAggregate (the aggregate +-- operator Comet accelerates) when the aggregation buffer is mutable, so variable-length +-- value types such as string force SortAggregate and fall back to Spark. +-- +-- Ordering values are kept unique within each group so results are deterministic (min_by is +-- non-deterministic when several rows tie on the minimum ordering). + +-- ============================================================ +-- Setup: tables +-- ============================================================ + +statement +CREATE TABLE nb_src(v int, ord int, grp string) USING parquet + +statement +INSERT INTO nb_src VALUES + (10, 10, 'g1'), (20, 50, 'g1'), (30, 20, 'g1'), + (40, 40, 'g2'), (50, 5, 'g2'), (60, 30, 'g2'), + (70, 99, 'g3') + +-- ordering NULLs are ignored; a group of all-NULL orderings yields NULL +statement +CREATE TABLE nb_nulls(v int, ord int, grp string) USING parquet + +statement +INSERT INTO nb_nulls VALUES + (1, 10, 'g1'), (2, NULL, 'g1'), (3, 5, 'g1'), + (4, NULL, 'g2'), (5, NULL, 'g2'), + (6, 7, 'g3'), (NULL, 1, 'g3') + +statement +CREATE TABLE nb_empty(v int, ord int) USING parquet + +-- ============================================================ +-- Global aggregate (no GROUP BY) +-- ============================================================ + +query +SELECT min_by(v, ord) FROM nb_src + +-- ============================================================ +-- GROUP BY +-- ============================================================ + +query +SELECT grp, min_by(v, ord) FROM nb_src GROUP BY grp ORDER BY grp + +-- ============================================================ +-- NULL handling: NULL orderings ignored; the value paired with the +-- minimum ordering may itself be NULL; all-NULL orderings yield NULL +-- ============================================================ + +query +SELECT grp, min_by(v, ord) FROM nb_nulls GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Empty table yields NULL +-- ============================================================ + +query +SELECT min_by(v, ord) FROM nb_empty + +-- ============================================================ +-- Literal arguments (evaluated natively; constant folding is disabled) +-- ============================================================ + +query +SELECT min_by(5, 10), min_by(CAST(NULL AS INT), 20) + +-- ============================================================ +-- Mixed with other aggregates +-- ============================================================ + +query +SELECT grp, min_by(v, ord), count(*), min(ord) +FROM nb_src GROUP BY grp ORDER BY grp + +-- ============================================================ +-- BigInt value +-- ============================================================ + +statement +CREATE TABLE nb_long(val bigint, ord int, grp string) USING parquet + +statement +INSERT INTO nb_long VALUES + (1000000000000, 1, 'a'), (2000000000000, 3, 'a'), (3000000000000, 2, 'a'), + (4000000000000, 5, 'b'), (5000000000000, 4, 'b') + +query +SELECT grp, min_by(val, ord) FROM nb_long GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Double value and double ordering (NaN is the maximum in Spark, so it is +-- never selected by min_by unless it is the only non-null ordering) +-- ============================================================ + +statement +CREATE TABLE nb_dbl(v double, ord double, grp string) USING parquet + +statement +INSERT INTO nb_dbl VALUES + (1.1, 1.5, 'g1'), (2.2, 2.5, 'g1'), (3.3, 0.5, 'g1'), + (4.4, 1.0, 'g2'), (5.5, CAST('NaN' AS DOUBLE), 'g2'), (6.6, 100.0, 'g2'), + (7.7, CAST('-Infinity' AS DOUBLE), 'g3'), (8.8, 3.0, 'g3') + +query +SELECT grp, min_by(v, ord) FROM nb_dbl GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Decimal value and decimal ordering +-- ============================================================ + +statement +CREATE TABLE nb_dec(v decimal(10,2), ord decimal(10,2), grp string) USING parquet + +statement +INSERT INTO nb_dec VALUES + (10.01, 1.50, 'g1'), (20.02, 9.99, 'g1'), (30.03, 5.00, 'g1'), + (40.04, 2.00, 'g2'), (50.05, 8.25, 'g2') + +query +SELECT grp, min_by(v, ord) FROM nb_dec GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Date / timestamp value and ordering +-- ============================================================ + +statement +CREATE TABLE nb_dt(d date, ts timestamp, ord int, grp string) USING parquet + +statement +INSERT INTO nb_dt VALUES + (DATE '2024-01-01', TIMESTAMP '2024-01-01 00:00:00', 1, 'g1'), + (DATE '2024-06-15', TIMESTAMP '2024-06-15 12:30:00', 3, 'g1'), + (DATE '2023-12-31', TIMESTAMP '2023-12-31 23:59:59', 2, 'g1'), + (DATE '2024-03-01', TIMESTAMP '2024-03-01 08:00:00', 1, 'g2') + +query +SELECT grp, min_by(d, ord) FROM nb_dt GROUP BY grp ORDER BY grp + +query +SELECT grp, min_by(ts, ord) FROM nb_dt GROUP BY grp ORDER BY grp + +query +SELECT grp, min_by(ord, d) FROM nb_dt GROUP BY grp ORDER BY grp + +query +SELECT grp, min_by(ord, ts) FROM nb_dt GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Negative and boundary ordering values +-- ============================================================ + +statement +CREATE TABLE nb_bound(v int, ord bigint, grp string) USING parquet + +statement +INSERT INTO nb_bound VALUES + (1, -100, 'g1'), (2, -5, 'g1'), (3, -50, 'g1'), + (4, -9223372036854775808, 'g2'), (5, 9223372036854775807, 'g2'), (6, 0, 'g2'), + (7, -2147483648, 'g3'), (8, 2147483647, 'g3') + +query +SELECT grp, min_by(v, ord) FROM nb_bound GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Multiple min_by and value = ordering column +-- ============================================================ + +query +SELECT grp, min_by(v, ord), min_by(ord, v) FROM nb_src GROUP BY grp ORDER BY grp + +query +SELECT grp, min_by(ord, ord) FROM nb_src GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Boolean value +-- ============================================================ + +statement +CREATE TABLE nb_bool(v boolean, ord int, grp string) USING parquet + +statement +INSERT INTO nb_bool VALUES + (true, 1, 'a'), (false, 3, 'a'), (true, 2, 'a'), + (false, 5, 'b'), (true, 4, 'b') + +query +SELECT grp, min_by(v, ord) FROM nb_bool GROUP BY grp ORDER BY grp + +-- ============================================================ +-- max_by and min_by in the same query +-- ============================================================ + +query +SELECT grp, max_by(v, ord), min_by(v, ord) FROM nb_src GROUP BY grp ORDER BY grp diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometAggregateExpressionBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometAggregateExpressionBenchmark.scala index fca15fc65d..07d6446857 100644 --- a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometAggregateExpressionBenchmark.scala +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometAggregateExpressionBenchmark.scala @@ -125,9 +125,9 @@ object CometAggregateExpressionBenchmark extends CometBenchmarkBase { "percentile_double_high_card", "SELECT percentile(c_double, 0.5) FROM parquetV1Table GROUP BY high_card_grp")) - // max_by runs natively only when both the value and ordering are fixed-length types (a + // max_by / min_by run natively only when both the value and ordering are fixed-length types (a // variable-length value or ordering forces Spark's SortAggregate, which Comet does not run). - private val maxByAggregates = List( + private val maxMinByAggregates = List( AggExprConfig("max_by_int", "SELECT max_by(c_int, c_long) FROM parquetV1Table GROUP BY grp"), AggExprConfig( "max_by_double", @@ -138,7 +138,14 @@ object CometAggregateExpressionBenchmark extends CometBenchmarkBase { AggExprConfig( "max_by_high_card", "SELECT max_by(c_int, c_long) FROM parquetV1Table GROUP BY high_card_grp"), - AggExprConfig("max_by_global", "SELECT max_by(c_int, c_long) FROM parquetV1Table")) + AggExprConfig("max_by_global", "SELECT max_by(c_int, c_long) FROM parquetV1Table"), + AggExprConfig("min_by_int", "SELECT min_by(c_int, c_long) FROM parquetV1Table GROUP BY grp"), + AggExprConfig( + "min_by_double", + "SELECT min_by(c_double, c_int) FROM parquetV1Table GROUP BY grp"), + AggExprConfig( + "min_by_high_card", + "SELECT min_by(c_int, c_long) FROM parquetV1Table GROUP BY high_card_grp")) override def runCometBenchmark(mainArgs: Array[String]): Unit = { val values = 1024 * 1024 @@ -163,7 +170,7 @@ object CometAggregateExpressionBenchmark extends CometBenchmarkBase { val allAggregates = basicAggregates ++ statisticalAggregates ++ bitwiseAggregates ++ multiKeyAggregates ++ multiAggregates ++ decimalAggregates ++ - highCardinalityAggregates ++ percentileAggregates ++ maxByAggregates + highCardinalityAggregates ++ percentileAggregates ++ maxMinByAggregates allAggregates.foreach { config => runExpressionBenchmark(config.name, v, config.query, config.extraCometConfigs)