diff --git a/docs/source/contributor-guide/expression-audits/agg_funcs.md b/docs/source/contributor-guide/expression-audits/agg_funcs.md index fb27662b49..8888776d85 100644 --- a/docs/source/contributor-guide/expression-audits/agg_funcs.md +++ b/docs/source/contributor-guide/expression-audits/agg_funcs.md @@ -27,6 +27,11 @@ - Spark 3.5.8 (audited 2026-05-26): identical to 3.4.3. - Spark 4.0.1 (audited 2026-05-26): identical to 3.4.3. +## approx_percentile + +- Spark 3.4.3, 3.5.8, 4.0.1, 4.1.1 (audited 2026-07-02): `ApproximatePercentile(child, percentageExpression, accuracyExpression)` is a `TypedImperativeAggregate` backed by a Greenwald-Khanna `PercentileDigest` quantile summary with relative error `1.0 / accuracy`. `child` accepts `NumericType`, `DateType`, `TimestampType`, `TimestampNTZType`, and interval types (all cast to `double` internally); `percentage` is a single literal or literal array in `[0.0, 1.0]`; `accuracy` is a positive literal (default 10000). NULL inputs are skipped; an empty or all-null group returns NULL. `approx_percentile` is a SQL alias for the primary function name `percentile_approx`. +- `CometApproxPercentile` maps the byte, short, int, long, float, and double input forms to a native Greenwald-Khanna quantile summary port with the same insert/compress/merge/query algorithm and relative error, casting the result back to the input type. `percentage` and `accuracy` must be foldable literals, matching Spark. Date, timestamp, interval, and decimal inputs fall back to Spark. + ## avg - Spark 3.4.3 (2026-05-26) diff --git a/docs/source/user-guide/latest/expressions.md b/docs/source/user-guide/latest/expressions.md index 680422992c..dd3b3e6cdc 100644 --- a/docs/source/user-guide/latest/expressions.md +++ b/docs/source/user-guide/latest/expressions.md @@ -60,7 +60,7 @@ expressions. The following function families are **not currently planned** for n The file-metadata functions `input_file_name`, `input_file_block_start`, and `input_file_block_length` depend on scan-internal per-row file information rather than the expression layer; their support status is covered in the [scan compatibility guide](compatibility/scans.md). -Note that `approx_count_distinct`, `median`, and `mode` are planned: they are mainstream (`median` and `mode` are exact aggregates). `approx_percentile` / `percentile_approx` are not currently planned because their approximate results cannot be made bit-identical to Spark. +Note that `approx_count_distinct`, `median`, and `mode` are planned: they are mainstream (`median` and `mode` are exact aggregates). The tables below list every Spark built-in expression with its current status. @@ -71,6 +71,7 @@ The tables below list every Spark built-in expression with its current status. | `any` | ✅ | | | `any_value` | ✅ | | | `approx_count_distinct` | 🔜 | tracking [#4098](https://github.com/apache/datafusion-comet/issues/4098) | +| `approx_percentile` | ✅ | Byte, short, int, long, float, and double input; other input types fall back to Spark | | `array_agg` | 🔜 | Array aggregate (related to `collect_list`, [#2524](https://github.com/apache/datafusion-comet/issues/2524)) | | `avg` | ✅ | Interval types fall back | | `bit_and` | ✅ | | diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 25162332fd..fa2fed394b 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -128,8 +128,8 @@ use datafusion_comet_proto::{ spark_partitioning::{partitioning::PartitioningStruct, Partitioning as SparkPartitioning}, }; use datafusion_comet_spark_expr::{ - jvm_udf::JvmScalarUdfExpr, ArrayInsert, Avg, AvgDecimal, Cast, CheckOverflow, Correlation, - Covariance, CreateNamedStruct, DecimalRescaleCheckOverflow, GetArrayStructFields, + jvm_udf::JvmScalarUdfExpr, ApproxPercentile, ArrayInsert, Avg, AvgDecimal, Cast, CheckOverflow, + Correlation, Covariance, CreateNamedStruct, DecimalRescaleCheckOverflow, GetArrayStructFields, GetStructField, IfExpr, ListExtract, NormalizeNaNAndZero, SparkCastOptions, Stddev, SumDecimal, ToJson, UnboundColumn, Variance, WideDecimalBinaryExpr, WideDecimalOp, }; @@ -2627,6 +2627,17 @@ impl PhysicalPlanner { .build() .map_err(|e| e.into()) } + AggExprStruct::ApproxPercentile(expr) => { + let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?; + let input_type = to_arrow_datatype(expr.input_type.as_ref().unwrap()); + let func = AggregateUDF::new_from_impl(ApproxPercentile::new( + expr.percentiles.clone(), + expr.accuracy, + input_type, + expr.return_array, + )); + Self::create_aggr_func_expr("approx_percentile", schema, vec![child], func) + } AggExprStruct::BloomFilterAgg(expr) => { let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?; let num_items = diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 5b2a6ce9ee..b617132e52 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -145,6 +145,7 @@ message AggExpr { BloomFilterAgg bloomFilterAgg = 16; CollectSet collectSet = 17; Percentile percentile = 18; + ApproxPercentile approxPercentile = 19; } // Optional filter expression for SQL FILTER (WHERE ...) clause. @@ -253,6 +254,22 @@ message Percentile { DataType datatype = 3; } +message ApproxPercentile { + // Child value expression, already cast to Float64 by the serde. + Expr child = 1; + // The percentiles and accuracy are carried as resolved scalars rather than + // child Exprs (unlike Percentile/BloomFilterAgg) because they are needed at + // UDAF construction time to drive return_type and accumulator shape. + // One or more percentiles in [0.0, 1.0]. + repeated double percentiles = 2; + // Spark's accuracy argument; relative_error = 1.0 / accuracy. + int64 accuracy = 3; + // True when the percentile argument was an array (output is a list). + bool return_array = 4; + // Spark's input/output type, used to cast results back from Float64. + DataType input_type = 5; +} + message BloomFilterAgg { Expr child = 1; Expr numItems = 2; diff --git a/native/spark-expr/src/agg_funcs/approx_percentile.rs b/native/spark-expr/src/agg_funcs/approx_percentile.rs new file mode 100644 index 0000000000..4e3a4822a5 --- /dev/null +++ b/native/spark-expr/src/agg_funcs/approx_percentile.rs @@ -0,0 +1,318 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::quantile_summaries::QuantileSummaries; +use arrow::array::{Array, ArrayRef, BinaryArray, Float64Array, ListArray}; +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion::common::utils::SingleRowListArrayBuilder; +use datafusion::common::{downcast_value, Result, ScalarValue}; +use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion::logical_expr::Volatility::Immutable; +use datafusion::logical_expr::{Accumulator, AggregateUDFImpl, Signature}; +use datafusion::physical_expr::expressions::format_state_name; +use std::sync::Arc; + +/// Native implementation of Spark's `approx_percentile` / `percentile_approx`, +/// backed by a bit-for-bit `QuantileSummaries` (Greenwald-Khanna) port. The +/// child value is cast to Float64 by the serde; the original `input_type` is +/// carried so results can be cast back to Spark's output type. +#[derive(Debug)] +pub struct ApproxPercentile { + name: String, + signature: Signature, + percentiles: Vec, + accuracy: i64, + input_type: DataType, + return_array: bool, +} + +impl PartialEq for ApproxPercentile { + fn eq(&self, other: &Self) -> bool { + self.name == other.name + && self.percentiles == other.percentiles + && self.accuracy == other.accuracy + && self.input_type == other.input_type + && self.return_array == other.return_array + } +} +impl Eq for ApproxPercentile {} + +impl std::hash::Hash for ApproxPercentile { + fn hash(&self, state: &mut H) { + self.name.hash(state); + self.percentiles + .iter() + .for_each(|p| p.to_bits().hash(state)); + self.accuracy.hash(state); + self.input_type.hash(state); + self.return_array.hash(state); + } +} + +impl ApproxPercentile { + pub fn new( + percentiles: Vec, + accuracy: i64, + input_type: DataType, + return_array: bool, + ) -> Self { + Self { + name: "approx_percentile".to_string(), + signature: Signature::numeric(1, Immutable), + percentiles, + accuracy, + input_type, + return_array, + } + } +} + +impl AggregateUDFImpl for ApproxPercentile { + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + if self.return_array { + Ok(DataType::List(Arc::new(Field::new( + "item", + self.input_type.clone(), + false, + )))) + } else { + Ok(self.input_type.clone()) + } + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + Ok(Box::new(ApproxPercentileAccumulator::new( + self.percentiles.clone(), + self.accuracy, + self.input_type.clone(), + self.return_array, + ))) + } + + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + Ok(vec![Arc::new(Field::new( + format_state_name(&self.name, "digest"), + DataType::Binary, + true, + ))]) + } + + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { + false + } +} + +#[derive(Debug)] +struct ApproxPercentileAccumulator { + summary: QuantileSummaries, + percentiles: Vec, + input_type: DataType, + return_array: bool, +} + +impl ApproxPercentileAccumulator { + fn new(percentiles: Vec, accuracy: i64, input_type: DataType, return_array: bool) -> Self { + let relative_error = 1.0 / accuracy as f64; + Self { + summary: QuantileSummaries::new( + QuantileSummaries::DEFAULT_COMPRESS_THRESHOLD, + relative_error, + ), + percentiles, + input_type, + return_array, + } + } + + /// Cast a double quantile back to Spark's output type. GK always returns an + /// actual inserted value (never an interpolation), so for the supported + /// numeric types this round-trips exactly and is always in range. + fn cast_back(&self, d: f64) -> ScalarValue { + match &self.input_type { + DataType::Int8 => ScalarValue::Int8(Some(d as i8)), + DataType::Int16 => ScalarValue::Int16(Some(d as i16)), + DataType::Int32 => ScalarValue::Int32(Some(d as i32)), + DataType::Int64 => ScalarValue::Int64(Some(d as i64)), + DataType::Float32 => ScalarValue::Float32(Some(d as f32)), + _ => ScalarValue::Float64(Some(d)), + } + } +} + +impl Accumulator for ApproxPercentileAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let arr = downcast_value!(&values[0], Float64Array); + if arr.null_count() == 0 { + // Fast path: no validity checks needed, iterate the raw values. + for &v in arr.values() { + self.summary.insert(v); + } + } else { + for v in arr.iter().flatten() { + self.summary.insert(v); + } + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let digests = downcast_value!(&states[0], BinaryArray); + self.summary.compress(); + for i in 0..digests.len() { + if digests.is_null(i) { + continue; + } + let peer = QuantileSummaries::from_bytes( + QuantileSummaries::DEFAULT_COMPRESS_THRESHOLD, + digests.value(i), + ); + if self.summary.count() == 0 { + // Empty self: merge would just clone the peer, so move it in. + self.summary = peer; + } else { + self.summary = self.summary.merge(&peer); + } + } + Ok(()) + } + + fn state(&mut self) -> Result> { + self.summary.compress(); + Ok(vec![ScalarValue::Binary(Some(self.summary.to_bytes()))]) + } + + fn evaluate(&mut self) -> Result { + self.summary.compress(); + match self.summary.query(&self.percentiles) { + None => { + if self.return_array { + // Empty digest still yields null overall in Spark. + Ok(ScalarValue::List(Arc::new(ListArray::new_null( + Arc::new(Field::new("item", self.input_type.clone(), false)), + 1, + )))) + } else { + Ok(ScalarValue::try_from(&self.input_type)?) + } + } + Some(results) => { + let scalars: Vec = + results.into_iter().map(|d| self.cast_back(d)).collect(); + if self.return_array { + let values = ScalarValue::iter_to_array(scalars)?; + Ok(SingleRowListArrayBuilder::new(values) + .with_nullable(false) + .build_list_scalar()) + } else { + Ok(scalars.into_iter().next().unwrap()) + } + } + } + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + + self.summary.heap_size() + + self.percentiles.capacity() * std::mem::size_of::() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn f64_array(v: Vec) -> ArrayRef { + Arc::new(Float64Array::from(v)) + } + + #[test] + fn scalar_median_of_int_column() { + let mut acc = ApproxPercentileAccumulator::new(vec![0.5], 10000, DataType::Int32, false); + acc.update_batch(&[f64_array((1..=100).map(|i| i as f64).collect())]) + .unwrap(); + match acc.evaluate().unwrap() { + ScalarValue::Int32(Some(v)) => assert!((49..=51).contains(&v)), + other => panic!("unexpected {other:?}"), + } + } + + #[test] + fn array_of_percentiles() { + let mut acc = + ApproxPercentileAccumulator::new(vec![0.25, 0.5, 0.75], 10000, DataType::Float64, true); + acc.update_batch(&[f64_array((1..=1000).map(|i| i as f64).collect())]) + .unwrap(); + match acc.evaluate().unwrap() { + ScalarValue::List(arr) => assert_eq!(arr.value_length(0), 3), + other => panic!("unexpected {other:?}"), + } + } + + #[test] + fn empty_input_is_null() { + let mut acc = ApproxPercentileAccumulator::new(vec![0.5], 10000, DataType::Int64, false); + assert!(acc.evaluate().unwrap().is_null()); + } + + #[test] + fn state_then_merge_matches_single_shot() { + let mut single = + ApproxPercentileAccumulator::new(vec![0.5], 10000, DataType::Float64, false); + single + .update_batch(&[f64_array((1..=1000).map(|i| i as f64).collect())]) + .unwrap(); + let single_val = single.evaluate().unwrap(); + + let mut left = ApproxPercentileAccumulator::new(vec![0.5], 10000, DataType::Float64, false); + left.update_batch(&[f64_array((1..=500).map(|i| i as f64).collect())]) + .unwrap(); + let left_state = left.state().unwrap(); + + let mut right = + ApproxPercentileAccumulator::new(vec![0.5], 10000, DataType::Float64, false); + right + .update_batch(&[f64_array((501..=1000).map(|i| i as f64).collect())]) + .unwrap(); + let right_state = right.state().unwrap(); + + let mut merged = + ApproxPercentileAccumulator::new(vec![0.5], 10000, DataType::Float64, false); + merged + .merge_batch(&[ScalarValue::iter_to_array(left_state).unwrap()]) + .unwrap(); + merged + .merge_batch(&[ScalarValue::iter_to_array(right_state).unwrap()]) + .unwrap(); + let merged_val = merged.evaluate().unwrap(); + + // Both within the same accuracy bound of the true median (~500). + for v in [single_val, merged_val] { + match v { + ScalarValue::Float64(Some(x)) => assert!((450.0..=550.0).contains(&x)), + other => panic!("unexpected {other:?}"), + } + } + } +} diff --git a/native/spark-expr/src/agg_funcs/mod.rs b/native/spark-expr/src/agg_funcs/mod.rs index 2a0322e46c..18a821bf4a 100644 --- a/native/spark-expr/src/agg_funcs/mod.rs +++ b/native/spark-expr/src/agg_funcs/mod.rs @@ -15,16 +15,19 @@ // specific language governing permissions and limitations // under the License. +mod approx_percentile; mod avg; mod avg_decimal; mod correlation; mod covariance; +mod quantile_summaries; mod stddev; mod sum_decimal; mod sum_int; mod variance; mod welford; +pub use approx_percentile::ApproxPercentile; pub use avg::Avg; pub use avg_decimal::AvgDecimal; pub use correlation::Correlation; diff --git a/native/spark-expr/src/agg_funcs/quantile_summaries.rs b/native/spark-expr/src/agg_funcs/quantile_summaries.rs new file mode 100644 index 0000000000..b43faebb03 --- /dev/null +++ b/native/spark-expr/src/agg_funcs/quantile_summaries.rs @@ -0,0 +1,440 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! A bit-for-bit port of Spark's `QuantileSummaries` (Greenwald-Khanna), the +//! sketch behind `approx_percentile` / `percentile_approx`. Kept free of any +//! DataFusion dependency so it can be unit-tested in isolation. +//! +//! Reference: `org.apache.spark.sql.catalyst.util.QuantileSummaries`. + +use std::collections::VecDeque; + +/// A single sampled statistic: the value, its minimum rank jump `g`, and the +/// maximum span of the rank `delta`. +#[derive(Debug, Clone, PartialEq)] +pub struct Stats { + pub value: f64, + pub g: i64, + pub delta: i64, +} + +#[derive(Debug, Clone)] +pub struct QuantileSummaries { + compress_threshold: usize, + relative_error: f64, + sampled: Vec, + count: i64, + compressed: bool, + head_sampled: Vec, +} + +impl QuantileSummaries { + pub const DEFAULT_COMPRESS_THRESHOLD: usize = 10000; + pub const DEFAULT_HEAD_SIZE: usize = 50000; + + /// Mirrors Spark's `PercentileDigest` ctor which builds a summary with + /// `compressed = true`. + pub fn new(compress_threshold: usize, relative_error: f64) -> Self { + Self { + compress_threshold, + relative_error, + sampled: Vec::new(), + count: 0, + compressed: true, + head_sampled: Vec::new(), + } + } + + pub fn count(&self) -> i64 { + self.count + } + + /// Heap bytes held by this summary (excluding the struct itself). + pub fn heap_size(&self) -> usize { + self.sampled.capacity() * std::mem::size_of::() + + self.head_sampled.capacity() * std::mem::size_of::() + } + + pub fn insert(&mut self, x: f64) { + self.head_sampled.push(x); + self.compressed = false; + if self.head_sampled.len() >= Self::DEFAULT_HEAD_SIZE { + self.with_head_buffer_inserted(); + if self.sampled.len() >= self.compress_threshold { + self.compress(); + } + } + } + + fn with_head_buffer_inserted(&mut self) { + if self.head_sampled.is_empty() { + return; + } + let mut current_count = self.count; + let mut sorted = std::mem::take(&mut self.head_sampled); + // Spark relies on `Array[Double].sorted`; use total ordering so the port + // is deterministic even in the presence of NaN (Spark's typical inputs + // are NaN-free). + sorted.sort_by(|a, b| a.total_cmp(b)); + + let mut new_samples: Vec = Vec::new(); + let mut sample_idx = 0usize; + let mut ops_idx = 0usize; + while ops_idx < sorted.len() { + let current_sample = sorted[ops_idx]; + while sample_idx < self.sampled.len() + && self.sampled[sample_idx].value <= current_sample + { + new_samples.push(self.sampled[sample_idx].clone()); + sample_idx += 1; + } + current_count += 1; + let delta = if new_samples.is_empty() + || (sample_idx == self.sampled.len() && ops_idx == sorted.len() - 1) + { + 0 + } else { + // Spark uses `.toInt` here (unlike `.toLong` in merge/compress); we use i64 + // uniformly. They diverge only past ~Int.MAX rows in one accumulator, which + // is not reachable in practice. + (2.0 * self.relative_error * current_count as f64).floor() as i64 + }; + new_samples.push(Stats { + value: current_sample, + g: 1, + delta, + }); + ops_idx += 1; + } + while sample_idx < self.sampled.len() { + new_samples.push(self.sampled[sample_idx].clone()); + sample_idx += 1; + } + self.sampled = new_samples; + self.count = current_count; + } + + pub fn compress(&mut self) { + // Already compressed and the head buffer is empty (insert clears the + // flag whenever it stages a value), so there is nothing to do. This + // mirrors Spark's `PercentileDigest.isCompressed` guard, which also + // compresses at most once. + if self.compressed { + return; + } + self.with_head_buffer_inserted(); + let merge_threshold = 2.0 * self.relative_error * self.count as f64; + self.sampled = Self::compress_immut(&self.sampled, merge_threshold); + self.compressed = true; + } + + fn compress_immut(current_samples: &[Stats], merge_threshold: f64) -> Vec { + if current_samples.is_empty() { + return Vec::new(); + } + let mut res: VecDeque = VecDeque::new(); + let mut head = current_samples[current_samples.len() - 1].clone(); + // Traverse backward from size-2 down to index 1 (index 0 is preserved + // separately so the minimum is always kept). + let mut i = current_samples.len() as isize - 2; + while i >= 1 { + let sample1 = ¤t_samples[i as usize]; + if ((sample1.g + head.g + head.delta) as f64) < merge_threshold { + head.g += sample1.g; + } else { + res.push_front(head.clone()); + head = sample1.clone(); + } + i -= 1; + } + res.push_front(head.clone()); + let curr_head = ¤t_samples[0]; + if curr_head.value <= head.value && current_samples.len() > 1 { + res.push_front(curr_head.clone()); + } + res.into() + } + + pub fn merge(&self, other: &QuantileSummaries) -> QuantileSummaries { + debug_assert!(self.head_sampled.is_empty(), "compress before merge"); + debug_assert!(other.head_sampled.is_empty(), "compress before merge"); + if other.count == 0 { + return self.clone(); + } + if self.count == 0 { + return other.clone(); + } + let merged_relative_error = self.relative_error.max(other.relative_error); + let merged_count = self.count + other.count; + let additional_self_delta = + (2.0 * other.relative_error * other.count as f64).floor() as i64; + let additional_other_delta = (2.0 * self.relative_error * self.count as f64).floor() as i64; + + let mut merged_sampled: Vec = + Vec::with_capacity(self.sampled.len() + other.sampled.len()); + let mut self_idx = 0usize; + let mut other_idx = 0usize; + while self_idx < self.sampled.len() && other_idx < other.sampled.len() { + let self_sample = &self.sampled[self_idx]; + let other_sample = &other.sampled[other_idx]; + let (mut next_sample, additional_delta) = if self_sample.value < other_sample.value { + self_idx += 1; + ( + self_sample.clone(), + if other_idx > 0 { + additional_self_delta + } else { + 0 + }, + ) + } else { + other_idx += 1; + ( + other_sample.clone(), + if self_idx > 0 { + additional_other_delta + } else { + 0 + }, + ) + }; + next_sample.delta += additional_delta; + merged_sampled.push(next_sample); + } + while self_idx < self.sampled.len() { + merged_sampled.push(self.sampled[self_idx].clone()); + self_idx += 1; + } + while other_idx < other.sampled.len() { + merged_sampled.push(other.sampled[other_idx].clone()); + other_idx += 1; + } + let comp = Self::compress_immut( + &merged_sampled, + 2.0 * merged_relative_error * merged_count as f64, + ); + QuantileSummaries { + compress_threshold: other.compress_threshold, + relative_error: merged_relative_error, + sampled: comp, + count: merged_count, + compressed: true, + head_sampled: Vec::new(), + } + } + + pub fn query(&self, percentiles: &[f64]) -> Option> { + debug_assert!(self.head_sampled.is_empty(), "compress before query"); + if self.sampled.is_empty() { + return None; + } + let target_error = self + .sampled + .iter() + .fold(i64::MIN, |m, s| m.max(s.delta + s.g)) + / 2; + + let mut indexed: Vec<(f64, usize)> = percentiles + .iter() + .enumerate() + .map(|(i, p)| (*p, i)) + .collect(); + indexed.sort_by(|a, b| a.0.total_cmp(&b.0)); + + let mut result = vec![0.0f64; percentiles.len()]; + let mut index = 0usize; + let mut min_rank = self.sampled[0].g; + for (percentile, pos) in indexed { + if percentile <= self.relative_error { + result[pos] = self.sampled[0].value; + } else if percentile >= 1.0 - self.relative_error { + result[pos] = self.sampled[self.sampled.len() - 1].value; + } else { + let (new_index, new_min_rank, approx) = + self.find_approx_quantile(index, min_rank, target_error, percentile); + index = new_index; + min_rank = new_min_rank; + result[pos] = approx; + } + } + Some(result) + } + + fn find_approx_quantile( + &self, + index: usize, + min_rank_at_index: i64, + target_error: i64, + percentile: f64, + ) -> (usize, i64, f64) { + let mut cur_sample = &self.sampled[index]; + let rank = (percentile * self.count as f64).ceil() as i64; + let mut i = index; + let mut min_rank = min_rank_at_index; + while i < self.sampled.len() - 1 { + let max_rank = min_rank + cur_sample.delta; + if max_rank - target_error <= rank && rank <= min_rank + target_error { + return (i, min_rank, cur_sample.value); + } else { + i += 1; + cur_sample = &self.sampled[i]; + min_rank += cur_sample.g; + } + } + ( + self.sampled.len() - 1, + 0, + self.sampled[self.sampled.len() - 1].value, + ) + } + + /// Comet-internal little-endian layout (NOT Spark's big-endian serializer): + /// count(i64) | relative_error(f64) | n(u32) | n * [value(f64) g(i64) delta(i64)]. + /// Callers must `compress()` first. + pub fn to_bytes(&self) -> Vec { + let mut buf = Vec::with_capacity(8 + 8 + 4 + self.sampled.len() * 24); + buf.extend_from_slice(&self.count.to_le_bytes()); + buf.extend_from_slice(&self.relative_error.to_le_bytes()); + buf.extend_from_slice(&(self.sampled.len() as u32).to_le_bytes()); + for s in &self.sampled { + buf.extend_from_slice(&s.value.to_le_bytes()); + buf.extend_from_slice(&s.g.to_le_bytes()); + buf.extend_from_slice(&s.delta.to_le_bytes()); + } + buf + } + + pub fn from_bytes(compress_threshold: usize, bytes: &[u8]) -> Self { + let mut off = 0usize; + let take = |off: &mut usize, n: usize| { + let s = &bytes[*off..*off + n]; + *off += n; + s + }; + let count = i64::from_le_bytes(take(&mut off, 8).try_into().unwrap()); + let relative_error = f64::from_le_bytes(take(&mut off, 8).try_into().unwrap()); + let n = u32::from_le_bytes(take(&mut off, 4).try_into().unwrap()) as usize; + let mut sampled = Vec::with_capacity(n); + for _ in 0..n { + let value = f64::from_le_bytes(take(&mut off, 8).try_into().unwrap()); + let g = i64::from_le_bytes(take(&mut off, 8).try_into().unwrap()); + let delta = i64::from_le_bytes(take(&mut off, 8).try_into().unwrap()); + sampled.push(Stats { value, g, delta }); + } + QuantileSummaries { + compress_threshold, + relative_error, + sampled, + count, + compressed: true, + head_sampled: Vec::new(), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + const EPS: f64 = 1.0 / 10000.0; + + fn summary_of(values: &[f64]) -> QuantileSummaries { + let mut qs = QuantileSummaries::new(QuantileSummaries::DEFAULT_COMPRESS_THRESHOLD, EPS); + for &v in values { + qs.insert(v); + } + qs.compress(); + qs + } + + /// Brute-force Spark-equivalent exact rank used to bound the approximation. + fn exact_percentile(sorted: &[f64], p: f64) -> f64 { + let rank = (p * sorted.len() as f64).ceil() as usize; + let idx = rank.saturating_sub(1).min(sorted.len() - 1); + sorted[idx] + } + + #[test] + fn empty_summary_queries_none() { + let qs = QuantileSummaries::new(QuantileSummaries::DEFAULT_COMPRESS_THRESHOLD, EPS); + assert_eq!(qs.query(&[0.5]), None); + } + + #[test] + fn query_returns_actual_inserted_values() { + let values: Vec = (1..=1000).map(|i| i as f64).collect(); + let qs = summary_of(&values); + // GK returns an actual inserted value, never an interpolation. + for p in [0.1, 0.25, 0.5, 0.75, 0.9] { + let got = qs.query(&[p]).unwrap()[0]; + assert!(values.contains(&got), "p={p} produced non-member {got}"); + } + } + + #[test] + fn query_within_relative_error_bound() { + let values: Vec = (1..=10000).map(|i| i as f64).collect(); + let mut sorted = values.clone(); + sorted.sort_by(|a, b| a.total_cmp(b)); + let qs = summary_of(&values); + for p in [0.01, 0.1, 0.5, 0.9, 0.99] { + let got = qs.query(&[p]).unwrap()[0]; + let exact = exact_percentile(&sorted, p); + // rank error bounded by relativeError * count. + let rank_err = (got - exact).abs(); + assert!( + rank_err <= EPS * values.len() as f64 + 1.0, + "p={p} got={got} exact={exact}" + ); + } + } + + #[test] + fn multi_percentile_matches_single() { + let values: Vec = (1..=5000).map(|i| i as f64).collect(); + let qs = summary_of(&values); + let ps = [0.9, 0.1, 0.5, 0.99, 0.01]; + let batch = qs.query(&ps).unwrap(); + for (i, &p) in ps.iter().enumerate() { + assert_eq!(batch[i], qs.query(&[p]).unwrap()[0]); + } + } + + #[test] + fn merge_is_within_bound() { + let left: Vec = (1..=5000).map(|i| i as f64).collect(); + let right: Vec = (5001..=10000).map(|i| i as f64).collect(); + let a = summary_of(&left); + let b = summary_of(&right); + let merged = a.merge(&b); + let mut all: Vec = left.iter().chain(right.iter()).cloned().collect(); + all.sort_by(|x, y| x.total_cmp(y)); + let got = merged.query(&[0.5]).unwrap()[0]; + let exact = exact_percentile(&all, 0.5); + assert!((got - exact).abs() <= EPS * all.len() as f64 + 1.0); + } + + #[test] + fn serde_round_trips() { + let qs = summary_of(&(1..=2000).map(|i| i as f64).collect::>()); + let bytes = qs.to_bytes(); + let back = + QuantileSummaries::from_bytes(QuantileSummaries::DEFAULT_COMPRESS_THRESHOLD, &bytes); + assert_eq!(qs.count(), back.count()); + assert_eq!(qs.query(&[0.5]), back.query(&[0.5])); + } +} diff --git a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala index 52f39c59ad..dc8268950e 100644 --- a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala +++ b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala @@ -23,7 +23,7 @@ import scala.collection.mutable.ListBuffer import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{Divide, DoubleLiteral, EqualNullSafe, EqualTo, Expression, FloatLiteral, GreaterThan, GreaterThanOrEqual, KnownFloatingPointNormalized, LessThan, LessThanOrEqual, NamedExpression, Remainder} -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateMode, Final, Partial} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateMode, Final, Partial, PartialMerge} import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreeNodeTag @@ -837,19 +837,21 @@ case class CometExecRule(session: SparkSession) private def tagUnsafePartialAggregates(plan: SparkPlan): Unit = { plan.foreach { case agg: BaseAggregateExec => - // Only consider single-mode Final aggregates. Multi-mode Finals come from Spark's - // distinct-aggregate rewrite, where the Comet partial (if any) feeds into a Spark - // PartialMerge rather than directly into a Final, which is a different code path - // than the Comet-Partial → Spark-Final crash scenario from issue #1389. + // A single-mode Final that consumes an incompatible intermediate buffer and cannot itself + // be converted to Comet must not sit above a Comet aggregate that produces that buffer, + // otherwise Spark's Final would try to read a Comet-encoded buffer and crash. Tagging the + // bottom Partial so it falls back is enough: once it is Spark, the missingCometProducer + // guard in CometBaseAggregate.doConvert cascades the fallback up through any intermediate + // PartialMerge stages of a distinct-aggregate rewrite. See issues #1389 and #4813. val modes = agg.aggregateExpressions.map(_.mode).distinct if (modes == Seq(Final) && !QueryPlanSerde.allAggsSupportMixedExecution(agg.aggregateExpressions) && !canAggregateBeConverted(agg, Final)) { findPartialAggInPlan(agg.child).foreach { partial => - // Only tag if the Partial would otherwise have been converted. If the Partial - // itself cannot be converted (e.g. the aggregate function is incompatible for the - // input type), there is no buffer-format mismatch to guard against, and tagging - // would mask the natural, more specific fallback reason. + // Only tag if the Partial would otherwise have been converted. If the Partial itself + // cannot be converted (e.g. an incompatible input type or a map-typed grouping key), + // there is no buffer-format mismatch to guard against, and tagging would mask the + // natural, more specific fallback reason. if (canAggregateBeConverted(partial, Partial)) { partial.setTagValue( CometExecRule.COMET_UNSAFE_PARTIAL, @@ -862,6 +864,32 @@ case class CometExecRule(session: SparkSession) } } + /** + * Look for the bottom Partial-mode aggregate that feeds into the given plan (the child of a + * Final). Walks through exchanges and AQE stages, and continues down through intermediate + * aggregate stages whose modes are all Partial / PartialMerge - these are the PartialMerge (and + * mixed Partial/PartialMerge) stages that Spark's distinct-aggregate rewrite inserts between + * the Partial and the Final. Stops at anything else. Requires `aggregateExpressions.nonEmpty` + * so that group-by-only dedup stages are traversed rather than mistaken for the partial. + */ + private def findPartialAggInPlan(plan: SparkPlan): Option[BaseAggregateExec] = plan match { + case agg: BaseAggregateExec + if agg.aggregateExpressions.nonEmpty && + agg.aggregateExpressions.forall(e => e.mode == Partial) => + Some(agg) + case agg: BaseAggregateExec + if agg.aggregateExpressions.forall(e => e.mode == Partial || e.mode == PartialMerge) => + // Intermediate PartialMerge / mixed stage of a distinct-aggregate rewrite, or a group-by + // only dedup stage; keep walking down towards the bottom Partial. + findPartialAggInPlan(agg.child) + case a: AQEShuffleReadExec => findPartialAggInPlan(a.child) + case s: ShuffleQueryStageExec => findPartialAggInPlan(s.plan) + case e: ShuffleExchangeExec => findPartialAggInPlan(e.child) + case other => + logDebug(s"findPartialAggInPlan: stopping at ${other.nodeName}; not a known passthrough") + None + } + /** * Conservative check for whether an aggregate could be converted to Comet. Checks operator * enablement, grouping expressions, aggregate expressions, and result expressions. @@ -933,25 +961,4 @@ case class CometExecRule(session: SparkSession) } } - /** - * Look for a Partial-mode aggregate that feeds directly into the given plan (the child of a - * Final). Walks through exchanges and AQE stages only, stopping at anything else including - * other aggregate stages. This avoids tagging unrelated Partials found deeper in the plan (e.g. - * the non-distinct Partial in a distinct-aggregate rewrite, which is separated from the Final - * by intermediate PartialMerge stages). Requires `aggregateExpressions.nonEmpty` so that - * group-by-only dedup stages are not mistaken for the partial we want to tag. - */ - private def findPartialAggInPlan(plan: SparkPlan): Option[BaseAggregateExec] = plan match { - case agg: BaseAggregateExec - if agg.aggregateExpressions.nonEmpty && - agg.aggregateExpressions.forall(e => e.mode == Partial) => - Some(agg) - case a: AQEShuffleReadExec => findPartialAggInPlan(a.child) - case s: ShuffleQueryStageExec => findPartialAggInPlan(s.plan) - case e: ShuffleExchangeExec => findPartialAggInPlan(e.child) - case other => - logDebug(s"findPartialAggInPlan: stopping at ${other.nodeName}; not a known passthrough") - None - } - } diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index b752f41d74..15f3360f8b 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -385,6 +385,7 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim { * Mapping of Spark aggregate expression class to Comet expression handler. */ val aggrSerdeMap: Map[Class[_], CometAggregateExpressionSerde[_]] = Map( + classOf[ApproximatePercentile] -> CometApproxPercentile, classOf[Average] -> CometAverage, classOf[BitAndAgg] -> CometBitAndAgg, classOf[BitOrAgg] -> CometBitOrAgg, diff --git a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala index 5710232cb4..0ad57a870d 100644 --- a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala +++ b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala @@ -22,9 +22,10 @@ package org.apache.comet.serde import scala.jdk.CollectionConverters._ import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal} -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate, CentralMomentAgg, CollectSet, Corr, Count, Covariance, CovPopulation, CovSample, First, Last, Max, Min, Percentile, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, ApproximatePercentile, Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate, CentralMomentAgg, CollectSet, Corr, Count, Covariance, CovPopulation, CovSample, First, Last, Max, Min, Percentile, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp} +import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{ByteType, DecimalType, DoubleType, IntegerType, LongType, NumericType, ShortType, StringType} +import org.apache.spark.sql.types.{ByteType, DecimalType, DoubleType, FloatType, IntegerType, LongType, NumericType, ShortType, StringType} import org.apache.comet.CometConf.COMET_EXEC_STRICT_FLOATING_POINT import org.apache.comet.CometSparkSessionExtensions.{isSpark41Plus, withFallbackReason} @@ -675,6 +676,78 @@ object CometPercentile extends CometAggregateExpressionSerde[Percentile] { } } +object CometApproxPercentile extends CometAggregateExpressionSerde[ApproximatePercentile] { + + private val nonLiteralPercentageReason = + "The percentage argument must be a foldable literal." + private val nonLiteralAccuracyReason = + "The accuracy argument must be a foldable literal." + private val inputTypeReason = + "Only byte, short, int, long, float, and double input types are supported." + + override def getUnsupportedReasons(): Seq[String] = + Seq(nonLiteralPercentageReason, nonLiteralAccuracyReason, inputTypeReason) + + override def getSupportLevel(expr: ApproximatePercentile): SupportLevel = { + if (!expr.percentageExpression.foldable) { + return Unsupported(Some(nonLiteralPercentageReason)) + } + if (!expr.accuracyExpression.foldable) { + return Unsupported(Some(nonLiteralAccuracyReason)) + } + expr.child.dataType match { + case ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => + Compatible(None) + case _ => Unsupported(Some(inputTypeReason)) + } + } + + override def convert( + aggExpr: AggregateExpression, + expr: ApproximatePercentile, + inputs: Seq[Attribute], + binding: Boolean, + conf: SQLConf): Option[ExprOuterClass.AggExpr] = { + // Spark accumulates values as doubles; cast the child up front so the + // accumulator sees a single Float64 column, then cast results back to the + // input type natively via input_type. + val childExpr = exprToProto(Cast(expr.child, DoubleType), inputs, binding) + val inputType = serializeDataType(expr.child.dataType) + + val (percentiles, returnArray) = expr.percentageExpression.eval() match { + case d: Double => (Seq(d), false) + case arr: ArrayData => (arr.toDoubleArray().toSeq, true) + case other => + withFallbackReason(aggExpr, s"Unsupported percentage literal: $other", expr.child) + return None + } + val accuracy = expr.accuracyExpression.eval() match { + case i: Int => i.toLong + case l: Long => l + case other => + withFallbackReason(aggExpr, s"Unsupported accuracy literal: $other", expr.child) + return None + } + + if (childExpr.isDefined && inputType.isDefined) { + val builder = ExprOuterClass.ApproxPercentile.newBuilder() + builder.setChild(childExpr.get) + percentiles.foreach(builder.addPercentiles(_)) + builder.setAccuracy(accuracy) + builder.setReturnArray(returnArray) + builder.setInputType(inputType.get) + Some( + ExprOuterClass.AggExpr + .newBuilder() + .setApproxPercentile(builder) + .build()) + } else { + withFallbackReason(aggExpr, expr.child) + None + } + } +} + object CometCorr extends CometAggregateExpressionSerde[Corr] { override def convert( aggExpr: AggregateExpression, diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index e4d6b53770..ec4311b746 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -1526,10 +1526,14 @@ trait CometBaseAggregate { // In distinct aggregates there can be a combination of modes. // We support {Partial, PartialMerge} mix; other combinations are rejected. val multiMode = modes.size > 1 && modeSet != Set(Partial, PartialMerge) - // For a final mode HashAggregate, we only need to transform the HashAggregate - // if there is Comet partial aggregation, unless all aggregates have compatible - // intermediate buffer formats (safe for mixed Spark/Comet execution). - val sparkFinalMode = modes.contains(Final) && findCometPartialAgg(aggregate.child).isEmpty + // An aggregate that consumes intermediate buffers (Final, or the PartialMerge stages of a + // distinct-aggregate rewrite) must have a Comet aggregate producing those buffers below it. + // Otherwise Comet would try to read a Spark partial's buffer, which is only safe when every + // aggregate has a buffer format compatible between Spark and Comet. This guards the + // Spark-Partial to Comet-Merge direction; the Comet-Partial to Spark-Final direction is + // guarded by the COMET_UNSAFE_PARTIAL tagging pass in CometExecRule. See issues #1389, #4813. + val consumesBuffers = modes.contains(Final) || modes.contains(PartialMerge) + val missingCometProducer = consumesBuffers && findCometPartialAgg(aggregate.child).isEmpty if (multiMode) { withFallbackReason( @@ -1538,12 +1542,12 @@ trait CometBaseAggregate { return None } - if (sparkFinalMode && + if (missingCometProducer && !QueryPlanSerde.allAggsSupportMixedExecution(aggregate.aggregateExpressions)) { withFallbackReason( aggregate, - "Spark Final aggregate without Comet Partial requires compatible " + - "intermediate buffer formats") + "Comet aggregate that merges intermediate buffers requires a Comet child aggregate " + + "when the intermediate buffer formats are incompatible with Spark") return None } diff --git a/spark/src/test/resources/sql-tests/expressions/aggregate/approx_percentile.sql b/spark/src/test/resources/sql-tests/expressions/aggregate/approx_percentile.sql new file mode 100644 index 0000000000..18a4b1d31b --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/aggregate/approx_percentile.sql @@ -0,0 +1,68 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- Native approx_percentile via a GK (Greenwald-Khanna) quantile summary port, +-- matching Spark's algorithm and default relative error, so results are +-- bit-identical. Only byte, short, int, long, float, and double inputs are +-- supported. Every query below uses the default `query` mode, which asserts +-- native execution, so an all-fallback run of this file cannot vacuously pass. + +-- scalar percentile over a bigint (long) column +query +SELECT approx_percentile(id, 0.5) FROM range(1000) + +-- explicit, non-default accuracy +query +SELECT approx_percentile(id, 0.9, 100) FROM range(1000) + +-- array of percentiles +query +SELECT approx_percentile(id, array(0.25, 0.5, 0.75)) FROM range(1000) + +-- group by +query +SELECT id % 3 AS g, approx_percentile(id, 0.5) FROM range(1000) GROUP BY g ORDER BY g + +-- doubles +query +SELECT approx_percentile(cast(id AS double) / 7.0, 0.5) FROM range(1000) + +-- floats +query +SELECT approx_percentile(cast(id AS float), 0.5) FROM range(1000) + +-- byte input type +query +SELECT approx_percentile(cast(id % 100 AS byte), 0.5) FROM range(1000) + +-- short input type +query +SELECT approx_percentile(cast(id AS short), 0.5) FROM range(1000) + +statement +CREATE TABLE test_approx_percentile_nulls(v int) USING parquet + +statement +INSERT INTO test_approx_percentile_nulls VALUES (1), (2), (null), (3), (4) + +-- nulls are ignored +query +SELECT approx_percentile(v, 0.5) FROM test_approx_percentile_nulls + +-- empty input yields null +query +SELECT approx_percentile(v, 0.5) FROM (SELECT id AS v FROM range(1000) WHERE id < 0) diff --git a/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4-spark3_5/q70/extended.txt b/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4-spark3_5/q70/extended.txt index 14ad77dad4..bbea84df75 100644 --- a/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4-spark3_5/q70/extended.txt +++ b/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4-spark3_5/q70/extended.txt @@ -4,7 +4,7 @@ CometNativeColumnarToRow +- CometWindowExec +- CometSort +- CometColumnarExchange - +- HashAggregate [COMET: Spark Final aggregate without Comet Partial requires compatible intermediate buffer formats] + +- HashAggregate [COMET: Comet aggregate that merges intermediate buffers requires a Comet child aggregate when the intermediate buffer formats are incompatible with Spark] +- Exchange +- HashAggregate +- Expand diff --git a/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q10/extended.txt b/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q10/extended.txt index 07af300183..f5c83651a7 100644 --- a/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q10/extended.txt +++ b/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q10/extended.txt @@ -1,5 +1,5 @@ TakeOrderedAndProject -+- HashAggregate [COMET: Spark Final aggregate without Comet Partial requires compatible intermediate buffer formats] ++- HashAggregate [COMET: Comet aggregate that merges intermediate buffers requires a Comet child aggregate when the intermediate buffer formats are incompatible with Spark] +- Exchange +- HashAggregate +- Project diff --git a/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q35/extended.txt b/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q35/extended.txt index 07af300183..f5c83651a7 100644 --- a/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q35/extended.txt +++ b/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q35/extended.txt @@ -1,5 +1,5 @@ TakeOrderedAndProject -+- HashAggregate [COMET: Spark Final aggregate without Comet Partial requires compatible intermediate buffer formats] ++- HashAggregate [COMET: Comet aggregate that merges intermediate buffers requires a Comet child aggregate when the intermediate buffer formats are incompatible with Spark] +- Exchange +- HashAggregate +- Project diff --git a/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q45/extended.txt b/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q45/extended.txt index f95c69368f..489056587b 100644 --- a/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q45/extended.txt +++ b/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q45/extended.txt @@ -1,5 +1,5 @@ TakeOrderedAndProject -+- HashAggregate [COMET: Spark Final aggregate without Comet Partial requires compatible intermediate buffer formats] ++- HashAggregate [COMET: Comet aggregate that merges intermediate buffers requires a Comet child aggregate when the intermediate buffer formats are incompatible with Spark] +- Exchange +- HashAggregate +- Project diff --git a/spark/src/test/resources/tpcds-plan-stability/approved-plans-v2_7-spark3_5/q70a/extended.txt b/spark/src/test/resources/tpcds-plan-stability/approved-plans-v2_7-spark3_5/q70a/extended.txt index 10ea854de0..9df0eb8a01 100644 --- a/spark/src/test/resources/tpcds-plan-stability/approved-plans-v2_7-spark3_5/q70a/extended.txt +++ b/spark/src/test/resources/tpcds-plan-stability/approved-plans-v2_7-spark3_5/q70a/extended.txt @@ -8,7 +8,7 @@ CometNativeColumnarToRow +- CometColumnarExchange +- HashAggregate +- Union - :- HashAggregate [COMET: Spark Final aggregate without Comet Partial requires compatible intermediate buffer formats] + :- HashAggregate [COMET: Comet aggregate that merges intermediate buffers requires a Comet child aggregate when the intermediate buffer formats are incompatible with Spark] : +- Exchange : +- HashAggregate : +- Project @@ -58,10 +58,10 @@ CometNativeColumnarToRow : +- CometProject : +- CometFilter : +- CometNativeScan parquet spark_catalog.default.date_dim - :- HashAggregate [COMET: Spark Final aggregate without Comet Partial requires compatible intermediate buffer formats] + :- HashAggregate [COMET: Comet aggregate that merges intermediate buffers requires a Comet child aggregate when the intermediate buffer formats are incompatible with Spark] : +- Exchange : +- HashAggregate - : +- HashAggregate [COMET: Spark Final aggregate without Comet Partial requires compatible intermediate buffer formats] + : +- HashAggregate [COMET: Comet aggregate that merges intermediate buffers requires a Comet child aggregate when the intermediate buffer formats are incompatible with Spark] : +- Exchange : +- HashAggregate : +- Project @@ -111,10 +111,10 @@ CometNativeColumnarToRow : +- CometProject : +- CometFilter : +- CometNativeScan parquet spark_catalog.default.date_dim - +- HashAggregate [COMET: Spark Final aggregate without Comet Partial requires compatible intermediate buffer formats] + +- HashAggregate [COMET: Comet aggregate that merges intermediate buffers requires a Comet child aggregate when the intermediate buffer formats are incompatible with Spark] +- Exchange +- HashAggregate - +- HashAggregate [COMET: Spark Final aggregate without Comet Partial requires compatible intermediate buffer formats] + +- HashAggregate [COMET: Comet aggregate that merges intermediate buffers requires a Comet child aggregate when the intermediate buffer formats are incompatible with Spark] +- Exchange +- HashAggregate +- Project diff --git a/spark/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q35/extended.txt b/spark/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q35/extended.txt index 07af300183..f5c83651a7 100644 --- a/spark/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q35/extended.txt +++ b/spark/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q35/extended.txt @@ -1,5 +1,5 @@ TakeOrderedAndProject -+- HashAggregate [COMET: Spark Final aggregate without Comet Partial requires compatible intermediate buffer formats] ++- HashAggregate [COMET: Comet aggregate that merges intermediate buffers requires a Comet child aggregate when the intermediate buffer formats are incompatible with Spark] +- Exchange +- HashAggregate +- Project diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index ae14c68207..1ff73757e9 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -108,6 +108,31 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + // Regression test for the approx_percentile distinct-aggregate crash: an aggregate with an + // incompatible intermediate buffer (percentile_approx) combined with a distinct aggregate is + // rewritten by Spark into a multi-stage plan. If part of that chain runs in Comet and part in + // Spark, the incompatible buffer crosses the boundary and crashes. Here we force the split with + // the partial/final debug configs and assert results still match Spark (the whole chain must + // fall back to Spark). See https://github.com/apache/datafusion-comet/issues/4813. + test("approx_percentile with distinct aggregate does not split across Comet and Spark") { + val data = (0 until 500).map(i => (i % 10, i % 100, i % 37)) + withParquetTable(data, "tbl", false) { + for (disablePartial <- Seq(false, true); + disableFinal <- Seq(false, true); + groupBy <- Seq("", " GROUP BY _1")) { + withSQLConf( + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_SHUFFLE_MODE.key -> "native", + SQLConf.USE_OBJECT_HASH_AGG.key -> "true", + CometConf.COMET_ENABLE_PARTIAL_HASH_AGGREGATE.key -> (!disablePartial).toString, + CometConf.COMET_ENABLE_FINAL_HASH_AGGREGATE.key -> (!disableFinal).toString) { + checkSparkAnswer( + s"SELECT percentile_approx(_2, 0.5), count(DISTINCT _3) FROM tbl$groupBy") + } + } + } + } + test("stddev_pop should return NaN for some cases") { withSQLConf(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") { Seq(true, false).foreach { nullOnDivideByZero => diff --git a/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala b/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala index 7fa06a26cc..3573567620 100644 --- a/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala @@ -308,6 +308,58 @@ class CometExecRuleSuite extends CometTestBase { } } + // Regression tests for https://github.com/apache/datafusion-comet/issues/4813. An aggregate with + // an incompatible intermediate buffer (percentile_approx) combined with a distinct aggregate is + // rewritten by Spark into a multi-stage plan whose partial is separated from the final by + // intermediate PartialMerge stages. If part of that chain runs in Comet and part in Spark the + // incompatible buffer crosses the boundary and crashes, so the whole chain must fall back. + test( + "CometExecRule should not split distinct aggregate with incompatible buffer (Spark final)") { + withTempView("test_data") { + createTestDataFrame.createOrReplaceTempView("test_data") + + val sparkPlan = createSparkPlan( + spark, + "SELECT percentile_approx(id, 0.5), COUNT(DISTINCT name) FROM test_data") + + // The distinct rewrite produces a multi-stage ObjectHashAggregate chain. + assert(countOperators(sparkPlan, classOf[ObjectHashAggregateExec]) > 1) + + withSQLConf( + CometConf.COMET_ENABLE_FINAL_HASH_AGGREGATE.key -> "false", + CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") { + val transformedPlan = applyCometExecRule(sparkPlan) + + // percentile_approx has an incompatible buffer, so with the final forced to Spark the + // entire partial/merge chain must also stay in Spark. + assert(countOperators(transformedPlan, classOf[CometHashAggregateExec]) == 0) + } + } + } + + test( + "CometExecRule should not split distinct aggregate with incompatible buffer (Spark part)") { + withTempView("test_data") { + createTestDataFrame.createOrReplaceTempView("test_data") + + val sparkPlan = createSparkPlan( + spark, + "SELECT percentile_approx(id, 0.5), COUNT(DISTINCT name) FROM test_data") + + assert(countOperators(sparkPlan, classOf[ObjectHashAggregateExec]) > 1) + + withSQLConf( + CometConf.COMET_ENABLE_PARTIAL_HASH_AGGREGATE.key -> "false", + CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") { + val transformedPlan = applyCometExecRule(sparkPlan) + + // With the partial/merge stages forced to Spark, no Comet aggregate may consume their + // incompatible buffers either. + assert(countOperators(transformedPlan, classOf[CometHashAggregateExec]) == 0) + } + } + } + test("CometExecRule should not convert hash aggregate when grouping key contains map type") { // Spark 3.4/3.5 reject `array>` as a grouping key in the analyzer (not orderable), // so the plan never reaches CometExecRule on those versions. The guard we're exercising diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometAggregateExpressionBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometAggregateExpressionBenchmark.scala index a9ee46802a..5fb23c34ff 100644 --- a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometAggregateExpressionBenchmark.scala +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometAggregateExpressionBenchmark.scala @@ -125,6 +125,34 @@ object CometAggregateExpressionBenchmark extends CometBenchmarkBase { "percentile_double_high_card", "SELECT percentile(c_double, 0.5) FROM parquetV1Table GROUP BY high_card_grp")) + // Approximate percentile (Greenwald-Khanna). All numeric input types and the + // scalar, array, and explicit-accuracy forms run natively. + private val approxPercentileAggregates = List( + AggExprConfig( + "approx_percentile_int_median", + "SELECT approx_percentile(c_int, 0.5) FROM parquetV1Table GROUP BY grp"), + AggExprConfig( + "approx_percentile_long_median", + "SELECT approx_percentile(c_long, 0.5) FROM parquetV1Table GROUP BY grp"), + AggExprConfig( + "approx_percentile_double_median", + "SELECT approx_percentile(c_double, 0.5) FROM parquetV1Table GROUP BY grp"), + AggExprConfig( + "approx_percentile_double_p90", + "SELECT approx_percentile(c_double, 0.9) FROM parquetV1Table GROUP BY grp"), + AggExprConfig( + "approx_percentile_double_array", + "SELECT approx_percentile(c_double, array(0.25, 0.5, 0.75)) FROM parquetV1Table GROUP BY grp"), + AggExprConfig( + "approx_percentile_double_accuracy", + "SELECT approx_percentile(c_double, 0.5, 100) FROM parquetV1Table GROUP BY grp"), + AggExprConfig( + "approx_percentile_double_global", + "SELECT approx_percentile(c_double, 0.5) FROM parquetV1Table"), + AggExprConfig( + "approx_percentile_double_high_card", + "SELECT approx_percentile(c_double, 0.5) FROM parquetV1Table GROUP BY high_card_grp")) + override def runCometBenchmark(mainArgs: Array[String]): Unit = { val values = 1024 * 1024 @@ -148,7 +176,7 @@ object CometAggregateExpressionBenchmark extends CometBenchmarkBase { val allAggregates = basicAggregates ++ statisticalAggregates ++ bitwiseAggregates ++ multiKeyAggregates ++ multiAggregates ++ decimalAggregates ++ - highCardinalityAggregates ++ percentileAggregates + highCardinalityAggregates ++ percentileAggregates ++ approxPercentileAggregates allAggregates.foreach { config => runExpressionBenchmark(config.name, v, config.query, config.extraCometConfigs) diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometBenchmarkBase.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometBenchmarkBase.scala index 71ff2000b3..42dba5ba30 100644 --- a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometBenchmarkBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometBenchmarkBase.scala @@ -51,6 +51,15 @@ trait CometBenchmarkBase .set("spark.master", "local[1]") .setIfMissing("spark.driver.memory", "3g") .setIfMissing("spark.executor.memory", "3g") + // Use Comet's shuffle manager so operators that require Comet shuffle can + // run natively, notably aggregates planned as ObjectHashAggregate such as + // percentile and approx_percentile. `spark.shuffle.manager` is static and + // must be set before the context starts. CometShuffleManager falls back to + // Spark's shuffle when Comet is disabled, so the Spark baseline cases are + // unaffected. + .set( + "spark.shuffle.manager", + "org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager") val sparkSession = SparkSession.builder .config(conf)