From c9960aada96532a2ac0526664acef9033e912565 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 2 Jul 2026 14:36:09 -0600 Subject: [PATCH 01/12] feat: add QuantileSummaries GK port for approx_percentile --- native/spark-expr/src/agg_funcs/mod.rs | 2 + .../src/agg_funcs/quantile_summaries.rs | 423 ++++++++++++++++++ 2 files changed, 425 insertions(+) create mode 100644 native/spark-expr/src/agg_funcs/quantile_summaries.rs diff --git a/native/spark-expr/src/agg_funcs/mod.rs b/native/spark-expr/src/agg_funcs/mod.rs index 2a0322e46c..681f1c582b 100644 --- a/native/spark-expr/src/agg_funcs/mod.rs +++ b/native/spark-expr/src/agg_funcs/mod.rs @@ -19,6 +19,7 @@ mod avg; mod avg_decimal; mod correlation; mod covariance; +mod quantile_summaries; mod stddev; mod sum_decimal; mod sum_int; @@ -29,6 +30,7 @@ pub use avg::Avg; pub use avg_decimal::AvgDecimal; pub use correlation::Correlation; pub use covariance::Covariance; +pub use quantile_summaries::{QuantileSummaries, Stats}; pub use stddev::Stddev; pub use sum_decimal::SumDecimal; pub use sum_int::SumInteger; diff --git a/native/spark-expr/src/agg_funcs/quantile_summaries.rs b/native/spark-expr/src/agg_funcs/quantile_summaries.rs new file mode 100644 index 0000000000..532a0475b7 --- /dev/null +++ b/native/spark-expr/src/agg_funcs/quantile_summaries.rs @@ -0,0 +1,423 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! A bit-for-bit port of Spark's `QuantileSummaries` (Greenwald-Khanna), the +//! sketch behind `approx_percentile` / `percentile_approx`. Kept free of any +//! DataFusion dependency so it can be unit-tested in isolation. +//! +//! Reference: `org.apache.spark.sql.catalyst.util.QuantileSummaries`. + +use std::collections::VecDeque; + +/// A single sampled statistic: the value, its minimum rank jump `g`, and the +/// maximum span of the rank `delta`. +#[derive(Debug, Clone, PartialEq)] +pub struct Stats { + pub value: f64, + pub g: i64, + pub delta: i64, +} + +#[derive(Debug, Clone)] +pub struct QuantileSummaries { + compress_threshold: usize, + relative_error: f64, + sampled: Vec, + count: i64, + compressed: bool, + head_sampled: Vec, +} + +impl QuantileSummaries { + pub const DEFAULT_COMPRESS_THRESHOLD: usize = 10000; + pub const DEFAULT_HEAD_SIZE: usize = 50000; + + /// Mirrors Spark's `PercentileDigest` ctor which builds a summary with + /// `compressed = true`. + pub fn new(compress_threshold: usize, relative_error: f64) -> Self { + Self { + compress_threshold, + relative_error, + sampled: Vec::new(), + count: 0, + compressed: true, + head_sampled: Vec::new(), + } + } + + pub fn count(&self) -> i64 { + self.count + } + + pub fn insert(&mut self, x: f64) { + self.head_sampled.push(x); + self.compressed = false; + if self.head_sampled.len() >= Self::DEFAULT_HEAD_SIZE { + self.with_head_buffer_inserted(); + if self.sampled.len() >= self.compress_threshold { + self.compress(); + } + } + } + + fn with_head_buffer_inserted(&mut self) { + if self.head_sampled.is_empty() { + return; + } + let mut current_count = self.count; + let mut sorted = std::mem::take(&mut self.head_sampled); + // Spark relies on `Array[Double].sorted`; use total ordering so the port + // is deterministic even in the presence of NaN (Spark's typical inputs + // are NaN-free). + sorted.sort_by(|a, b| a.total_cmp(b)); + + let mut new_samples: Vec = Vec::new(); + let mut sample_idx = 0usize; + let mut ops_idx = 0usize; + while ops_idx < sorted.len() { + let current_sample = sorted[ops_idx]; + while sample_idx < self.sampled.len() + && self.sampled[sample_idx].value <= current_sample + { + new_samples.push(self.sampled[sample_idx].clone()); + sample_idx += 1; + } + current_count += 1; + let delta = if new_samples.is_empty() + || (sample_idx == self.sampled.len() && ops_idx == sorted.len() - 1) + { + 0 + } else { + (2.0 * self.relative_error * current_count as f64).floor() as i64 + }; + new_samples.push(Stats { + value: current_sample, + g: 1, + delta, + }); + ops_idx += 1; + } + while sample_idx < self.sampled.len() { + new_samples.push(self.sampled[sample_idx].clone()); + sample_idx += 1; + } + self.sampled = new_samples; + self.count = current_count; + } + + pub fn compress(&mut self) { + self.with_head_buffer_inserted(); + let merge_threshold = 2.0 * self.relative_error * self.count as f64; + self.sampled = Self::compress_immut(&self.sampled, merge_threshold); + self.compressed = true; + } + + fn compress_immut(current_samples: &[Stats], merge_threshold: f64) -> Vec { + if current_samples.is_empty() { + return Vec::new(); + } + let mut res: VecDeque = VecDeque::new(); + let mut head = current_samples[current_samples.len() - 1].clone(); + // Traverse backward from size-2 down to index 1 (index 0 is preserved + // separately so the minimum is always kept). + let mut i = current_samples.len() as isize - 2; + while i >= 1 { + let sample1 = ¤t_samples[i as usize]; + if ((sample1.g + head.g + head.delta) as f64) < merge_threshold { + head.g += sample1.g; + } else { + res.push_front(head.clone()); + head = sample1.clone(); + } + i -= 1; + } + res.push_front(head.clone()); + let curr_head = ¤t_samples[0]; + if curr_head.value <= head.value && current_samples.len() > 1 { + res.push_front(curr_head.clone()); + } + res.into() + } + + pub fn merge(&self, other: &QuantileSummaries) -> QuantileSummaries { + debug_assert!(self.head_sampled.is_empty(), "compress before merge"); + debug_assert!(other.head_sampled.is_empty(), "compress before merge"); + if other.count == 0 { + return self.clone(); + } + if self.count == 0 { + return other.clone(); + } + let merged_relative_error = self.relative_error.max(other.relative_error); + let merged_count = self.count + other.count; + let additional_self_delta = + (2.0 * other.relative_error * other.count as f64).floor() as i64; + let additional_other_delta = (2.0 * self.relative_error * self.count as f64).floor() as i64; + + let mut merged_sampled: Vec = Vec::new(); + let mut self_idx = 0usize; + let mut other_idx = 0usize; + while self_idx < self.sampled.len() && other_idx < other.sampled.len() { + let self_sample = &self.sampled[self_idx]; + let other_sample = &other.sampled[other_idx]; + let (mut next_sample, additional_delta) = if self_sample.value < other_sample.value { + self_idx += 1; + ( + self_sample.clone(), + if other_idx > 0 { + additional_self_delta + } else { + 0 + }, + ) + } else { + other_idx += 1; + ( + other_sample.clone(), + if self_idx > 0 { + additional_other_delta + } else { + 0 + }, + ) + }; + next_sample.delta += additional_delta; + merged_sampled.push(next_sample); + } + while self_idx < self.sampled.len() { + merged_sampled.push(self.sampled[self_idx].clone()); + self_idx += 1; + } + while other_idx < other.sampled.len() { + merged_sampled.push(other.sampled[other_idx].clone()); + other_idx += 1; + } + let comp = Self::compress_immut( + &merged_sampled, + 2.0 * merged_relative_error * merged_count as f64, + ); + QuantileSummaries { + compress_threshold: other.compress_threshold, + relative_error: merged_relative_error, + sampled: comp, + count: merged_count, + compressed: true, + head_sampled: Vec::new(), + } + } + + pub fn query(&self, percentiles: &[f64]) -> Option> { + debug_assert!(self.head_sampled.is_empty(), "compress before query"); + if self.sampled.is_empty() { + return None; + } + let target_error = self + .sampled + .iter() + .fold(i64::MIN, |m, s| m.max(s.delta + s.g)) + / 2; + + let mut indexed: Vec<(f64, usize)> = percentiles + .iter() + .enumerate() + .map(|(i, p)| (*p, i)) + .collect(); + indexed.sort_by(|a, b| a.0.total_cmp(&b.0)); + + let mut result = vec![0.0f64; percentiles.len()]; + let mut index = 0usize; + let mut min_rank = self.sampled[0].g; + for (percentile, pos) in indexed { + if percentile <= self.relative_error { + result[pos] = self.sampled[0].value; + } else if percentile >= 1.0 - self.relative_error { + result[pos] = self.sampled[self.sampled.len() - 1].value; + } else { + let (new_index, new_min_rank, approx) = + self.find_approx_quantile(index, min_rank, target_error, percentile); + index = new_index; + min_rank = new_min_rank; + result[pos] = approx; + } + } + Some(result) + } + + fn find_approx_quantile( + &self, + index: usize, + min_rank_at_index: i64, + target_error: i64, + percentile: f64, + ) -> (usize, i64, f64) { + let mut cur_sample = &self.sampled[index]; + let rank = (percentile * self.count as f64).ceil() as i64; + let mut i = index; + let mut min_rank = min_rank_at_index; + while i < self.sampled.len() - 1 { + let max_rank = min_rank + cur_sample.delta; + if max_rank - target_error <= rank && rank <= min_rank + target_error { + return (i, min_rank, cur_sample.value); + } else { + i += 1; + cur_sample = &self.sampled[i]; + min_rank += cur_sample.g; + } + } + ( + self.sampled.len() - 1, + 0, + self.sampled[self.sampled.len() - 1].value, + ) + } + + /// Comet-internal little-endian layout (NOT Spark's big-endian serializer): + /// count(i64) | relative_error(f64) | n(u32) | n * [value(f64) g(i64) delta(i64)]. + /// Callers must `compress()` first. + pub fn to_bytes(&self) -> Vec { + let mut buf = Vec::with_capacity(8 + 8 + 4 + self.sampled.len() * 24); + buf.extend_from_slice(&self.count.to_le_bytes()); + buf.extend_from_slice(&self.relative_error.to_le_bytes()); + buf.extend_from_slice(&(self.sampled.len() as u32).to_le_bytes()); + for s in &self.sampled { + buf.extend_from_slice(&s.value.to_le_bytes()); + buf.extend_from_slice(&s.g.to_le_bytes()); + buf.extend_from_slice(&s.delta.to_le_bytes()); + } + buf + } + + pub fn from_bytes(compress_threshold: usize, bytes: &[u8]) -> Self { + let mut off = 0usize; + let take = |off: &mut usize, n: usize| { + let s = &bytes[*off..*off + n]; + *off += n; + s + }; + let count = i64::from_le_bytes(take(&mut off, 8).try_into().unwrap()); + let relative_error = f64::from_le_bytes(take(&mut off, 8).try_into().unwrap()); + let n = u32::from_le_bytes(take(&mut off, 4).try_into().unwrap()) as usize; + let mut sampled = Vec::with_capacity(n); + for _ in 0..n { + let value = f64::from_le_bytes(take(&mut off, 8).try_into().unwrap()); + let g = i64::from_le_bytes(take(&mut off, 8).try_into().unwrap()); + let delta = i64::from_le_bytes(take(&mut off, 8).try_into().unwrap()); + sampled.push(Stats { value, g, delta }); + } + QuantileSummaries { + compress_threshold, + relative_error, + sampled, + count, + compressed: true, + head_sampled: Vec::new(), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + const EPS: f64 = 1.0 / 10000.0; + + fn summary_of(values: &[f64]) -> QuantileSummaries { + let mut qs = QuantileSummaries::new(QuantileSummaries::DEFAULT_COMPRESS_THRESHOLD, EPS); + for &v in values { + qs.insert(v); + } + qs.compress(); + qs + } + + /// Brute-force Spark-equivalent exact rank used to bound the approximation. + fn exact_percentile(sorted: &[f64], p: f64) -> f64 { + let rank = (p * sorted.len() as f64).ceil() as usize; + let idx = rank.saturating_sub(1).min(sorted.len() - 1); + sorted[idx] + } + + #[test] + fn empty_summary_queries_none() { + let qs = QuantileSummaries::new(QuantileSummaries::DEFAULT_COMPRESS_THRESHOLD, EPS); + assert_eq!(qs.query(&[0.5]), None); + } + + #[test] + fn query_returns_actual_inserted_values() { + let values: Vec = (1..=1000).map(|i| i as f64).collect(); + let qs = summary_of(&values); + // GK returns an actual inserted value, never an interpolation. + for p in [0.1, 0.25, 0.5, 0.75, 0.9] { + let got = qs.query(&[p]).unwrap()[0]; + assert!(values.contains(&got), "p={p} produced non-member {got}"); + } + } + + #[test] + fn query_within_relative_error_bound() { + let values: Vec = (1..=10000).map(|i| i as f64).collect(); + let mut sorted = values.clone(); + sorted.sort_by(|a, b| a.total_cmp(b)); + let qs = summary_of(&values); + for p in [0.01, 0.1, 0.5, 0.9, 0.99] { + let got = qs.query(&[p]).unwrap()[0]; + let exact = exact_percentile(&sorted, p); + // rank error bounded by relativeError * count. + let rank_err = (got - exact).abs(); + assert!( + rank_err <= EPS * values.len() as f64 + 1.0, + "p={p} got={got} exact={exact}" + ); + } + } + + #[test] + fn multi_percentile_matches_single() { + let values: Vec = (1..=5000).map(|i| i as f64).collect(); + let qs = summary_of(&values); + let ps = [0.9, 0.1, 0.5, 0.99, 0.01]; + let batch = qs.query(&ps).unwrap(); + for (i, &p) in ps.iter().enumerate() { + assert_eq!(batch[i], qs.query(&[p]).unwrap()[0]); + } + } + + #[test] + fn merge_is_within_bound() { + let left: Vec = (1..=5000).map(|i| i as f64).collect(); + let right: Vec = (5001..=10000).map(|i| i as f64).collect(); + let a = summary_of(&left); + let b = summary_of(&right); + let merged = a.merge(&b); + let mut all: Vec = left.iter().chain(right.iter()).cloned().collect(); + all.sort_by(|x, y| x.total_cmp(y)); + let got = merged.query(&[0.5]).unwrap()[0]; + let exact = exact_percentile(&all, 0.5); + assert!((got - exact).abs() <= EPS * all.len() as f64 + 1.0); + } + + #[test] + fn serde_round_trips() { + let qs = summary_of(&(1..=2000).map(|i| i as f64).collect::>()); + let bytes = qs.to_bytes(); + let back = + QuantileSummaries::from_bytes(QuantileSummaries::DEFAULT_COMPRESS_THRESHOLD, &bytes); + assert_eq!(qs.count(), back.count()); + assert_eq!(qs.query(&[0.5]), back.query(&[0.5])); + } +} From 4bfb51bf6db68556fed8a5837707665f5d75dc00 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 2 Jul 2026 14:47:56 -0600 Subject: [PATCH 02/12] feat: add ApproxPercentile UDAF and accumulator Wrap the QuantileSummaries (Greenwald-Khanna) sketch in a DataFusion AggregateUDFImpl and per-row Accumulator for Spark's approx_percentile / percentile_approx. State is carried as a single Binary ScalarValue using the Comet-internal little-endian digest format. Supports byte, short, int, long, float, and double inputs, casting the query result back to the original input type. groups_accumulator_supported is false; this runs as a row accumulator per group. --- .../src/agg_funcs/approx_percentile.rs | 310 ++++++++++++++++++ native/spark-expr/src/agg_funcs/mod.rs | 2 + 2 files changed, 312 insertions(+) create mode 100644 native/spark-expr/src/agg_funcs/approx_percentile.rs diff --git a/native/spark-expr/src/agg_funcs/approx_percentile.rs b/native/spark-expr/src/agg_funcs/approx_percentile.rs new file mode 100644 index 0000000000..afb314fe0b --- /dev/null +++ b/native/spark-expr/src/agg_funcs/approx_percentile.rs @@ -0,0 +1,310 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::quantile_summaries::QuantileSummaries; +use arrow::array::{Array, ArrayRef, BinaryArray, Float64Array, ListArray}; +use arrow::buffer::OffsetBuffer; +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion::common::{downcast_value, Result, ScalarValue}; +use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion::logical_expr::Volatility::Immutable; +use datafusion::logical_expr::{Accumulator, AggregateUDFImpl, Signature}; +use datafusion::physical_expr::expressions::format_state_name; +use std::sync::Arc; + +/// Native implementation of Spark's `approx_percentile` / `percentile_approx`, +/// backed by a bit-for-bit `QuantileSummaries` (Greenwald-Khanna) port. The +/// child value is cast to Float64 by the serde; the original `input_type` is +/// carried so results can be cast back to Spark's output type. +#[derive(Debug)] +pub struct ApproxPercentile { + name: String, + signature: Signature, + percentiles: Vec, + accuracy: i64, + input_type: DataType, + return_array: bool, +} + +impl PartialEq for ApproxPercentile { + fn eq(&self, other: &Self) -> bool { + self.name == other.name + && self.percentiles == other.percentiles + && self.accuracy == other.accuracy + && self.input_type == other.input_type + && self.return_array == other.return_array + } +} +impl Eq for ApproxPercentile {} + +impl std::hash::Hash for ApproxPercentile { + fn hash(&self, state: &mut H) { + self.name.hash(state); + self.percentiles + .iter() + .for_each(|p| p.to_bits().hash(state)); + self.accuracy.hash(state); + self.input_type.hash(state); + self.return_array.hash(state); + } +} + +impl ApproxPercentile { + pub fn new( + percentiles: Vec, + accuracy: i64, + input_type: DataType, + return_array: bool, + ) -> Self { + Self { + name: "approx_percentile".to_string(), + signature: Signature::numeric(1, Immutable), + percentiles, + accuracy, + input_type, + return_array, + } + } + + fn output_element_type(&self) -> DataType { + self.input_type.clone() + } +} + +impl AggregateUDFImpl for ApproxPercentile { + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + if self.return_array { + Ok(DataType::List(Arc::new(Field::new( + "item", + self.output_element_type(), + false, + )))) + } else { + Ok(self.output_element_type()) + } + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + Ok(Box::new(ApproxPercentileAccumulator::new( + self.percentiles.clone(), + self.accuracy, + self.input_type.clone(), + self.return_array, + ))) + } + + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + Ok(vec![Arc::new(Field::new( + format_state_name(&self.name, "digest"), + DataType::Binary, + true, + ))]) + } + + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { + false + } +} + +#[derive(Debug)] +struct ApproxPercentileAccumulator { + summary: QuantileSummaries, + percentiles: Vec, + input_type: DataType, + return_array: bool, +} + +impl ApproxPercentileAccumulator { + fn new(percentiles: Vec, accuracy: i64, input_type: DataType, return_array: bool) -> Self { + let relative_error = 1.0 / accuracy as f64; + Self { + summary: QuantileSummaries::new( + QuantileSummaries::DEFAULT_COMPRESS_THRESHOLD, + relative_error, + ), + percentiles, + input_type, + return_array, + } + } + + /// Cast a double quantile back to Spark's output type. GK always returns an + /// actual inserted value (never an interpolation), so for the supported + /// numeric types this round-trips exactly and is always in range. + fn cast_back(&self, d: f64) -> ScalarValue { + match &self.input_type { + DataType::Int8 => ScalarValue::Int8(Some(d as i8)), + DataType::Int16 => ScalarValue::Int16(Some(d as i16)), + DataType::Int32 => ScalarValue::Int32(Some(d as i32)), + DataType::Int64 => ScalarValue::Int64(Some(d as i64)), + DataType::Float32 => ScalarValue::Float32(Some(d as f32)), + _ => ScalarValue::Float64(Some(d)), + } + } +} + +impl Accumulator for ApproxPercentileAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let arr = downcast_value!(&values[0], Float64Array); + for v in arr.iter().flatten() { + self.summary.insert(v); + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let digests = downcast_value!(&states[0], BinaryArray); + self.summary.compress(); + for i in 0..digests.len() { + if digests.is_null(i) { + continue; + } + let peer = QuantileSummaries::from_bytes( + QuantileSummaries::DEFAULT_COMPRESS_THRESHOLD, + digests.value(i), + ); + self.summary = self.summary.merge(&peer); + } + Ok(()) + } + + fn state(&mut self) -> Result> { + self.summary.compress(); + Ok(vec![ScalarValue::Binary(Some(self.summary.to_bytes()))]) + } + + fn evaluate(&mut self) -> Result { + self.summary.compress(); + let element_type = self.input_type.clone(); + match self.summary.query(&self.percentiles) { + None => { + if self.return_array { + // Empty digest still yields null overall in Spark. + Ok(ScalarValue::List(Arc::new(ListArray::new_null( + Arc::new(Field::new("item", element_type, false)), + 1, + )))) + } else { + Ok(ScalarValue::try_from(&element_type)?) + } + } + Some(results) => { + let scalars: Vec = + results.into_iter().map(|d| self.cast_back(d)).collect(); + if self.return_array { + let values = ScalarValue::iter_to_array(scalars)?; + let field = Arc::new(Field::new("item", element_type, false)); + let offsets = OffsetBuffer::from_lengths([values.len()]); + let list = ListArray::new(field, offsets, values, None); + Ok(ScalarValue::List(Arc::new(list))) + } else { + Ok(scalars.into_iter().next().unwrap()) + } + } + } + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn f64_array(v: Vec) -> ArrayRef { + Arc::new(Float64Array::from(v)) + } + + #[test] + fn scalar_median_of_int_column() { + let mut acc = ApproxPercentileAccumulator::new(vec![0.5], 10000, DataType::Int32, false); + acc.update_batch(&[f64_array((1..=100).map(|i| i as f64).collect())]) + .unwrap(); + match acc.evaluate().unwrap() { + ScalarValue::Int32(Some(v)) => assert!((49..=51).contains(&v)), + other => panic!("unexpected {other:?}"), + } + } + + #[test] + fn array_of_percentiles() { + let mut acc = + ApproxPercentileAccumulator::new(vec![0.25, 0.5, 0.75], 10000, DataType::Float64, true); + acc.update_batch(&[f64_array((1..=1000).map(|i| i as f64).collect())]) + .unwrap(); + match acc.evaluate().unwrap() { + ScalarValue::List(arr) => assert_eq!(arr.value_length(0), 3), + other => panic!("unexpected {other:?}"), + } + } + + #[test] + fn empty_input_is_null() { + let mut acc = ApproxPercentileAccumulator::new(vec![0.5], 10000, DataType::Int64, false); + assert!(acc.evaluate().unwrap().is_null()); + } + + #[test] + fn state_then_merge_matches_single_shot() { + let mut single = + ApproxPercentileAccumulator::new(vec![0.5], 10000, DataType::Float64, false); + single + .update_batch(&[f64_array((1..=1000).map(|i| i as f64).collect())]) + .unwrap(); + let single_val = single.evaluate().unwrap(); + + let mut left = ApproxPercentileAccumulator::new(vec![0.5], 10000, DataType::Float64, false); + left.update_batch(&[f64_array((1..=500).map(|i| i as f64).collect())]) + .unwrap(); + let left_state = left.state().unwrap(); + + let mut right = + ApproxPercentileAccumulator::new(vec![0.5], 10000, DataType::Float64, false); + right + .update_batch(&[f64_array((501..=1000).map(|i| i as f64).collect())]) + .unwrap(); + let right_state = right.state().unwrap(); + + let mut merged = + ApproxPercentileAccumulator::new(vec![0.5], 10000, DataType::Float64, false); + merged + .merge_batch(&[ScalarValue::iter_to_array(left_state).unwrap()]) + .unwrap(); + merged + .merge_batch(&[ScalarValue::iter_to_array(right_state).unwrap()]) + .unwrap(); + let merged_val = merged.evaluate().unwrap(); + + // Both within the same accuracy bound of the true median (~500). + for v in [single_val, merged_val] { + match v { + ScalarValue::Float64(Some(x)) => assert!((450.0..=550.0).contains(&x)), + other => panic!("unexpected {other:?}"), + } + } + } +} diff --git a/native/spark-expr/src/agg_funcs/mod.rs b/native/spark-expr/src/agg_funcs/mod.rs index 681f1c582b..dc609c6f7d 100644 --- a/native/spark-expr/src/agg_funcs/mod.rs +++ b/native/spark-expr/src/agg_funcs/mod.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +mod approx_percentile; mod avg; mod avg_decimal; mod correlation; @@ -26,6 +27,7 @@ mod sum_int; mod variance; mod welford; +pub use approx_percentile::ApproxPercentile; pub use avg::Avg; pub use avg_decimal::AvgDecimal; pub use correlation::Correlation; From de80d431e62736c42db46ffba05c0b9a992accf2 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 2 Jul 2026 14:53:23 -0600 Subject: [PATCH 03/12] feat: wire ApproxPercentile through proto and native planner --- native/core/src/execution/planner.rs | 15 +++++++++++++-- native/proto/src/proto/expr.proto | 14 ++++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 25162332fd..fa2fed394b 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -128,8 +128,8 @@ use datafusion_comet_proto::{ spark_partitioning::{partitioning::PartitioningStruct, Partitioning as SparkPartitioning}, }; use datafusion_comet_spark_expr::{ - jvm_udf::JvmScalarUdfExpr, ArrayInsert, Avg, AvgDecimal, Cast, CheckOverflow, Correlation, - Covariance, CreateNamedStruct, DecimalRescaleCheckOverflow, GetArrayStructFields, + jvm_udf::JvmScalarUdfExpr, ApproxPercentile, ArrayInsert, Avg, AvgDecimal, Cast, CheckOverflow, + Correlation, Covariance, CreateNamedStruct, DecimalRescaleCheckOverflow, GetArrayStructFields, GetStructField, IfExpr, ListExtract, NormalizeNaNAndZero, SparkCastOptions, Stddev, SumDecimal, ToJson, UnboundColumn, Variance, WideDecimalBinaryExpr, WideDecimalOp, }; @@ -2627,6 +2627,17 @@ impl PhysicalPlanner { .build() .map_err(|e| e.into()) } + AggExprStruct::ApproxPercentile(expr) => { + let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?; + let input_type = to_arrow_datatype(expr.input_type.as_ref().unwrap()); + let func = AggregateUDF::new_from_impl(ApproxPercentile::new( + expr.percentiles.clone(), + expr.accuracy, + input_type, + expr.return_array, + )); + Self::create_aggr_func_expr("approx_percentile", schema, vec![child], func) + } AggExprStruct::BloomFilterAgg(expr) => { let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?; let num_items = diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 5b2a6ce9ee..467f8afad8 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -145,6 +145,7 @@ message AggExpr { BloomFilterAgg bloomFilterAgg = 16; CollectSet collectSet = 17; Percentile percentile = 18; + ApproxPercentile approxPercentile = 19; } // Optional filter expression for SQL FILTER (WHERE ...) clause. @@ -253,6 +254,19 @@ message Percentile { DataType datatype = 3; } +message ApproxPercentile { + // Child value expression, already cast to Float64 by the serde. + Expr child = 1; + // One or more percentiles in [0.0, 1.0]. + repeated double percentiles = 2; + // Spark's accuracy argument; relative_error = 1.0 / accuracy. + int64 accuracy = 3; + // True when the percentile argument was an array (output is a list). + bool return_array = 4; + // Spark's input/output type, used to cast results back from Float64. + DataType input_type = 5; +} + message BloomFilterAgg { Expr child = 1; Expr numItems = 2; From 35b5f28eb7ee55b920361b278113bfd9be10c0c0 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 2 Jul 2026 14:58:07 -0600 Subject: [PATCH 04/12] feat: add CometApproxPercentile serde and register it --- .../apache/comet/serde/QueryPlanSerde.scala | 1 + .../org/apache/comet/serde/aggregates.scala | 77 ++++++++++++++++++- 2 files changed, 76 insertions(+), 2 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index b752f41d74..15f3360f8b 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -385,6 +385,7 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim { * Mapping of Spark aggregate expression class to Comet expression handler. */ val aggrSerdeMap: Map[Class[_], CometAggregateExpressionSerde[_]] = Map( + classOf[ApproximatePercentile] -> CometApproxPercentile, classOf[Average] -> CometAverage, classOf[BitAndAgg] -> CometBitAndAgg, classOf[BitOrAgg] -> CometBitOrAgg, diff --git a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala index 5710232cb4..0ad57a870d 100644 --- a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala +++ b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala @@ -22,9 +22,10 @@ package org.apache.comet.serde import scala.jdk.CollectionConverters._ import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal} -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate, CentralMomentAgg, CollectSet, Corr, Count, Covariance, CovPopulation, CovSample, First, Last, Max, Min, Percentile, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, ApproximatePercentile, Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate, CentralMomentAgg, CollectSet, Corr, Count, Covariance, CovPopulation, CovSample, First, Last, Max, Min, Percentile, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp} +import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{ByteType, DecimalType, DoubleType, IntegerType, LongType, NumericType, ShortType, StringType} +import org.apache.spark.sql.types.{ByteType, DecimalType, DoubleType, FloatType, IntegerType, LongType, NumericType, ShortType, StringType} import org.apache.comet.CometConf.COMET_EXEC_STRICT_FLOATING_POINT import org.apache.comet.CometSparkSessionExtensions.{isSpark41Plus, withFallbackReason} @@ -675,6 +676,78 @@ object CometPercentile extends CometAggregateExpressionSerde[Percentile] { } } +object CometApproxPercentile extends CometAggregateExpressionSerde[ApproximatePercentile] { + + private val nonLiteralPercentageReason = + "The percentage argument must be a foldable literal." + private val nonLiteralAccuracyReason = + "The accuracy argument must be a foldable literal." + private val inputTypeReason = + "Only byte, short, int, long, float, and double input types are supported." + + override def getUnsupportedReasons(): Seq[String] = + Seq(nonLiteralPercentageReason, nonLiteralAccuracyReason, inputTypeReason) + + override def getSupportLevel(expr: ApproximatePercentile): SupportLevel = { + if (!expr.percentageExpression.foldable) { + return Unsupported(Some(nonLiteralPercentageReason)) + } + if (!expr.accuracyExpression.foldable) { + return Unsupported(Some(nonLiteralAccuracyReason)) + } + expr.child.dataType match { + case ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => + Compatible(None) + case _ => Unsupported(Some(inputTypeReason)) + } + } + + override def convert( + aggExpr: AggregateExpression, + expr: ApproximatePercentile, + inputs: Seq[Attribute], + binding: Boolean, + conf: SQLConf): Option[ExprOuterClass.AggExpr] = { + // Spark accumulates values as doubles; cast the child up front so the + // accumulator sees a single Float64 column, then cast results back to the + // input type natively via input_type. + val childExpr = exprToProto(Cast(expr.child, DoubleType), inputs, binding) + val inputType = serializeDataType(expr.child.dataType) + + val (percentiles, returnArray) = expr.percentageExpression.eval() match { + case d: Double => (Seq(d), false) + case arr: ArrayData => (arr.toDoubleArray().toSeq, true) + case other => + withFallbackReason(aggExpr, s"Unsupported percentage literal: $other", expr.child) + return None + } + val accuracy = expr.accuracyExpression.eval() match { + case i: Int => i.toLong + case l: Long => l + case other => + withFallbackReason(aggExpr, s"Unsupported accuracy literal: $other", expr.child) + return None + } + + if (childExpr.isDefined && inputType.isDefined) { + val builder = ExprOuterClass.ApproxPercentile.newBuilder() + builder.setChild(childExpr.get) + percentiles.foreach(builder.addPercentiles(_)) + builder.setAccuracy(accuracy) + builder.setReturnArray(returnArray) + builder.setInputType(inputType.get) + Some( + ExprOuterClass.AggExpr + .newBuilder() + .setApproxPercentile(builder) + .build()) + } else { + withFallbackReason(aggExpr, expr.child) + None + } + } +} + object CometCorr extends CometAggregateExpressionSerde[Corr] { override def convert( aggExpr: AggregateExpression, From f98b5fd58fbd11953222842543a80aa61b9f60dd Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 2 Jul 2026 15:06:58 -0600 Subject: [PATCH 05/12] test: add approx_percentile end-to-end tests and docs Add a SQL file test that runs approx_percentile scalar, array, group-by, accuracy, doubles, floats, null, and empty-input forms through both Spark and Comet natively, then update the compatibility docs to reflect that approx_percentile now runs natively for the six supported numeric types. --- .../expression-audits/agg_funcs.md | 5 ++ docs/source/user-guide/latest/expressions.md | 3 +- .../aggregate/approx_percentile.sql | 60 +++++++++++++++++++ 3 files changed, 67 insertions(+), 1 deletion(-) create mode 100644 spark/src/test/resources/sql-tests/expressions/aggregate/approx_percentile.sql diff --git a/docs/source/contributor-guide/expression-audits/agg_funcs.md b/docs/source/contributor-guide/expression-audits/agg_funcs.md index fb27662b49..8888776d85 100644 --- a/docs/source/contributor-guide/expression-audits/agg_funcs.md +++ b/docs/source/contributor-guide/expression-audits/agg_funcs.md @@ -27,6 +27,11 @@ - Spark 3.5.8 (audited 2026-05-26): identical to 3.4.3. - Spark 4.0.1 (audited 2026-05-26): identical to 3.4.3. +## approx_percentile + +- Spark 3.4.3, 3.5.8, 4.0.1, 4.1.1 (audited 2026-07-02): `ApproximatePercentile(child, percentageExpression, accuracyExpression)` is a `TypedImperativeAggregate` backed by a Greenwald-Khanna `PercentileDigest` quantile summary with relative error `1.0 / accuracy`. `child` accepts `NumericType`, `DateType`, `TimestampType`, `TimestampNTZType`, and interval types (all cast to `double` internally); `percentage` is a single literal or literal array in `[0.0, 1.0]`; `accuracy` is a positive literal (default 10000). NULL inputs are skipped; an empty or all-null group returns NULL. `approx_percentile` is a SQL alias for the primary function name `percentile_approx`. +- `CometApproxPercentile` maps the byte, short, int, long, float, and double input forms to a native Greenwald-Khanna quantile summary port with the same insert/compress/merge/query algorithm and relative error, casting the result back to the input type. `percentage` and `accuracy` must be foldable literals, matching Spark. Date, timestamp, interval, and decimal inputs fall back to Spark. + ## avg - Spark 3.4.3 (2026-05-26) diff --git a/docs/source/user-guide/latest/expressions.md b/docs/source/user-guide/latest/expressions.md index 680422992c..dd3b3e6cdc 100644 --- a/docs/source/user-guide/latest/expressions.md +++ b/docs/source/user-guide/latest/expressions.md @@ -60,7 +60,7 @@ expressions. The following function families are **not currently planned** for n The file-metadata functions `input_file_name`, `input_file_block_start`, and `input_file_block_length` depend on scan-internal per-row file information rather than the expression layer; their support status is covered in the [scan compatibility guide](compatibility/scans.md). -Note that `approx_count_distinct`, `median`, and `mode` are planned: they are mainstream (`median` and `mode` are exact aggregates). `approx_percentile` / `percentile_approx` are not currently planned because their approximate results cannot be made bit-identical to Spark. +Note that `approx_count_distinct`, `median`, and `mode` are planned: they are mainstream (`median` and `mode` are exact aggregates). The tables below list every Spark built-in expression with its current status. @@ -71,6 +71,7 @@ The tables below list every Spark built-in expression with its current status. | `any` | ✅ | | | `any_value` | ✅ | | | `approx_count_distinct` | 🔜 | tracking [#4098](https://github.com/apache/datafusion-comet/issues/4098) | +| `approx_percentile` | ✅ | Byte, short, int, long, float, and double input; other input types fall back to Spark | | `array_agg` | 🔜 | Array aggregate (related to `collect_list`, [#2524](https://github.com/apache/datafusion-comet/issues/2524)) | | `avg` | ✅ | Interval types fall back | | `bit_and` | ✅ | | diff --git a/spark/src/test/resources/sql-tests/expressions/aggregate/approx_percentile.sql b/spark/src/test/resources/sql-tests/expressions/aggregate/approx_percentile.sql new file mode 100644 index 0000000000..a9d147adcf --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/aggregate/approx_percentile.sql @@ -0,0 +1,60 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- Native approx_percentile via a GK (Greenwald-Khanna) quantile summary port, +-- matching Spark's algorithm and default relative error, so results are +-- bit-identical. Only byte, short, int, long, float, and double inputs are +-- supported. Every query below uses the default `query` mode, which asserts +-- native execution, so an all-fallback run of this file cannot vacuously pass. + +-- scalar percentile over an int column +query +SELECT approx_percentile(id, 0.5) FROM range(1000) + +-- explicit accuracy +query +SELECT approx_percentile(id, 0.9, 10000) FROM range(1000) + +-- array of percentiles +query +SELECT approx_percentile(id, array(0.25, 0.5, 0.75)) FROM range(1000) + +-- group by +query +SELECT id % 3 AS g, approx_percentile(id, 0.5) FROM range(1000) GROUP BY g ORDER BY g + +-- doubles +query +SELECT approx_percentile(cast(id AS double) / 7.0, 0.5) FROM range(1000) + +-- floats +query +SELECT approx_percentile(cast(id AS float), 0.5) FROM range(1000) + +statement +CREATE TABLE test_approx_percentile_nulls(v int) USING parquet + +statement +INSERT INTO test_approx_percentile_nulls VALUES (1), (2), (null), (3), (4) + +-- nulls are ignored +query +SELECT approx_percentile(v, 0.5) FROM test_approx_percentile_nulls + +-- empty input yields null +query +SELECT approx_percentile(v, 0.5) FROM (SELECT id AS v FROM range(1000) WHERE id < 0) From b02864fbd473c604952d51791ef58d522ea5659b Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 2 Jul 2026 15:20:12 -0600 Subject: [PATCH 06/12] test: broaden approx_percentile coverage and fix accumulator size accounting Add heap_size() to QuantileSummaries so the accumulator's size() reflects actual sampled/head buffer memory instead of only the struct footprint. Document the intentional i64-vs-toInt delta deviation from Spark. Extend the SQL file test with byte and short input coverage, a non-default accuracy case, and fix a comment describing range().id as int. --- .../spark-expr/src/agg_funcs/approx_percentile.rs | 2 ++ .../spark-expr/src/agg_funcs/quantile_summaries.rs | 9 +++++++++ .../expressions/aggregate/approx_percentile.sql | 14 +++++++++++--- 3 files changed, 22 insertions(+), 3 deletions(-) diff --git a/native/spark-expr/src/agg_funcs/approx_percentile.rs b/native/spark-expr/src/agg_funcs/approx_percentile.rs index afb314fe0b..04053ee450 100644 --- a/native/spark-expr/src/agg_funcs/approx_percentile.rs +++ b/native/spark-expr/src/agg_funcs/approx_percentile.rs @@ -228,6 +228,8 @@ impl Accumulator for ApproxPercentileAccumulator { fn size(&self) -> usize { std::mem::size_of_val(self) + + self.summary.heap_size() + + self.percentiles.capacity() * std::mem::size_of::() } } diff --git a/native/spark-expr/src/agg_funcs/quantile_summaries.rs b/native/spark-expr/src/agg_funcs/quantile_summaries.rs index 532a0475b7..1b835ca121 100644 --- a/native/spark-expr/src/agg_funcs/quantile_summaries.rs +++ b/native/spark-expr/src/agg_funcs/quantile_summaries.rs @@ -63,6 +63,12 @@ impl QuantileSummaries { self.count } + /// Heap bytes held by this summary (excluding the struct itself). + pub fn heap_size(&self) -> usize { + self.sampled.capacity() * std::mem::size_of::() + + self.head_sampled.capacity() * std::mem::size_of::() + } + pub fn insert(&mut self, x: f64) { self.head_sampled.push(x); self.compressed = false; @@ -102,6 +108,9 @@ impl QuantileSummaries { { 0 } else { + // Spark uses `.toInt` here (unlike `.toLong` in merge/compress); we use i64 + // uniformly. They diverge only past ~Int.MAX rows in one accumulator, which + // is not reachable in practice. (2.0 * self.relative_error * current_count as f64).floor() as i64 }; new_samples.push(Stats { diff --git a/spark/src/test/resources/sql-tests/expressions/aggregate/approx_percentile.sql b/spark/src/test/resources/sql-tests/expressions/aggregate/approx_percentile.sql index a9d147adcf..18a4b1d31b 100644 --- a/spark/src/test/resources/sql-tests/expressions/aggregate/approx_percentile.sql +++ b/spark/src/test/resources/sql-tests/expressions/aggregate/approx_percentile.sql @@ -21,13 +21,13 @@ -- supported. Every query below uses the default `query` mode, which asserts -- native execution, so an all-fallback run of this file cannot vacuously pass. --- scalar percentile over an int column +-- scalar percentile over a bigint (long) column query SELECT approx_percentile(id, 0.5) FROM range(1000) --- explicit accuracy +-- explicit, non-default accuracy query -SELECT approx_percentile(id, 0.9, 10000) FROM range(1000) +SELECT approx_percentile(id, 0.9, 100) FROM range(1000) -- array of percentiles query @@ -45,6 +45,14 @@ SELECT approx_percentile(cast(id AS double) / 7.0, 0.5) FROM range(1000) query SELECT approx_percentile(cast(id AS float), 0.5) FROM range(1000) +-- byte input type +query +SELECT approx_percentile(cast(id % 100 AS byte), 0.5) FROM range(1000) + +-- short input type +query +SELECT approx_percentile(cast(id AS short), 0.5) FROM range(1000) + statement CREATE TABLE test_approx_percentile_nulls(v int) USING parquet From caae6e12ed724ce682edc388e0314568f25cee5a Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 2 Jul 2026 15:38:03 -0600 Subject: [PATCH 07/12] refactor: simplify and speed up approx_percentile aggregate - Use SingleRowListArrayBuilder for the array output branch. - Inline the trivial output_element_type wrapper. - Fast path in update_batch when the batch has no nulls. - Honor the compressed flag in QuantileSummaries::compress to avoid redundant recompression, matching Spark's compress-once semantics. - Move the peer digest into an empty accumulator instead of cloning. - Pre-size the merged sampled vector. - Keep the QuantileSummaries GK helper private (mirror welford). --- native/proto/src/proto/expr.proto | 3 ++ .../src/agg_funcs/approx_percentile.rs | 39 +++++++++++-------- native/spark-expr/src/agg_funcs/mod.rs | 1 - .../src/agg_funcs/quantile_summaries.rs | 10 ++++- 4 files changed, 35 insertions(+), 18 deletions(-) diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 467f8afad8..b617132e52 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -257,6 +257,9 @@ message Percentile { message ApproxPercentile { // Child value expression, already cast to Float64 by the serde. Expr child = 1; + // The percentiles and accuracy are carried as resolved scalars rather than + // child Exprs (unlike Percentile/BloomFilterAgg) because they are needed at + // UDAF construction time to drive return_type and accumulator shape. // One or more percentiles in [0.0, 1.0]. repeated double percentiles = 2; // Spark's accuracy argument; relative_error = 1.0 / accuracy. diff --git a/native/spark-expr/src/agg_funcs/approx_percentile.rs b/native/spark-expr/src/agg_funcs/approx_percentile.rs index 04053ee450..4b2d1515c8 100644 --- a/native/spark-expr/src/agg_funcs/approx_percentile.rs +++ b/native/spark-expr/src/agg_funcs/approx_percentile.rs @@ -17,8 +17,8 @@ use super::quantile_summaries::QuantileSummaries; use arrow::array::{Array, ArrayRef, BinaryArray, Float64Array, ListArray}; -use arrow::buffer::OffsetBuffer; use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion::common::utils::SingleRowListArrayBuilder; use datafusion::common::{downcast_value, Result, ScalarValue}; use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion::logical_expr::Volatility::Immutable; @@ -80,9 +80,6 @@ impl ApproxPercentile { } } - fn output_element_type(&self) -> DataType { - self.input_type.clone() - } } impl AggregateUDFImpl for ApproxPercentile { @@ -98,11 +95,11 @@ impl AggregateUDFImpl for ApproxPercentile { if self.return_array { Ok(DataType::List(Arc::new(Field::new( "item", - self.output_element_type(), + self.input_type.clone(), false, )))) } else { - Ok(self.output_element_type()) + Ok(self.input_type.clone()) } } @@ -168,8 +165,15 @@ impl ApproxPercentileAccumulator { impl Accumulator for ApproxPercentileAccumulator { fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { let arr = downcast_value!(&values[0], Float64Array); - for v in arr.iter().flatten() { - self.summary.insert(v); + if arr.null_count() == 0 { + // Fast path: no validity checks needed, iterate the raw values. + for &v in arr.values() { + self.summary.insert(v); + } + } else { + for v in arr.iter().flatten() { + self.summary.insert(v); + } } Ok(()) } @@ -185,7 +189,12 @@ impl Accumulator for ApproxPercentileAccumulator { QuantileSummaries::DEFAULT_COMPRESS_THRESHOLD, digests.value(i), ); - self.summary = self.summary.merge(&peer); + if self.summary.count() == 0 { + // Empty self: merge would just clone the peer, so move it in. + self.summary = peer; + } else { + self.summary = self.summary.merge(&peer); + } } Ok(()) } @@ -197,17 +206,16 @@ impl Accumulator for ApproxPercentileAccumulator { fn evaluate(&mut self) -> Result { self.summary.compress(); - let element_type = self.input_type.clone(); match self.summary.query(&self.percentiles) { None => { if self.return_array { // Empty digest still yields null overall in Spark. Ok(ScalarValue::List(Arc::new(ListArray::new_null( - Arc::new(Field::new("item", element_type, false)), + Arc::new(Field::new("item", self.input_type.clone(), false)), 1, )))) } else { - Ok(ScalarValue::try_from(&element_type)?) + Ok(ScalarValue::try_from(&self.input_type)?) } } Some(results) => { @@ -215,10 +223,9 @@ impl Accumulator for ApproxPercentileAccumulator { results.into_iter().map(|d| self.cast_back(d)).collect(); if self.return_array { let values = ScalarValue::iter_to_array(scalars)?; - let field = Arc::new(Field::new("item", element_type, false)); - let offsets = OffsetBuffer::from_lengths([values.len()]); - let list = ListArray::new(field, offsets, values, None); - Ok(ScalarValue::List(Arc::new(list))) + Ok(SingleRowListArrayBuilder::new(values) + .with_nullable(false) + .build_list_scalar()) } else { Ok(scalars.into_iter().next().unwrap()) } diff --git a/native/spark-expr/src/agg_funcs/mod.rs b/native/spark-expr/src/agg_funcs/mod.rs index dc609c6f7d..18a821bf4a 100644 --- a/native/spark-expr/src/agg_funcs/mod.rs +++ b/native/spark-expr/src/agg_funcs/mod.rs @@ -32,7 +32,6 @@ pub use avg::Avg; pub use avg_decimal::AvgDecimal; pub use correlation::Correlation; pub use covariance::Covariance; -pub use quantile_summaries::{QuantileSummaries, Stats}; pub use stddev::Stddev; pub use sum_decimal::SumDecimal; pub use sum_int::SumInteger; diff --git a/native/spark-expr/src/agg_funcs/quantile_summaries.rs b/native/spark-expr/src/agg_funcs/quantile_summaries.rs index 1b835ca121..b43faebb03 100644 --- a/native/spark-expr/src/agg_funcs/quantile_summaries.rs +++ b/native/spark-expr/src/agg_funcs/quantile_summaries.rs @@ -129,6 +129,13 @@ impl QuantileSummaries { } pub fn compress(&mut self) { + // Already compressed and the head buffer is empty (insert clears the + // flag whenever it stages a value), so there is nothing to do. This + // mirrors Spark's `PercentileDigest.isCompressed` guard, which also + // compresses at most once. + if self.compressed { + return; + } self.with_head_buffer_inserted(); let merge_threshold = 2.0 * self.relative_error * self.count as f64; self.sampled = Self::compress_immut(&self.sampled, merge_threshold); @@ -177,7 +184,8 @@ impl QuantileSummaries { (2.0 * other.relative_error * other.count as f64).floor() as i64; let additional_other_delta = (2.0 * self.relative_error * self.count as f64).floor() as i64; - let mut merged_sampled: Vec = Vec::new(); + let mut merged_sampled: Vec = + Vec::with_capacity(self.sampled.len() + other.sampled.len()); let mut self_idx = 0usize; let mut other_idx = 0usize; while self_idx < self.sampled.len() && other_idx < other.sampled.len() { From b47fbabc0757ed60c13255a59027474ac68dc291 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 2 Jul 2026 15:39:26 -0600 Subject: [PATCH 08/12] test: add approx_percentile microbenchmarks --- .../CometAggregateExpressionBenchmark.scala | 30 ++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometAggregateExpressionBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometAggregateExpressionBenchmark.scala index a9ee46802a..5fb23c34ff 100644 --- a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometAggregateExpressionBenchmark.scala +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometAggregateExpressionBenchmark.scala @@ -125,6 +125,34 @@ object CometAggregateExpressionBenchmark extends CometBenchmarkBase { "percentile_double_high_card", "SELECT percentile(c_double, 0.5) FROM parquetV1Table GROUP BY high_card_grp")) + // Approximate percentile (Greenwald-Khanna). All numeric input types and the + // scalar, array, and explicit-accuracy forms run natively. + private val approxPercentileAggregates = List( + AggExprConfig( + "approx_percentile_int_median", + "SELECT approx_percentile(c_int, 0.5) FROM parquetV1Table GROUP BY grp"), + AggExprConfig( + "approx_percentile_long_median", + "SELECT approx_percentile(c_long, 0.5) FROM parquetV1Table GROUP BY grp"), + AggExprConfig( + "approx_percentile_double_median", + "SELECT approx_percentile(c_double, 0.5) FROM parquetV1Table GROUP BY grp"), + AggExprConfig( + "approx_percentile_double_p90", + "SELECT approx_percentile(c_double, 0.9) FROM parquetV1Table GROUP BY grp"), + AggExprConfig( + "approx_percentile_double_array", + "SELECT approx_percentile(c_double, array(0.25, 0.5, 0.75)) FROM parquetV1Table GROUP BY grp"), + AggExprConfig( + "approx_percentile_double_accuracy", + "SELECT approx_percentile(c_double, 0.5, 100) FROM parquetV1Table GROUP BY grp"), + AggExprConfig( + "approx_percentile_double_global", + "SELECT approx_percentile(c_double, 0.5) FROM parquetV1Table"), + AggExprConfig( + "approx_percentile_double_high_card", + "SELECT approx_percentile(c_double, 0.5) FROM parquetV1Table GROUP BY high_card_grp")) + override def runCometBenchmark(mainArgs: Array[String]): Unit = { val values = 1024 * 1024 @@ -148,7 +176,7 @@ object CometAggregateExpressionBenchmark extends CometBenchmarkBase { val allAggregates = basicAggregates ++ statisticalAggregates ++ bitwiseAggregates ++ multiKeyAggregates ++ multiAggregates ++ decimalAggregates ++ - highCardinalityAggregates ++ percentileAggregates + highCardinalityAggregates ++ percentileAggregates ++ approxPercentileAggregates allAggregates.foreach { config => runExpressionBenchmark(config.name, v, config.query, config.extraCometConfigs) From 92ebc5d8af3963b2c1d1317dc228042ebd3110ed Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 2 Jul 2026 16:24:22 -0600 Subject: [PATCH 09/12] test: enable Comet shuffle manager in benchmark base ObjectHashAggregate-based aggregates (e.g. approx_percentile) only run natively when Comet shuffle is enabled, which requires the Comet shuffle manager. Set it on the SparkConf at startup so all benchmarks using this base measure native execution instead of falling back to Spark. --- .../apache/spark/sql/benchmark/CometBenchmarkBase.scala | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometBenchmarkBase.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometBenchmarkBase.scala index 71ff2000b3..42dba5ba30 100644 --- a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometBenchmarkBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometBenchmarkBase.scala @@ -51,6 +51,15 @@ trait CometBenchmarkBase .set("spark.master", "local[1]") .setIfMissing("spark.driver.memory", "3g") .setIfMissing("spark.executor.memory", "3g") + // Use Comet's shuffle manager so operators that require Comet shuffle can + // run natively, notably aggregates planned as ObjectHashAggregate such as + // percentile and approx_percentile. `spark.shuffle.manager` is static and + // must be set before the context starts. CometShuffleManager falls back to + // Spark's shuffle when Comet is disabled, so the Spark baseline cases are + // unaffected. + .set( + "spark.shuffle.manager", + "org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager") val sparkSession = SparkSession.builder .config(conf) From d1ed88f72ddfda75d67b3876dae70a8fa932c110 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 2 Jul 2026 16:25:25 -0600 Subject: [PATCH 10/12] style: apply cargo fmt to approx_percentile --- native/spark-expr/src/agg_funcs/approx_percentile.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/native/spark-expr/src/agg_funcs/approx_percentile.rs b/native/spark-expr/src/agg_funcs/approx_percentile.rs index 4b2d1515c8..4e3a4822a5 100644 --- a/native/spark-expr/src/agg_funcs/approx_percentile.rs +++ b/native/spark-expr/src/agg_funcs/approx_percentile.rs @@ -79,7 +79,6 @@ impl ApproxPercentile { return_array, } } - } impl AggregateUDFImpl for ApproxPercentile { From c61fcd324b614b807e58dbc76463124b7b76935c Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 3 Jul 2026 08:36:55 -0600 Subject: [PATCH 11/12] fix: prevent incompatible-buffer aggregate from splitting across Comet and Spark in distinct rewrite A distinct-aggregate rewrite separates a non-distinct aggregate's Partial from its Final by intermediate PartialMerge stages. The existing mixed- execution guards (#1389) only covered a direct Partial -> Final pair, so an aggregate with an incompatible intermediate buffer (e.g. percentile_approx) could run part of the chain in Comet and part in Spark, handing a Comet- encoded buffer to a Spark aggregate (or vice versa) and crashing. - CometExecRule: walk findPartialAggInPlan through intermediate PartialMerge stages to tag the bottom Partial so the whole chain falls back to Spark. - CometBaseAggregate.doConvert: generalize the sparkFinalMode guard so a Comet PartialMerge that merges buffers (not just a Final) requires a Comet producer below it, covering the Spark-Partial -> Comet-Merge direction. Fixes the ObjectHashAggregateSuite "[typed, with distinct]" crash. See #4813. --- .../apache/comet/rules/CometExecRule.scala | 67 ++++++++++--------- .../apache/spark/sql/comet/operators.scala | 18 +++-- .../comet/exec/CometAggregateSuite.scala | 25 +++++++ .../comet/rules/CometExecRuleSuite.scala | 52 ++++++++++++++ 4 files changed, 125 insertions(+), 37 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala index 52f39c59ad..dc8268950e 100644 --- a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala +++ b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala @@ -23,7 +23,7 @@ import scala.collection.mutable.ListBuffer import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{Divide, DoubleLiteral, EqualNullSafe, EqualTo, Expression, FloatLiteral, GreaterThan, GreaterThanOrEqual, KnownFloatingPointNormalized, LessThan, LessThanOrEqual, NamedExpression, Remainder} -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateMode, Final, Partial} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateMode, Final, Partial, PartialMerge} import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreeNodeTag @@ -837,19 +837,21 @@ case class CometExecRule(session: SparkSession) private def tagUnsafePartialAggregates(plan: SparkPlan): Unit = { plan.foreach { case agg: BaseAggregateExec => - // Only consider single-mode Final aggregates. Multi-mode Finals come from Spark's - // distinct-aggregate rewrite, where the Comet partial (if any) feeds into a Spark - // PartialMerge rather than directly into a Final, which is a different code path - // than the Comet-Partial → Spark-Final crash scenario from issue #1389. + // A single-mode Final that consumes an incompatible intermediate buffer and cannot itself + // be converted to Comet must not sit above a Comet aggregate that produces that buffer, + // otherwise Spark's Final would try to read a Comet-encoded buffer and crash. Tagging the + // bottom Partial so it falls back is enough: once it is Spark, the missingCometProducer + // guard in CometBaseAggregate.doConvert cascades the fallback up through any intermediate + // PartialMerge stages of a distinct-aggregate rewrite. See issues #1389 and #4813. val modes = agg.aggregateExpressions.map(_.mode).distinct if (modes == Seq(Final) && !QueryPlanSerde.allAggsSupportMixedExecution(agg.aggregateExpressions) && !canAggregateBeConverted(agg, Final)) { findPartialAggInPlan(agg.child).foreach { partial => - // Only tag if the Partial would otherwise have been converted. If the Partial - // itself cannot be converted (e.g. the aggregate function is incompatible for the - // input type), there is no buffer-format mismatch to guard against, and tagging - // would mask the natural, more specific fallback reason. + // Only tag if the Partial would otherwise have been converted. If the Partial itself + // cannot be converted (e.g. an incompatible input type or a map-typed grouping key), + // there is no buffer-format mismatch to guard against, and tagging would mask the + // natural, more specific fallback reason. if (canAggregateBeConverted(partial, Partial)) { partial.setTagValue( CometExecRule.COMET_UNSAFE_PARTIAL, @@ -862,6 +864,32 @@ case class CometExecRule(session: SparkSession) } } + /** + * Look for the bottom Partial-mode aggregate that feeds into the given plan (the child of a + * Final). Walks through exchanges and AQE stages, and continues down through intermediate + * aggregate stages whose modes are all Partial / PartialMerge - these are the PartialMerge (and + * mixed Partial/PartialMerge) stages that Spark's distinct-aggregate rewrite inserts between + * the Partial and the Final. Stops at anything else. Requires `aggregateExpressions.nonEmpty` + * so that group-by-only dedup stages are traversed rather than mistaken for the partial. + */ + private def findPartialAggInPlan(plan: SparkPlan): Option[BaseAggregateExec] = plan match { + case agg: BaseAggregateExec + if agg.aggregateExpressions.nonEmpty && + agg.aggregateExpressions.forall(e => e.mode == Partial) => + Some(agg) + case agg: BaseAggregateExec + if agg.aggregateExpressions.forall(e => e.mode == Partial || e.mode == PartialMerge) => + // Intermediate PartialMerge / mixed stage of a distinct-aggregate rewrite, or a group-by + // only dedup stage; keep walking down towards the bottom Partial. + findPartialAggInPlan(agg.child) + case a: AQEShuffleReadExec => findPartialAggInPlan(a.child) + case s: ShuffleQueryStageExec => findPartialAggInPlan(s.plan) + case e: ShuffleExchangeExec => findPartialAggInPlan(e.child) + case other => + logDebug(s"findPartialAggInPlan: stopping at ${other.nodeName}; not a known passthrough") + None + } + /** * Conservative check for whether an aggregate could be converted to Comet. Checks operator * enablement, grouping expressions, aggregate expressions, and result expressions. @@ -933,25 +961,4 @@ case class CometExecRule(session: SparkSession) } } - /** - * Look for a Partial-mode aggregate that feeds directly into the given plan (the child of a - * Final). Walks through exchanges and AQE stages only, stopping at anything else including - * other aggregate stages. This avoids tagging unrelated Partials found deeper in the plan (e.g. - * the non-distinct Partial in a distinct-aggregate rewrite, which is separated from the Final - * by intermediate PartialMerge stages). Requires `aggregateExpressions.nonEmpty` so that - * group-by-only dedup stages are not mistaken for the partial we want to tag. - */ - private def findPartialAggInPlan(plan: SparkPlan): Option[BaseAggregateExec] = plan match { - case agg: BaseAggregateExec - if agg.aggregateExpressions.nonEmpty && - agg.aggregateExpressions.forall(e => e.mode == Partial) => - Some(agg) - case a: AQEShuffleReadExec => findPartialAggInPlan(a.child) - case s: ShuffleQueryStageExec => findPartialAggInPlan(s.plan) - case e: ShuffleExchangeExec => findPartialAggInPlan(e.child) - case other => - logDebug(s"findPartialAggInPlan: stopping at ${other.nodeName}; not a known passthrough") - None - } - } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index e4d6b53770..ec4311b746 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -1526,10 +1526,14 @@ trait CometBaseAggregate { // In distinct aggregates there can be a combination of modes. // We support {Partial, PartialMerge} mix; other combinations are rejected. val multiMode = modes.size > 1 && modeSet != Set(Partial, PartialMerge) - // For a final mode HashAggregate, we only need to transform the HashAggregate - // if there is Comet partial aggregation, unless all aggregates have compatible - // intermediate buffer formats (safe for mixed Spark/Comet execution). - val sparkFinalMode = modes.contains(Final) && findCometPartialAgg(aggregate.child).isEmpty + // An aggregate that consumes intermediate buffers (Final, or the PartialMerge stages of a + // distinct-aggregate rewrite) must have a Comet aggregate producing those buffers below it. + // Otherwise Comet would try to read a Spark partial's buffer, which is only safe when every + // aggregate has a buffer format compatible between Spark and Comet. This guards the + // Spark-Partial to Comet-Merge direction; the Comet-Partial to Spark-Final direction is + // guarded by the COMET_UNSAFE_PARTIAL tagging pass in CometExecRule. See issues #1389, #4813. + val consumesBuffers = modes.contains(Final) || modes.contains(PartialMerge) + val missingCometProducer = consumesBuffers && findCometPartialAgg(aggregate.child).isEmpty if (multiMode) { withFallbackReason( @@ -1538,12 +1542,12 @@ trait CometBaseAggregate { return None } - if (sparkFinalMode && + if (missingCometProducer && !QueryPlanSerde.allAggsSupportMixedExecution(aggregate.aggregateExpressions)) { withFallbackReason( aggregate, - "Spark Final aggregate without Comet Partial requires compatible " + - "intermediate buffer formats") + "Comet aggregate that merges intermediate buffers requires a Comet child aggregate " + + "when the intermediate buffer formats are incompatible with Spark") return None } diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index ae14c68207..1ff73757e9 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -108,6 +108,31 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + // Regression test for the approx_percentile distinct-aggregate crash: an aggregate with an + // incompatible intermediate buffer (percentile_approx) combined with a distinct aggregate is + // rewritten by Spark into a multi-stage plan. If part of that chain runs in Comet and part in + // Spark, the incompatible buffer crosses the boundary and crashes. Here we force the split with + // the partial/final debug configs and assert results still match Spark (the whole chain must + // fall back to Spark). See https://github.com/apache/datafusion-comet/issues/4813. + test("approx_percentile with distinct aggregate does not split across Comet and Spark") { + val data = (0 until 500).map(i => (i % 10, i % 100, i % 37)) + withParquetTable(data, "tbl", false) { + for (disablePartial <- Seq(false, true); + disableFinal <- Seq(false, true); + groupBy <- Seq("", " GROUP BY _1")) { + withSQLConf( + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_SHUFFLE_MODE.key -> "native", + SQLConf.USE_OBJECT_HASH_AGG.key -> "true", + CometConf.COMET_ENABLE_PARTIAL_HASH_AGGREGATE.key -> (!disablePartial).toString, + CometConf.COMET_ENABLE_FINAL_HASH_AGGREGATE.key -> (!disableFinal).toString) { + checkSparkAnswer( + s"SELECT percentile_approx(_2, 0.5), count(DISTINCT _3) FROM tbl$groupBy") + } + } + } + } + test("stddev_pop should return NaN for some cases") { withSQLConf(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") { Seq(true, false).foreach { nullOnDivideByZero => diff --git a/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala b/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala index 7fa06a26cc..3573567620 100644 --- a/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala @@ -308,6 +308,58 @@ class CometExecRuleSuite extends CometTestBase { } } + // Regression tests for https://github.com/apache/datafusion-comet/issues/4813. An aggregate with + // an incompatible intermediate buffer (percentile_approx) combined with a distinct aggregate is + // rewritten by Spark into a multi-stage plan whose partial is separated from the final by + // intermediate PartialMerge stages. If part of that chain runs in Comet and part in Spark the + // incompatible buffer crosses the boundary and crashes, so the whole chain must fall back. + test( + "CometExecRule should not split distinct aggregate with incompatible buffer (Spark final)") { + withTempView("test_data") { + createTestDataFrame.createOrReplaceTempView("test_data") + + val sparkPlan = createSparkPlan( + spark, + "SELECT percentile_approx(id, 0.5), COUNT(DISTINCT name) FROM test_data") + + // The distinct rewrite produces a multi-stage ObjectHashAggregate chain. + assert(countOperators(sparkPlan, classOf[ObjectHashAggregateExec]) > 1) + + withSQLConf( + CometConf.COMET_ENABLE_FINAL_HASH_AGGREGATE.key -> "false", + CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") { + val transformedPlan = applyCometExecRule(sparkPlan) + + // percentile_approx has an incompatible buffer, so with the final forced to Spark the + // entire partial/merge chain must also stay in Spark. + assert(countOperators(transformedPlan, classOf[CometHashAggregateExec]) == 0) + } + } + } + + test( + "CometExecRule should not split distinct aggregate with incompatible buffer (Spark part)") { + withTempView("test_data") { + createTestDataFrame.createOrReplaceTempView("test_data") + + val sparkPlan = createSparkPlan( + spark, + "SELECT percentile_approx(id, 0.5), COUNT(DISTINCT name) FROM test_data") + + assert(countOperators(sparkPlan, classOf[ObjectHashAggregateExec]) > 1) + + withSQLConf( + CometConf.COMET_ENABLE_PARTIAL_HASH_AGGREGATE.key -> "false", + CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") { + val transformedPlan = applyCometExecRule(sparkPlan) + + // With the partial/merge stages forced to Spark, no Comet aggregate may consume their + // incompatible buffers either. + assert(countOperators(transformedPlan, classOf[CometHashAggregateExec]) == 0) + } + } + } + test("CometExecRule should not convert hash aggregate when grouping key contains map type") { // Spark 3.4/3.5 reject `array>` as a grouping key in the analyzer (not orderable), // so the plan never reaches CometExecRule on those versions. The guard we're exercising From edb767934f7718ecbad3ac4dba2f2b7908b69e45 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 3 Jul 2026 11:45:45 -0600 Subject: [PATCH 12/12] test: regenerate TPC-DS golden files for updated aggregate fallback reason --- .../approved-plans-v1_4-spark3_5/q70/extended.txt | 2 +- .../approved-plans-v1_4/q10/extended.txt | 2 +- .../approved-plans-v1_4/q35/extended.txt | 2 +- .../approved-plans-v1_4/q45/extended.txt | 2 +- .../approved-plans-v2_7-spark3_5/q70a/extended.txt | 10 +++++----- .../approved-plans-v2_7/q35/extended.txt | 2 +- 6 files changed, 10 insertions(+), 10 deletions(-) diff --git a/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4-spark3_5/q70/extended.txt b/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4-spark3_5/q70/extended.txt index 14ad77dad4..bbea84df75 100644 --- a/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4-spark3_5/q70/extended.txt +++ b/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4-spark3_5/q70/extended.txt @@ -4,7 +4,7 @@ CometNativeColumnarToRow +- CometWindowExec +- CometSort +- CometColumnarExchange - +- HashAggregate [COMET: Spark Final aggregate without Comet Partial requires compatible intermediate buffer formats] + +- HashAggregate [COMET: Comet aggregate that merges intermediate buffers requires a Comet child aggregate when the intermediate buffer formats are incompatible with Spark] +- Exchange +- HashAggregate +- Expand diff --git a/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q10/extended.txt b/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q10/extended.txt index 07af300183..f5c83651a7 100644 --- a/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q10/extended.txt +++ b/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q10/extended.txt @@ -1,5 +1,5 @@ TakeOrderedAndProject -+- HashAggregate [COMET: Spark Final aggregate without Comet Partial requires compatible intermediate buffer formats] ++- HashAggregate [COMET: Comet aggregate that merges intermediate buffers requires a Comet child aggregate when the intermediate buffer formats are incompatible with Spark] +- Exchange +- HashAggregate +- Project diff --git a/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q35/extended.txt b/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q35/extended.txt index 07af300183..f5c83651a7 100644 --- a/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q35/extended.txt +++ b/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q35/extended.txt @@ -1,5 +1,5 @@ TakeOrderedAndProject -+- HashAggregate [COMET: Spark Final aggregate without Comet Partial requires compatible intermediate buffer formats] ++- HashAggregate [COMET: Comet aggregate that merges intermediate buffers requires a Comet child aggregate when the intermediate buffer formats are incompatible with Spark] +- Exchange +- HashAggregate +- Project diff --git a/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q45/extended.txt b/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q45/extended.txt index f95c69368f..489056587b 100644 --- a/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q45/extended.txt +++ b/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q45/extended.txt @@ -1,5 +1,5 @@ TakeOrderedAndProject -+- HashAggregate [COMET: Spark Final aggregate without Comet Partial requires compatible intermediate buffer formats] ++- HashAggregate [COMET: Comet aggregate that merges intermediate buffers requires a Comet child aggregate when the intermediate buffer formats are incompatible with Spark] +- Exchange +- HashAggregate +- Project diff --git a/spark/src/test/resources/tpcds-plan-stability/approved-plans-v2_7-spark3_5/q70a/extended.txt b/spark/src/test/resources/tpcds-plan-stability/approved-plans-v2_7-spark3_5/q70a/extended.txt index 10ea854de0..9df0eb8a01 100644 --- a/spark/src/test/resources/tpcds-plan-stability/approved-plans-v2_7-spark3_5/q70a/extended.txt +++ b/spark/src/test/resources/tpcds-plan-stability/approved-plans-v2_7-spark3_5/q70a/extended.txt @@ -8,7 +8,7 @@ CometNativeColumnarToRow +- CometColumnarExchange +- HashAggregate +- Union - :- HashAggregate [COMET: Spark Final aggregate without Comet Partial requires compatible intermediate buffer formats] + :- HashAggregate [COMET: Comet aggregate that merges intermediate buffers requires a Comet child aggregate when the intermediate buffer formats are incompatible with Spark] : +- Exchange : +- HashAggregate : +- Project @@ -58,10 +58,10 @@ CometNativeColumnarToRow : +- CometProject : +- CometFilter : +- CometNativeScan parquet spark_catalog.default.date_dim - :- HashAggregate [COMET: Spark Final aggregate without Comet Partial requires compatible intermediate buffer formats] + :- HashAggregate [COMET: Comet aggregate that merges intermediate buffers requires a Comet child aggregate when the intermediate buffer formats are incompatible with Spark] : +- Exchange : +- HashAggregate - : +- HashAggregate [COMET: Spark Final aggregate without Comet Partial requires compatible intermediate buffer formats] + : +- HashAggregate [COMET: Comet aggregate that merges intermediate buffers requires a Comet child aggregate when the intermediate buffer formats are incompatible with Spark] : +- Exchange : +- HashAggregate : +- Project @@ -111,10 +111,10 @@ CometNativeColumnarToRow : +- CometProject : +- CometFilter : +- CometNativeScan parquet spark_catalog.default.date_dim - +- HashAggregate [COMET: Spark Final aggregate without Comet Partial requires compatible intermediate buffer formats] + +- HashAggregate [COMET: Comet aggregate that merges intermediate buffers requires a Comet child aggregate when the intermediate buffer formats are incompatible with Spark] +- Exchange +- HashAggregate - +- HashAggregate [COMET: Spark Final aggregate without Comet Partial requires compatible intermediate buffer formats] + +- HashAggregate [COMET: Comet aggregate that merges intermediate buffers requires a Comet child aggregate when the intermediate buffer formats are incompatible with Spark] +- Exchange +- HashAggregate +- Project diff --git a/spark/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q35/extended.txt b/spark/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q35/extended.txt index 07af300183..f5c83651a7 100644 --- a/spark/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q35/extended.txt +++ b/spark/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q35/extended.txt @@ -1,5 +1,5 @@ TakeOrderedAndProject -+- HashAggregate [COMET: Spark Final aggregate without Comet Partial requires compatible intermediate buffer formats] ++- HashAggregate [COMET: Comet aggregate that merges intermediate buffers requires a Comet child aggregate when the intermediate buffer formats are incompatible with Spark] +- Exchange +- HashAggregate +- Project