diff --git a/docs/source/user-guide/latest/expressions.md b/docs/source/user-guide/latest/expressions.md index ae9530d8ec..46ea5314fb 100644 --- a/docs/source/user-guide/latest/expressions.md +++ b/docs/source/user-guide/latest/expressions.md @@ -100,7 +100,7 @@ The tables below list every Spark built-in expression with its current status. | `median` | ✅ | Rewrites to `percentile(col, 0.5)`; falls back by default, opt-in via allowIncompatible ([#4719](https://github.com/apache/datafusion-comet/issues/4719)) | | `min` | ✅ | | | `min_by` | 🔜 | [#3841](https://github.com/apache/datafusion-comet/issues/3841) | -| `mode` | 🔜 | [#3970](https://github.com/apache/datafusion-comet/issues/3970) | +| `mode` | ✅ | `mode(col)` only; Spark breaks ties non-deterministically, so Comet returns the smallest tied value and falls back by default, opt-in via allowIncompatible ([#3970](https://github.com/apache/datafusion-comet/issues/3970)) | | `percentile` | ✅ | Single literal percentage on numeric input; array of percentages and a frequency argument fall back to Spark. Falls back by default, opt-in via allowIncompatible ([#4719](https://github.com/apache/datafusion-comet/issues/4719)) | | `percentile_cont` | ✅ | Spark 4.0+ `WITHIN GROUP (ORDER BY ...)`; ascending only, `DESC` falls back to Spark. Falls back by default, opt-in via allowIncompatible ([#4719](https://github.com/apache/datafusion-comet/issues/4719)) | | `percentile_disc` | 🔜 | Percentile aggregate | diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 049c8c13e9..3c56088877 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -130,8 +130,8 @@ use datafusion_comet_proto::{ use datafusion_comet_spark_expr::{ jvm_udf::JvmScalarUdfExpr, ArrayInsert, Avg, AvgDecimal, Cast, CheckOverflow, Correlation, Covariance, CreateNamedStruct, DecimalRescaleCheckOverflow, GetArrayStructFields, - GetStructField, IfExpr, ListExtract, NormalizeNaNAndZero, SparkCastOptions, Stddev, SumDecimal, - ToJson, UnboundColumn, Variance, WideDecimalBinaryExpr, WideDecimalOp, + GetStructField, IfExpr, ListExtract, Mode, NormalizeNaNAndZero, SparkCastOptions, Stddev, + SumDecimal, ToJson, UnboundColumn, Variance, WideDecimalBinaryExpr, WideDecimalOp, }; use itertools::Itertools; use jni::objects::{Global, JObject}; @@ -2643,6 +2643,12 @@ impl PhysicalPlanner { let func = AggregateUDF::new_from_impl(SparkCollectSet::new()); Self::create_aggr_func_expr("collect_set", schema, vec![child], func) } + AggExprStruct::Mode(expr) => { + let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?; + let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); + let func = AggregateUDF::new_from_impl(Mode::new(datatype)); + Self::create_aggr_func_expr("mode", schema, vec![child], func) + } } } diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 103af9bf10..532a58e7d6 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -144,6 +144,7 @@ message AggExpr { BloomFilterAgg bloomFilterAgg = 16; CollectSet collectSet = 17; Percentile percentile = 18; + Mode mode = 19; } // Optional filter expression for SQL FILTER (WHERE ...) clause. @@ -275,6 +276,11 @@ message CollectSet { DataType datatype = 2; } +message Mode { + Expr child = 1; + DataType datatype = 2; +} + enum EvalMode { LEGACY = 0; TRY = 1; diff --git a/native/spark-expr/src/agg_funcs/mod.rs b/native/spark-expr/src/agg_funcs/mod.rs index 2a0322e46c..aaa80b2287 100644 --- a/native/spark-expr/src/agg_funcs/mod.rs +++ b/native/spark-expr/src/agg_funcs/mod.rs @@ -19,6 +19,7 @@ mod avg; mod avg_decimal; mod correlation; mod covariance; +mod mode; mod stddev; mod sum_decimal; mod sum_int; @@ -29,6 +30,7 @@ pub use avg::Avg; pub use avg_decimal::AvgDecimal; pub use correlation::Correlation; pub use covariance::Covariance; +pub use mode::Mode; pub use stddev::Stddev; pub use sum_decimal::SumDecimal; pub use sum_int::SumInteger; diff --git a/native/spark-expr/src/agg_funcs/mode.rs b/native/spark-expr/src/agg_funcs/mode.rs new file mode 100644 index 0000000000..4576fc3203 --- /dev/null +++ b/native/spark-expr/src/agg_funcs/mode.rs @@ -0,0 +1,496 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use arrow::array::{Array, ArrayRef, AsArray, BooleanArray, StructArray}; +use arrow::datatypes::{DataType, Field, FieldRef, Fields, Int64Type}; +use datafusion::common::{internal_datafusion_err, Result, ScalarValue}; +use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion::logical_expr::{ + Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, Signature, Volatility, +}; +use datafusion::physical_expr::expressions::format_state_name; +use std::cmp::Ordering; +use std::collections::HashMap; +use std::mem::size_of; +use std::sync::Arc; + +/// Spark's `mode` aggregate: returns the most frequent value within a group, ignoring NULLs. +/// +/// Spark breaks ties on the default `mode(col)` form non-deterministically (the value is chosen +/// by JVM `OpenHashMap` iteration order), which a native hash map cannot reproduce bit-for-bit. +/// Comet resolves ties deterministically by returning the smallest value, so this function is +/// registered as `Incompatible` on the Scala side and is opt-in via `allowIncompatible`. +/// +/// Float keys are normalized before counting (`-0.0` becomes `0.0` and every `NaN` becomes a +/// canonical `NaN`) to match Spark's `NormalizeFloatingNumbers` behaviour so that counts agree. +/// +/// Spark's `Mode` is a `TypedImperativeAggregate` with a single aggregation-buffer attribute, so +/// the intermediate state is a single struct field `{ values: list, counts: list }` (a +/// parallel-array encoding of the frequency map) to keep the partial/final buffer schemas aligned +/// with Spark. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct Mode { + name: String, + signature: Signature, + data_type: DataType, +} + +impl Mode { + pub fn new(data_type: DataType) -> Self { + Self { + name: "mode".to_string(), + signature: Signature::any(1, Volatility::Immutable), + data_type, + } + } +} + +/// Fields of the single struct state column `{values: list, counts: list}`. +fn state_struct_fields(data_type: &DataType) -> Fields { + let values_list = DataType::List(Arc::new(Field::new_list_field(data_type.clone(), true))); + let counts_list = DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))); + Fields::from(vec![ + Field::new("values", values_list, false), + Field::new("counts", counts_list, false), + ]) +} + +/// Build the single-column struct state array holding one `{values, counts}` row per map. +fn build_state(data_type: &DataType, maps: &[&HashMap]) -> Result { + let mut value_lists = Vec::with_capacity(maps.len()); + let mut count_lists = Vec::with_capacity(maps.len()); + for map in maps { + let mut values = Vec::with_capacity(map.len()); + let mut counts = Vec::with_capacity(map.len()); + for (value, &count) in map.iter() { + values.push(value.clone()); + counts.push(ScalarValue::Int64(Some(count))); + } + value_lists.push(ScalarValue::List(ScalarValue::new_list( + &values, data_type, true, + ))); + count_lists.push(ScalarValue::List(ScalarValue::new_list( + &counts, + &DataType::Int64, + true, + ))); + } + let values = ScalarValue::iter_to_array(value_lists)?; + let counts = ScalarValue::iter_to_array(count_lists)?; + Ok(StructArray::new( + state_struct_fields(data_type), + vec![values, counts], + None, + )) +} + +impl AggregateUDFImpl for Mode { + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(self.data_type.clone()) + } + + fn default_value(&self, _data_type: &DataType) -> Result { + ScalarValue::try_from(&self.data_type) + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + Ok(Box::new(ModeAccumulator::new(self.data_type.clone()))) + } + + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + Ok(vec![Arc::new(Field::new( + format_state_name(&self.name, "freq"), + DataType::Struct(state_struct_fields(&self.data_type)), + false, + ))]) + } + + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { + true + } + + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { + Ok(Box::new(ModeGroupsAccumulator::new(self.data_type.clone()))) + } +} + +/// Normalize a scalar key so that Spark's floating-point normalization is honoured: `-0.0` and +/// `0.0` collapse to the same key and all `NaN` bit patterns collapse to a canonical `NaN`. +fn normalize_key(value: ScalarValue) -> ScalarValue { + /// Collapse `-0.0`/`0.0` and every `NaN` to a canonical form for one float variant. + macro_rules! normalize_float { + ($variant:path, $f:expr, $nan:expr) => { + if $f == 0.0 { + $variant(Some(0.0)) + } else if $f.is_nan() { + $variant(Some($nan)) + } else { + $variant(Some($f)) + } + }; + } + match value { + ScalarValue::Float32(Some(f)) => normalize_float!(ScalarValue::Float32, f, f32::NAN), + ScalarValue::Float64(Some(f)) => normalize_float!(ScalarValue::Float64, f, f64::NAN), + other => other, + } +} + +/// Add each non-null value in `array` to `map`, normalizing float keys. +fn count_values(map: &mut HashMap, array: &ArrayRef, idx: usize) -> Result<()> { + if array.is_null(idx) { + return Ok(()); + } + let key = normalize_key(ScalarValue::try_from_array(array, idx)?); + *map.entry(key).or_insert(0) += 1; + Ok(()) +} + +/// Fold row `row` of the struct-state columns (`{values, counts}`) into `map`. +fn merge_state_row( + map: &mut HashMap, + values_list: &arrow::array::ListArray, + counts_list: &arrow::array::ListArray, + row: usize, +) -> Result<()> { + if values_list.is_null(row) { + return Ok(()); + } + let values = values_list.value(row); + let counts = counts_list.value(row); + let counts = counts + .as_primitive_opt::() + .ok_or_else(|| internal_datafusion_err!("mode state counts must be Int64"))?; + for i in 0..values.len() { + if values.is_null(i) { + continue; + } + let key = normalize_key(ScalarValue::try_from_array(&values, i)?); + *map.entry(key).or_insert(0) += counts.value(i); + } + Ok(()) +} + +/// Pick the mode from a frequency map: the value with the highest count, breaking ties by the +/// smallest value. Returns a null scalar of `data_type` when the map is empty. +fn eval_mode(counts: &HashMap, data_type: &DataType) -> Result { + let mut best: Option<(&ScalarValue, i64)> = None; + for (value, &count) in counts.iter() { + let wins = match best { + None => true, + Some((best_value, best_count)) => { + count > best_count + || (count == best_count + && value.partial_cmp(best_value) == Some(Ordering::Less)) + } + }; + if wins { + best = Some((value, count)); + } + } + match best { + Some((value, _)) => Ok(value.clone()), + None => ScalarValue::try_from(data_type), + } +} + +/// Non-grouped accumulator backing global `mode` aggregation. +#[derive(Debug)] +pub struct ModeAccumulator { + counts: HashMap, + data_type: DataType, +} + +impl ModeAccumulator { + fn new(data_type: DataType) -> Self { + Self { + counts: HashMap::new(), + data_type, + } + } +} + +impl Accumulator for ModeAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let array = &values[0]; + for i in 0..array.len() { + count_values(&mut self.counts, array, i)?; + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let structs = states[0].as_struct(); + let values_list = structs.column(0).as_list::(); + let counts_list = structs.column(1).as_list::(); + for row in 0..structs.len() { + merge_state_row(&mut self.counts, values_list, counts_list, row)?; + } + Ok(()) + } + + fn state(&mut self) -> Result> { + let array = build_state(&self.data_type, &[&self.counts])?; + Ok(vec![ScalarValue::Struct(Arc::new(array))]) + } + + fn evaluate(&mut self) -> Result { + eval_mode(&self.counts, &self.data_type) + } + + fn size(&self) -> usize { + size_of_val(self) + self.counts.capacity() * size_of::<(ScalarValue, i64)>() + } +} + +/// Vectorized grouped accumulator: one frequency map per group. +#[derive(Debug)] +pub struct ModeGroupsAccumulator { + groups: Vec>, + data_type: DataType, +} + +impl ModeGroupsAccumulator { + fn new(data_type: DataType) -> Self { + Self { + groups: Vec::new(), + data_type, + } + } + + fn resize(&mut self, total_num_groups: usize) { + if self.groups.len() < total_num_groups { + self.groups.resize_with(total_num_groups, HashMap::new); + } + } +} + +impl GroupsAccumulator for ModeGroupsAccumulator { + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + self.resize(total_num_groups); + let array = &values[0]; + for (idx, &group_index) in group_indices.iter().enumerate() { + if let Some(f) = opt_filter { + if !f.is_valid(idx) || !f.value(idx) { + continue; + } + } + count_values(&mut self.groups[group_index], array, idx)?; + } + Ok(()) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + _opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + self.resize(total_num_groups); + let structs = values[0].as_struct(); + let values_list = structs.column(0).as_list::(); + let counts_list = structs.column(1).as_list::(); + for (row, &group_index) in group_indices.iter().enumerate() { + merge_state_row(&mut self.groups[group_index], values_list, counts_list, row)?; + } + Ok(()) + } + + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + let emitted = emit_to.take_needed(&mut self.groups); + let mut results = Vec::with_capacity(emitted.len()); + for map in &emitted { + results.push(eval_mode(map, &self.data_type)?); + } + ScalarValue::iter_to_array(results) + } + + fn state(&mut self, emit_to: EmitTo) -> Result> { + let emitted = emit_to.take_needed(&mut self.groups); + let refs: Vec<&HashMap> = emitted.iter().collect(); + Ok(vec![Arc::new(build_state(&self.data_type, &refs)?)]) + } + + fn size(&self) -> usize { + size_of_val(self) + + self + .groups + .iter() + .map(|m| m.capacity() * size_of::<(ScalarValue, i64)>()) + .sum::() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Float64Array, Int32Array}; + use arrow::datatypes::Int32Type; + + fn i32_array(values: Vec>) -> ArrayRef { + Arc::new(Int32Array::from(values)) + } + + fn eval_acc(acc: &mut ModeAccumulator) -> ScalarValue { + acc.evaluate().unwrap() + } + + #[test] + fn most_frequent_value() { + let mut acc = ModeAccumulator::new(DataType::Int32); + acc.update_batch(&[i32_array(vec![Some(0), Some(10), Some(10)])]) + .unwrap(); + assert_eq!(eval_acc(&mut acc), ScalarValue::Int32(Some(10))); + } + + #[test] + fn nulls_are_ignored() { + let mut acc = ModeAccumulator::new(DataType::Int32); + acc.update_batch(&[i32_array(vec![ + Some(10), + None, + None, + None, + Some(10), + Some(7), + ])]) + .unwrap(); + assert_eq!(eval_acc(&mut acc), ScalarValue::Int32(Some(10))); + } + + #[test] + fn empty_input_is_null() { + let mut acc = ModeAccumulator::new(DataType::Int32); + acc.update_batch(&[i32_array(vec![None, None])]).unwrap(); + assert_eq!(eval_acc(&mut acc), ScalarValue::Int32(None)); + } + + #[test] + fn ties_break_to_smallest() { + let mut acc = ModeAccumulator::new(DataType::Int32); + // 10 and 20 each appear twice; Comet returns the smallest tied value. + acc.update_batch(&[i32_array(vec![Some(20), Some(10), Some(10), Some(20)])]) + .unwrap(); + assert_eq!(eval_acc(&mut acc), ScalarValue::Int32(Some(10))); + } + + /// Turn an accumulator's `Vec` state into the state arrays `merge_batch` consumes. + fn state_arrays(acc: &mut ModeAccumulator) -> Vec { + acc.state() + .unwrap() + .into_iter() + .map(|s| ScalarValue::iter_to_array(vec![s]).unwrap()) + .collect() + } + + #[test] + fn merge_matches_single_shot() { + let single = { + let mut a = ModeAccumulator::new(DataType::Int32); + a.update_batch(&[i32_array(vec![ + Some(1), + Some(1), + Some(2), + Some(3), + Some(3), + Some(3), + ])]) + .unwrap(); + eval_acc(&mut a) + }; + + let mut left = ModeAccumulator::new(DataType::Int32); + left.update_batch(&[i32_array(vec![Some(1), Some(1), Some(3)])]) + .unwrap(); + let lstate = state_arrays(&mut left); + + let mut right = ModeAccumulator::new(DataType::Int32); + right + .update_batch(&[i32_array(vec![Some(2), Some(3), Some(3)])]) + .unwrap(); + let rstate = state_arrays(&mut right); + + let mut merged = ModeAccumulator::new(DataType::Int32); + merged.merge_batch(&lstate).unwrap(); + merged.merge_batch(&rstate).unwrap(); + assert_eq!(eval_acc(&mut merged), single); + } + + #[test] + fn float_zero_and_nan_normalized() { + let mut acc = ModeAccumulator::new(DataType::Float64); + // -0.0 and 0.0 must count as one key. + let arr: ArrayRef = Arc::new(Float64Array::from(vec![ + Some(-0.0), + Some(0.0), + Some(0.0), + Some(1.5), + ])); + acc.update_batch(&[arr]).unwrap(); + assert_eq!(eval_acc(&mut acc), ScalarValue::Float64(Some(0.0))); + } + + #[test] + fn groups_accumulator_per_group_mode() { + let mut acc = ModeGroupsAccumulator::new(DataType::Int32); + let values = i32_array(vec![Some(5), Some(5), Some(9), Some(9), Some(9)]); + acc.update_batch(&[values], &[0, 0, 1, 1, 1], None, 2) + .unwrap(); + let result = acc.evaluate(EmitTo::All).unwrap(); + let result = result.as_primitive::(); + assert_eq!(result.value(0), 5); + assert_eq!(result.value(1), 9); + } + + #[test] + fn groups_accumulator_merge_roundtrip() { + // Partial over two groups, then merge its state into a fresh accumulator. + let mut partial = ModeGroupsAccumulator::new(DataType::Int32); + let values = i32_array(vec![Some(5), Some(5), Some(7), Some(9), Some(9), Some(9)]); + partial + .update_batch(&[values], &[0, 0, 0, 1, 1, 1], None, 2) + .unwrap(); + let state = partial.state(EmitTo::All).unwrap(); + + let mut final_acc = ModeGroupsAccumulator::new(DataType::Int32); + final_acc.merge_batch(&state, &[0, 1], None, 2).unwrap(); + let result = final_acc.evaluate(EmitTo::All).unwrap(); + let result = result.as_primitive::(); + assert_eq!(result.value(0), 5); + assert_eq!(result.value(1), 9); + } +} diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 44142e75ed..06ed92eb14 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -389,6 +389,7 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim { classOf[Last] -> CometLast, classOf[Max] -> CometMax, classOf[Min] -> CometMin, + classOf[Mode] -> CometMode, classOf[Percentile] -> CometPercentile, classOf[StddevPop] -> CometStddevPop, classOf[StddevSamp] -> CometStddevSamp, diff --git a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala index 5710232cb4..8d93d83287 100644 --- a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala +++ b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala @@ -22,14 +22,14 @@ package org.apache.comet.serde import scala.jdk.CollectionConverters._ import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal} -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate, CentralMomentAgg, CollectSet, Corr, Count, Covariance, CovPopulation, CovSample, First, Last, Max, Min, Percentile, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate, CentralMomentAgg, CollectSet, Corr, Count, Covariance, CovPopulation, CovSample, First, Last, Max, Min, Mode, Percentile, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{ByteType, DecimalType, DoubleType, IntegerType, LongType, NumericType, ShortType, StringType} +import org.apache.spark.sql.types.{BooleanType, ByteType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, NumericType, ShortType, StringType, TimestampNTZType, TimestampType} import org.apache.comet.CometConf.COMET_EXEC_STRICT_FLOATING_POINT import org.apache.comet.CometSparkSessionExtensions.{isSpark41Plus, withFallbackReason} import org.apache.comet.serde.QueryPlanSerde.{evalModeToProto, exprToProto, serializeDataType} -import org.apache.comet.shims.CometEvalModeUtil +import org.apache.comet.shims.{CometEvalModeUtil, CometTypeShim} object CometMin extends CometAggregateExpressionSerde[Min] { @@ -823,6 +823,71 @@ object CometCollectSet extends CometAggregateExpressionSerde[CollectSet] { } } +object CometMode extends CometAggregateExpressionSerde[Mode] with CometTypeShim { + + private val tieBreakReason = + "mode breaks ties non-deterministically in Spark (the result depends on JVM hash-map" + + " iteration order); Comet returns the smallest of the tied values instead" + + " (https://github.com/apache/datafusion-comet/issues/3970)" + + override def getIncompatibleReasons(): Seq[String] = Seq(tieBreakReason) + + private def isSupportedType(dt: DataType): Boolean = dt match { + case BooleanType => true + case ByteType | ShortType | IntegerType | LongType => true + case FloatType | DoubleType => true + case _: DecimalType => true + case DateType | TimestampType | TimestampNTZType => true + case StringType => true + case _ => false + } + + override def getSupportLevel(expr: Mode): SupportLevel = { + if (modeHasUnsupportedOrdering(expr)) { + // `mode(col, deterministic)` and `mode() WITHIN GROUP (ORDER BY col)` carry deterministic + // ordered tie-breaking that Comet does not implement yet (Spark 4.0+ only). + Unsupported( + Some("mode with a deterministic flag or WITHIN GROUP ordering is not supported")) + } else if (hasNonDefaultStringCollation(expr.child.dataType)) { + // Native counting is not collation-aware, so non-UTF8_BINARY collations would group keys + // differently from Spark. + Unsupported( + Some( + "mode does not support non-UTF8_BINARY collations " + + "(https://github.com/apache/datafusion-comet/issues/2190)")) + } else if (!isSupportedType(expr.child.dataType)) { + Unsupported(Some(s"mode does not support input type ${expr.child.dataType}")) + } else { + Incompatible(Some(tieBreakReason)) + } + } + + override def convert( + aggExpr: AggregateExpression, + expr: Mode, + inputs: Seq[Attribute], + binding: Boolean, + conf: SQLConf): Option[ExprOuterClass.AggExpr] = { + val child = expr.child + val childExpr = exprToProto(child, inputs, binding) + val dataType = serializeDataType(child.dataType) + + if (childExpr.isDefined && dataType.isDefined) { + val builder = ExprOuterClass.Mode.newBuilder() + builder.setChild(childExpr.get) + builder.setDatatype(dataType.get) + Some( + ExprOuterClass.AggExpr + .newBuilder() + .setMode(builder) + .build()) + } else { + withFallbackReason(aggExpr, child) + None + } + } +} + object AggSerde { import org.apache.spark.sql.types._ diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index e4d6b53770..dcf30080c5 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -30,7 +30,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeSet, Expression, ExpressionSet, Generator, NamedExpression, SortOrder} -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateMode, CollectSet, Final, First, Last, Partial, PartialMerge, Percentile} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateMode, CollectSet, Final, First, Last, Mode, Partial, PartialMerge, Percentile} import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ @@ -43,7 +43,7 @@ import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregat import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, HashJoin, ShuffledHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} -import org.apache.spark.sql.types.{ArrayType, BooleanType, ByteType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, ShortType, StringType, TimestampNTZType, TimestampType} +import org.apache.spark.sql.types.{ArrayType, BooleanType, ByteType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, ShortType, StringType, StructField, StructType, TimestampNTZType, TimestampType} import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.SerializableConfiguration import org.apache.spark.util.io.ChunkedByteBuffer @@ -1925,6 +1925,19 @@ object CometObjectHashAggregateExec // Comet casts the child to double, so the native state is ArrayType(DoubleType). val nativeStateType = ArrayType(DoubleType, containsNull = true) output(bufferIdx) = output(bufferIdx).withDataType(nativeStateType) + case m: Mode => + // Comet's native mode accumulator keeps a frequency map encoded as parallel arrays + // (see ModeAccumulator in native/spark-expr): a struct of the distinct values and their + // counts. + val elementType = m.child.dataType + val nativeStateType = StructType( + Seq( + StructField( + "values", + ArrayType(elementType, containsNull = true), + nullable = false), + StructField("counts", ArrayType(LongType, containsNull = true), nullable = false))) + output(bufferIdx) = output(bufferIdx).withDataType(nativeStateType) case _ => } bufferIdx += bufferAttrs.length diff --git a/spark/src/main/spark-3.x/org/apache/comet/shims/CometTypeShim.scala b/spark/src/main/spark-3.x/org/apache/comet/shims/CometTypeShim.scala index 97320be9e7..4d19bf35eb 100644 --- a/spark/src/main/spark-3.x/org/apache/comet/shims/CometTypeShim.scala +++ b/spark/src/main/spark-3.x/org/apache/comet/shims/CometTypeShim.scala @@ -21,12 +21,18 @@ package org.apache.comet.shims import scala.annotation.nowarn +import org.apache.spark.sql.catalyst.expressions.aggregate.Mode import org.apache.spark.sql.types.{DataType, StructType} trait CometTypeShim { @nowarn // Spark 4 feature; stubbed to false in Spark 3.x for compatibility. def isStringCollationType(dt: DataType): Boolean = false + // `mode() WITHIN GROUP (ORDER BY ...)` and the deterministic-flag form (which set `reverseOpt`) + // are Spark 4.0 features; Spark 3.x `Mode` is always the plain `mode(col)` form. + @nowarn + def modeHasUnsupportedOrdering(expr: Mode): Boolean = false + @nowarn // Spark 4 feature; stubbed to false in Spark 3.x for compatibility. def hasNonDefaultStringCollation(dt: DataType): Boolean = false diff --git a/spark/src/main/spark-4.x/org/apache/comet/shims/CometTypeShim.scala b/spark/src/main/spark-4.x/org/apache/comet/shims/CometTypeShim.scala index 1d4a9f601e..535f4012af 100644 --- a/spark/src/main/spark-4.x/org/apache/comet/shims/CometTypeShim.scala +++ b/spark/src/main/spark-4.x/org/apache/comet/shims/CometTypeShim.scala @@ -19,10 +19,16 @@ package org.apache.comet.shims +import org.apache.spark.sql.catalyst.expressions.aggregate.Mode import org.apache.spark.sql.execution.datasources.VariantMetadata import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StringType, StructType} trait CometTypeShim { + // `reverseOpt` is set for `mode() WITHIN GROUP (ORDER BY col [DESC])` and the + // `mode(col, deterministic)` form, both of which carry ordered tie-breaking that Comet does not + // implement yet. The plain `mode(col)` form leaves it as `None`. + def modeHasUnsupportedOrdering(expr: Mode): Boolean = expr.reverseOpt.isDefined + // A `StringType` carries collation metadata in Spark 4.0. Only non-default (non-UTF8_BINARY) // collations have semantics Comet's byte-level hashing/sorting/equality cannot honor. The // default `StringType` object is `StringType(UTF8_BINARY_COLLATION_ID)`, so comparing diff --git a/spark/src/test/resources/sql-tests/expressions/aggregate/mode.sql b/spark/src/test/resources/sql-tests/expressions/aggregate/mode.sql new file mode 100644 index 0000000000..db7d1d2b90 --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/aggregate/mode.sql @@ -0,0 +1,186 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- Comet's `mode` is opt-in via allowIncompatible because Spark breaks ties non-deterministically. +-- Every compared query below has a single value with the strictly-highest frequency per group so +-- that Comet's smallest-value tie-break agrees with Spark's arbitrary choice. +-- Config: spark.comet.expression.Mode.allowIncompatible=true + +-- ============================================================ +-- Setup: tables +-- ============================================================ + +statement +CREATE TABLE mode_int(v int, grp string) USING parquet + +statement +INSERT INTO mode_int VALUES + (10, 'a'), (10, 'a'), (7, 'a'), (NULL, 'a'), + (5, 'b'), (5, 'b'), (5, 'b'), (9, 'b'), (NULL, 'b'), + (NULL, 'c'), (NULL, 'c') + +statement +CREATE TABLE mode_all_null(v int) USING parquet + +statement +INSERT INTO mode_all_null VALUES (NULL), (NULL) + +-- ============================================================ +-- Global aggregate (no GROUP BY): unique mode +-- ============================================================ + +query +SELECT mode(v) FROM mode_int + +-- ============================================================ +-- GROUP BY: unique mode per group; NULLs ignored +-- ============================================================ + +query +SELECT grp, mode(v) FROM mode_int GROUP BY grp ORDER BY grp + +-- ============================================================ +-- All-NULL input returns NULL +-- ============================================================ + +query +SELECT mode(v) FROM mode_all_null + +-- ============================================================ +-- Mixed with other aggregates +-- ============================================================ + +query +SELECT grp, mode(v), count(*), sum(v) FROM mode_int GROUP BY grp ORDER BY grp + +-- ============================================================ +-- HAVING clause +-- ============================================================ + +query +SELECT grp, mode(v) FROM mode_int GROUP BY grp HAVING count(v) > 3 ORDER BY grp + +-- ============================================================ +-- Boolean +-- ============================================================ + +statement +CREATE TABLE mode_bool(v boolean, grp string) USING parquet + +statement +INSERT INTO mode_bool VALUES + (true, 'a'), (true, 'a'), (false, 'a'), (NULL, 'a'), + (false, 'b'), (false, 'b'), (true, 'b') + +query +SELECT grp, mode(v) FROM mode_bool GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Byte / Short / Long +-- ============================================================ + +statement +CREATE TABLE mode_nums(b tinyint, s smallint, l bigint, grp string) USING parquet + +statement +INSERT INTO mode_nums VALUES + (1, 100, 1000000000000, 'a'), (1, 100, 1000000000000, 'a'), (2, 200, 2000000000000, 'a'), + (3, 300, 3000000000000, 'b'), (3, 300, 3000000000000, 'b'), (4, 400, 4000000000000, 'b') + +query +SELECT grp, mode(b), mode(s), mode(l) FROM mode_nums GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Float / Double: -0.0 and 0.0 normalize to one key +-- ============================================================ + +statement +CREATE TABLE mode_double(v double, grp string) USING parquet + +statement +INSERT INTO mode_double VALUES + (1.5, 'a'), (1.5, 'a'), (2.5, 'a'), (NULL, 'a'), + (CAST(0.0 AS DOUBLE), 'b'), (CAST(-0.0 AS DOUBLE), 'b'), (CAST(-0.0 AS DOUBLE), 'b'), (7.0, 'b') + +query +SELECT grp, mode(v) FROM mode_double GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Decimal +-- ============================================================ + +statement +CREATE TABLE mode_decimal(v decimal(10,2), grp string) USING parquet + +statement +INSERT INTO mode_decimal VALUES + (1.50, 'a'), (1.50, 'a'), (2.50, 'a'), (NULL, 'a'), + (99999999.99, 'b'), (99999999.99, 'b'), (0.00, 'b') + +query +SELECT grp, mode(v) FROM mode_decimal GROUP BY grp ORDER BY grp + +-- ============================================================ +-- String +-- ============================================================ + +statement +CREATE TABLE mode_string(v string, grp string) USING parquet + +statement +INSERT INTO mode_string VALUES + ('hello', 'a'), ('hello', 'a'), ('world', 'a'), (NULL, 'a'), + ('', 'b'), ('', 'b'), ('x', 'b') + +query +SELECT grp, mode(v) FROM mode_string GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Date / Timestamp +-- ============================================================ + +statement +CREATE TABLE mode_temporal(d date, t timestamp, grp string) USING parquet + +statement +INSERT INTO mode_temporal VALUES + (DATE '2024-01-01', TIMESTAMP '2024-01-01 00:00:00', 'a'), + (DATE '2024-01-01', TIMESTAMP '2024-01-01 00:00:00', 'a'), + (DATE '2024-06-15', TIMESTAMP '2024-06-15 12:30:00', 'a'), + (DATE '1970-01-01', TIMESTAMP '1970-01-01 00:00:00', 'b'), + (DATE '1970-01-01', TIMESTAMP '1970-01-01 00:00:00', 'b'), + (DATE '2000-12-31', TIMESTAMP '2000-12-31 23:59:59', 'b') + +query +SELECT grp, mode(d), mode(t) FROM mode_temporal GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Unsupported input type falls back to Spark +-- +-- A single row keeps the result deterministic: Spark's mode on BinaryType compares Array[Byte] +-- keys by reference, so any binary multiset with repeats is a full tie and returns an arbitrary +-- value. One row avoids that while still exercising the unsupported-type fallback. +-- ============================================================ + +statement +CREATE TABLE mode_binary(v binary) USING parquet + +statement +INSERT INTO mode_binary VALUES (X'CAFE') + +query expect_fallback(does not support input type) +SELECT mode(v) FROM mode_binary