From 795f6f9ae84875562565455e65f14b72b6f038cb Mon Sep 17 00:00:00 2001 From: Manu Zhang Date: Wed, 1 Jul 2026 23:52:28 +0800 Subject: [PATCH 1/2] fix: match Spark percentile interpolation precision Use a Comet-native percentile aggregate for Spark percentile so interpolation keeps Spark's full-precision weight and can run compatible by default. Add native and SQL regressions for deeply interpolated percentile values. Co-authored-by: Codex --- native/core/src/execution/planner.rs | 30 +- native/spark-expr/src/agg_funcs/mod.rs | 2 + native/spark-expr/src/agg_funcs/percentile.rs | 342 ++++++++++++++++++ .../org/apache/comet/serde/aggregates.scala | 22 +- .../apache/spark/sql/comet/operators.scala | 84 +++-- .../expressions/aggregate/percentile.sql | 19 +- .../aggregate/percentile_within_group.sql | 15 +- .../CometAggregateExpressionBenchmark.scala | 2 +- 8 files changed, 440 insertions(+), 76 deletions(-) create mode 100644 native/spark-expr/src/agg_funcs/percentile.rs diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 049c8c13e9..cb7fa2dc1e 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -41,7 +41,6 @@ use datafusion::functions_aggregate::bit_and_or_xor::{bit_and_udaf, bit_or_udaf, use datafusion::functions_aggregate::count::count_udaf; use datafusion::functions_aggregate::min_max::max_udaf; use datafusion::functions_aggregate::min_max::min_udaf; -use datafusion::functions_aggregate::percentile_cont::percentile_cont_udaf; use datafusion::functions_aggregate::sum::sum_udaf; use datafusion::physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; use datafusion::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec}; @@ -74,7 +73,7 @@ use datafusion::{ use datafusion_comet_spark_expr::{ create_comet_physical_fun, create_comet_physical_fun_with_eval_mode, BinaryOutputStyle, BloomFilterAgg, BloomFilterMightContain, CsvWriteOptions, EvalMode, SparkArraysZipFunc, - SparkBloomFilterVersion, SumInteger, ToCsv, + SparkBloomFilterVersion, SparkPercentile, SumInteger, ToCsv, }; use datafusion_spark::function::aggregate::collect::SparkCollectSet; use iceberg::expr::Bind; @@ -2606,12 +2605,14 @@ impl PhysicalPlanner { let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?; let percentile = self.create_expr(expr.percentage.as_ref().unwrap(), Arc::clone(&schema))?; - // DataFusion's percentile_cont uses the same `index = p * (n - 1)` linear - // interpolation as Spark's exact Percentile, so results match for the single - // percentage case wired here. - AggregateExprBuilder::new(percentile_cont_udaf(), vec![child, percentile]) + // Spark's exact Percentile uses full-precision linear interpolation. Comet uses + // its own UDAF rather than DataFusion's percentile_cont because DataFusion + // quantizes the interpolation weight. + let percentile_value = percentile_value(expr.percentage.as_ref().unwrap())?; + let func = AggregateUDF::new_from_impl(SparkPercentile::try_new(percentile_value)?); + AggregateExprBuilder::new(func.into(), vec![child, percentile]) .schema(schema) - .alias("percentile_cont") + .alias("percentile") .with_ignore_nulls(false) .with_distinct(false) .build() @@ -3263,6 +3264,21 @@ impl PhysicalPlanner { } } +fn percentile_value(expr: &spark_expression::Expr) -> Result { + match &expr.expr_struct { + Some(ExprStruct::Literal(literal)) if !literal.is_null => match &literal.value { + Some(Value::DoubleVal(value)) => Ok(*value), + Some(Value::FloatVal(value)) => Ok(*value as f64), + _ => Err(GeneralError( + "Percentile value must be a floating-point literal".to_string(), + )), + }, + _ => Err(GeneralError( + "Percentile value must be a non-null literal".to_string(), + )), + } +} + /// Collects the indices of the columns in the input schema that are used in the expression /// and returns them as a pair of vectors, one for the left side and one for the right side. fn expr_to_columns( diff --git a/native/spark-expr/src/agg_funcs/mod.rs b/native/spark-expr/src/agg_funcs/mod.rs index 2a0322e46c..70100c1a31 100644 --- a/native/spark-expr/src/agg_funcs/mod.rs +++ b/native/spark-expr/src/agg_funcs/mod.rs @@ -19,6 +19,7 @@ mod avg; mod avg_decimal; mod correlation; mod covariance; +mod percentile; mod stddev; mod sum_decimal; mod sum_int; @@ -29,6 +30,7 @@ pub use avg::Avg; pub use avg_decimal::AvgDecimal; pub use correlation::Correlation; pub use covariance::Covariance; +pub use percentile::SparkPercentile; pub use stddev::Stddev; pub use sum_decimal::SumDecimal; pub use sum_int::SumInteger; diff --git a/native/spark-expr/src/agg_funcs/percentile.rs b/native/spark-expr/src/agg_funcs/percentile.rs new file mode 100644 index 0000000000..3fddfdcb68 --- /dev/null +++ b/native/spark-expr/src/agg_funcs/percentile.rs @@ -0,0 +1,342 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ + Array, ArrayRef, AsArray, BooleanArray, Float64Array, Float64Builder, ListArray, +}; +use arrow::buffer::{OffsetBuffer, ScalarBuffer}; +use arrow::datatypes::{DataType, Field, FieldRef, Float64Type}; +use datafusion::common::{internal_err, plan_err, Result, ScalarValue}; +use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion::logical_expr::Volatility::Immutable; +use datafusion::logical_expr::{ + Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, ReversedUDAF, Signature, +}; +use datafusion::physical_expr::expressions::format_state_name; +use std::cmp::Ordering; +use std::mem::{size_of, size_of_val}; +use std::sync::Arc; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct SparkPercentile { + signature: Signature, + percentile_bits: u64, +} + +impl SparkPercentile { + pub fn try_new(percentile: f64) -> Result { + validate_percentile(percentile)?; + Ok(Self { + signature: Signature::user_defined(Immutable), + percentile_bits: percentile.to_bits(), + }) + } + + fn percentile(&self) -> f64 { + f64::from_bits(self.percentile_bits) + } +} + +impl AggregateUDFImpl for SparkPercentile { + fn accumulator(&self, _args: AccumulatorArgs) -> Result> { + Ok(Box::new(SparkPercentileAccumulator::new(self.percentile()))) + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + let input_type = args.input_fields[0].data_type(); + if input_type != &DataType::Float64 { + return internal_err!("SparkPercentile expects Float64 input, got {input_type}"); + } + Ok(vec![Arc::new(Field::new( + format_state_name(args.name, self.name()), + DataType::List(Arc::new(Field::new_list_field(DataType::Float64, true))), + true, + ))]) + } + + fn name(&self) -> &str { + "percentile" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if arg_types.first() != Some(&DataType::Float64) { + return internal_err!( + "SparkPercentile return type expects Float64 input, got {:?}", + arg_types.first() + ); + } + Ok(DataType::Float64) + } + + fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { + !args.is_distinct && args.expr_fields[0].data_type() == &DataType::Float64 + } + + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { + Ok(Box::new(SparkPercentileGroupsAccumulator::new( + self.percentile(), + ))) + } + + fn default_value(&self, _data_type: &DataType) -> Result { + Ok(ScalarValue::Float64(None)) + } + + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::Identical + } + + fn is_nullable(&self) -> bool { + true + } +} + +fn validate_percentile(percentile: f64) -> Result<()> { + if !(0.0..=1.0).contains(&percentile) { + return plan_err!( + "Percentile value must be between 0.0 and 1.0 inclusive, got {percentile}." + ); + } + Ok(()) +} + +#[derive(Debug)] +struct SparkPercentileAccumulator { + values: Vec, + percentile: f64, +} + +impl SparkPercentileAccumulator { + fn new(percentile: f64) -> Self { + Self { + values: vec![], + percentile, + } + } +} + +impl Accumulator for SparkPercentileAccumulator { + fn state(&mut self) -> Result> { + let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0, self.values.len() as i32])); + let values = Float64Array::new(ScalarBuffer::from(self.values.clone()), None); + let list = ListArray::new( + Arc::new(Field::new_list_field(DataType::Float64, true)), + offsets, + Arc::new(values), + None, + ); + Ok(vec![ScalarValue::List(Arc::new(list))]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = values[0].as_primitive::(); + self.values.reserve(values.len() - values.null_count()); + self.values.extend(values.iter().flatten()); + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let states = states[0].as_list::(); + for state in states.iter().flatten() { + self.update_batch(&[state])?; + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + Ok(ScalarValue::Float64(spark_percentile( + self.values.as_mut_slice(), + self.percentile, + ))) + } + + fn size(&self) -> usize { + size_of_val(self) + self.values.capacity() * size_of::() + } +} + +#[derive(Debug)] +struct SparkPercentileGroupsAccumulator { + group_values: Vec>, + percentile: f64, +} + +impl SparkPercentileGroupsAccumulator { + fn new(percentile: f64) -> Self { + Self { + group_values: vec![], + percentile, + } + } +} + +impl GroupsAccumulator for SparkPercentileGroupsAccumulator { + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + let values = values[0].as_primitive::(); + self.group_values.resize(total_num_groups, Vec::new()); + + for (row, &group_index) in group_indices.iter().enumerate() { + if let Some(filter) = opt_filter { + if !filter.is_valid(row) || !filter.value(row) { + continue; + } + } + if values.is_null(row) { + continue; + } + self.group_values[group_index].push(values.value(row)); + } + + Ok(()) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + _opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + let input_group_values = values[0].as_list::(); + self.group_values.resize(total_num_groups, Vec::new()); + + for (&group_index, values) in group_indices.iter().zip(input_group_values.iter()) { + if let Some(values) = values { + let values = values.as_primitive::(); + self.group_values[group_index].extend(values.iter().flatten()); + } + } + + Ok(()) + } + + fn state(&mut self, emit_to: EmitTo) -> Result> { + let emit_group_values = emit_to.take_needed(&mut self.group_values); + + let mut offsets = Vec::with_capacity(emit_group_values.len() + 1); + offsets.push(0); + let mut len = 0_i32; + for values in &emit_group_values { + len += values.len() as i32; + offsets.push(len); + } + + let values = emit_group_values.into_iter().flatten().collect::>(); + let values = Float64Array::new(ScalarBuffer::from(values), None); + let list = ListArray::new( + Arc::new(Field::new_list_field(DataType::Float64, true)), + OffsetBuffer::new(ScalarBuffer::from(offsets)), + Arc::new(values), + None, + ); + Ok(vec![Arc::new(list)]) + } + + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + let mut emit_group_values = emit_to.take_needed(&mut self.group_values); + let mut builder = Float64Builder::with_capacity(emit_group_values.len()); + for values in &mut emit_group_values { + builder.append_option(spark_percentile(values.as_mut_slice(), self.percentile)); + } + Ok(Arc::new(builder.finish())) + } + + fn size(&self) -> usize { + self.group_values + .iter() + .map(|values| values.capacity() * size_of::()) + .sum::() + + self.group_values.capacity() * size_of::>() + } +} + +fn spark_percentile(values: &mut [f64], percentile: f64) -> Option { + let len = values.len(); + if len == 0 { + return None; + } + if len == 1 { + return Some(values[0]); + } + + let position = (len - 1) as f64 * percentile; + let lower = position.floor() as usize; + let higher = position.ceil() as usize; + + let (_, lower_value, _) = values.select_nth_unstable_by(lower, spark_double_cmp); + let lower_value = *lower_value; + if lower == higher { + return Some(lower_value); + } + + let (_, higher_value, _) = values.select_nth_unstable_by(higher, spark_double_cmp); + let higher_value = *higher_value; + if spark_double_cmp(&lower_value, &higher_value) == Ordering::Equal { + return Some(lower_value); + } + + Some((higher as f64 - position) * lower_value + (position - lower as f64) * higher_value) +} + +fn spark_double_cmp(x: &f64, y: &f64) -> Ordering { + if x == y || (x.is_nan() && y.is_nan()) { + Ordering::Equal + } else if x.is_nan() { + Ordering::Greater + } else if y.is_nan() { + Ordering::Less + } else { + x.partial_cmp(y) + .expect("non-NaN values should be comparable") + } +} + +#[cfg(test)] +mod tests { + use super::{spark_double_cmp, spark_percentile}; + use std::cmp::Ordering; + + #[test] + fn interpolates_with_full_spark_precision() { + let mut values = vec![0.0, 10_000_000.0]; + assert_eq!( + spark_percentile(&mut values, 0.123456789), + Some(1_234_567.89) + ); + } + + #[test] + fn matches_spark_double_ordering_for_nan_and_zero() { + assert_eq!(spark_double_cmp(&f64::NAN, &1.0), Ordering::Greater); + assert_eq!(spark_double_cmp(&1.0, &f64::NAN), Ordering::Less); + assert_eq!(spark_double_cmp(&f64::NAN, &f64::NAN), Ordering::Equal); + assert_eq!(spark_double_cmp(&-0.0, &0.0), Ordering::Equal); + } +} diff --git a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala index 5710232cb4..24e02aa2b3 100644 --- a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala +++ b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala @@ -598,18 +598,11 @@ object CometPercentile extends CometAggregateExpressionSerde[Percentile] { private val nonLiteralPercentageReason = "The percentage argument must be a literal." private val frequencyReason = "A frequency argument is not supported." // `reverse` is set when `percentile_cont`/`percentile_disc` is used with - // `WITHIN GROUP (ORDER BY ... DESC)` on Spark 4.0+. The native `percentile_cont` always + // `WITHIN GROUP (ORDER BY ... DESC)` on Spark 4.0+. The native percentile UDAF always // interpolates in ascending order, so the descending form would return a wrong answer. private val descendingReason = "Descending order in `WITHIN GROUP (ORDER BY ... DESC)` is not supported." private val inputTypeReason = "Only numeric input types are supported." - // DataFusion's percentile_cont quantizes the linear interpolation weight to 6 decimal places, - // so an interpolated percentile may differ from Spark by up to `(upper - lower) * 1e-6`. - // See #4719. - private val precisionReason = - "Interpolated values may differ from Spark by up to `(upper - lower) * 1e-6` because" + - " DataFusion quantizes the interpolation weight to 6 decimal places (#4719)." - override def getUnsupportedReasons(): Seq[String] = Seq( arrayOfPercentagesReason, nonLiteralPercentageReason, @@ -617,15 +610,10 @@ object CometPercentile extends CometAggregateExpressionSerde[Percentile] { descendingReason, inputTypeReason) - override def getIncompatibleReasons(): Seq[String] = Seq(precisionReason) - override def getSupportLevel(expr: Percentile): SupportLevel = { // Only the single-percentage, default-frequency, numeric-input, ascending form is wired - // today. It maps to DataFusion's percentile_cont, which uses the same `index = p * (n - 1)` - // linear interpolation as Spark's exact Percentile, but quantizes the interpolation weight to - // 6 decimal places (see precisionReason / #4719), so the supported form is Incompatible rather - // than Compatible. Array-of-percentages, a non-default frequency argument, descending order, - // and interval inputs fall back to Spark. + // today. It maps to Comet's Spark-compatible percentile UDAF. Array-of-percentages, a + // non-default frequency argument, descending order, and interval inputs fall back to Spark. if (expr.percentageExpression.dataType != DoubleType) { return Unsupported(Some(arrayOfPercentagesReason)) } @@ -640,7 +628,7 @@ object CometPercentile extends CometAggregateExpressionSerde[Percentile] { return Unsupported(Some(descendingReason)) } expr.child.dataType match { - case _: NumericType => Incompatible(Some(precisionReason)) + case _: NumericType => Compatible() case _ => Unsupported(Some(inputTypeReason)) } } @@ -652,7 +640,7 @@ object CometPercentile extends CometAggregateExpressionSerde[Percentile] { binding: Boolean, conf: SQLConf): Option[ExprOuterClass.AggExpr] = { // Spark computes the percentile over the values as doubles; cast the child up front so the - // native percentile_cont returns Float64 / DoubleType to match Spark. + // native percentile UDAF returns Float64 / DoubleType to match Spark. val childExpr = exprToProto(Cast(percentile.child, DoubleType), inputs, binding) val percentageExpr = exprToProto(Literal(percentile.percentageExpression.eval(), DoubleType), inputs, binding) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index e4d6b53770..a73441dd45 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -1775,6 +1775,46 @@ trait CometBaseAggregate { .build()) } + /** + * For partial-like aggregates containing TypedImperativeAggregate functions (like CollectSet + * and Percentile), the Spark-side output declares buffer columns as BinaryType because Spark + * serializes state to binary. Native Comet emits the actual state type, so fix the exposed + * output schema before shuffle/exchange code consumes it. + * + * NOTE: If a new TypedImperativeAggregate function (e.g., CollectList) is added natively, add a + * case branch here mapping it to the native state type. + */ + protected def adjustOutputForNativeState(op: BaseAggregateExec): Seq[Attribute] = { + val modeSet = op.aggregateExpressions.map(_.mode).toSet + if (modeSet.isEmpty || !modeSet.subsetOf(Set(Partial, PartialMerge))) { + return op.output + } + + val numGrouping = op.groupingExpressions.length + val output = op.output.toArray + + var bufferIdx = numGrouping + for (aggExpr <- op.aggregateExpressions) { + val aggFunc = aggExpr.aggregateFunction + val bufferAttrs = aggFunc.aggBufferAttributes + aggFunc match { + case cs: CollectSet => + val elementType = cs.children.head.dataType + val nativeStateType = ArrayType(elementType, containsNull = true) + output(bufferIdx) = output(bufferIdx).withDataType(nativeStateType) + case _: Percentile => + // Comet's native percentile UDAF keeps all values in a List partial state. + // Comet casts the child to double, so the native state is ArrayType(DoubleType). + val nativeStateType = ArrayType(DoubleType, containsNull = true) + output(bufferIdx) = output(bufferIdx).withDataType(nativeStateType) + case _ => + } + bufferIdx += bufferAttrs.length + } + + output.toSeq + } + /** * Find the first Comet partial aggregate in the plan. If it reaches a Spark HashAggregate with * partial or partial-merge mode, it will return None. @@ -1832,7 +1872,7 @@ object CometHashAggregateExec CometHashAggregateExec( nativeOp, op, - op.output, + adjustOutputForNativeState(op), op.groupingExpressions, op.aggregateExpressions, op.resultExpressions, @@ -1890,48 +1930,6 @@ object CometObjectHashAggregateExec op.child, SerializedPlan(None)) } - - /** - * For Partial mode aggregates containing TypedImperativeAggregate functions (like CollectSet), - * the Spark-side output declares buffer columns as BinaryType (since Spark serializes state to - * binary). However, the native Comet aggregate produces the actual state type (e.g., - * ArrayType(elementType) for CollectSet). This method corrects the output schema to match the - * native state types so the shuffle exchange schema is consistent with the actual data. - * - * NOTE: If a new TypedImperativeAggregate function (e.g., CollectList) is added natively, add a - * case branch here mapping it to the native state type. - */ - private def adjustOutputForNativeState(op: ObjectHashAggregateExec): Seq[Attribute] = { - // This adjustment only applies to pure-Partial aggregates (checked below). - val modes = op.aggregateExpressions.map(_.mode).distinct - if (modes != Seq(Partial)) { - return op.output - } - - val numGrouping = op.groupingExpressions.length - val output = op.output.toArray - - var bufferIdx = numGrouping - for (aggExpr <- op.aggregateExpressions) { - val aggFunc = aggExpr.aggregateFunction - val bufferAttrs = aggFunc.aggBufferAttributes - aggFunc match { - case cs: CollectSet => - val elementType = cs.children.head.dataType - val nativeStateType = ArrayType(elementType, containsNull = true) - output(bufferIdx) = output(bufferIdx).withDataType(nativeStateType) - case _: Percentile => - // DataFusion's percentile_cont keeps all values in a List partial state. - // Comet casts the child to double, so the native state is ArrayType(DoubleType). - val nativeStateType = ArrayType(DoubleType, containsNull = true) - output(bufferIdx) = output(bufferIdx).withDataType(nativeStateType) - case _ => - } - bufferIdx += bufferAttrs.length - } - - output.toSeq - } } case class CometHashAggregateExec( diff --git a/spark/src/test/resources/sql-tests/expressions/aggregate/percentile.sql b/spark/src/test/resources/sql-tests/expressions/aggregate/percentile.sql index ab3d1706a5..6ecf592d47 100644 --- a/spark/src/test/resources/sql-tests/expressions/aggregate/percentile.sql +++ b/spark/src/test/resources/sql-tests/expressions/aggregate/percentile.sql @@ -15,10 +15,7 @@ -- specific language governing permissions and limitations -- under the License. --- Native exact percentile via DataFusion percentile_cont (same (n-1)*p interpolation as Spark). --- Marked Incompatible because DataFusion quantizes the interpolation weight to 6 decimal places --- (#4719); allow it here so the native path is exercised. --- Config: spark.comet.expression.Percentile.allowIncompatible=true +-- Native exact percentile via Comet's Spark-compatible percentile UDAF. statement CREATE TABLE test_percentile(g int, v double, i int) USING parquet @@ -28,6 +25,12 @@ INSERT INTO test_percentile VALUES (1, 1.0, 10), (1, 2.0, 20), (1, 3.0, 30), (1, 4.0, 40), (2, 10.0, 5), (2, 20.0, 15), (2, NULL, 25) +statement +CREATE TABLE test_percentile_precision(v double) USING parquet + +statement +INSERT INTO test_percentile_precision VALUES (0.0), (10000000.0) + -- global percentile, interpolated and exact-rank cases query SELECT percentile(v, 0.5) FROM test_percentile @@ -35,10 +38,18 @@ SELECT percentile(v, 0.5) FROM test_percentile query SELECT percentile(v, 0.0), percentile(v, 1.0), percentile(v, 0.25), percentile(v, 0.9) FROM test_percentile +-- deeply interpolated percentile that would differ if the interpolation weight were quantized +query +SELECT percentile(v, 0.123456789) FROM test_percentile_precision + -- grouped query SELECT g, percentile(v, 0.5) FROM test_percentile GROUP BY g ORDER BY g +-- mixed distinct aggregate plans use PartialMerge; percentile must preserve its percentage +query +SELECT g, count(DISTINCT i), percentile(v, 0.5) FROM test_percentile GROUP BY g ORDER BY g + -- integer input (cast to double) query SELECT percentile(i, 0.5) FROM test_percentile diff --git a/spark/src/test/resources/sql-tests/expressions/aggregate/percentile_within_group.sql b/spark/src/test/resources/sql-tests/expressions/aggregate/percentile_within_group.sql index 4def3feef7..e81529b454 100644 --- a/spark/src/test/resources/sql-tests/expressions/aggregate/percentile_within_group.sql +++ b/spark/src/test/resources/sql-tests/expressions/aggregate/percentile_within_group.sql @@ -18,10 +18,7 @@ -- percentile_cont(p) WITHIN GROUP (ORDER BY col) was added in Spark 4.0. It is a -- RuntimeReplaceable that rewrites to Percentile(col, p, reverse), so the ascending form runs -- natively through Comet, while the descending (DESC) form falls back to Spark because the --- native percentile_cont always interpolates in ascending order. --- Percentile is marked Incompatible because DataFusion quantizes the interpolation weight to 6 --- decimal places (#4719); allow it here so the ascending native path is exercised. --- Config: spark.comet.expression.Percentile.allowIncompatible=true +-- native percentile UDAF always interpolates in ascending order. -- MinSparkVersion: 4.0 statement @@ -31,10 +28,20 @@ statement INSERT INTO test_pct_wg VALUES (1, 1.0), (1, 2.0), (1, 3.0), (1, 4.0), (2, 10.0), (2, 20.0), (2, NULL) +statement +CREATE TABLE test_pct_wg_precision(v double) USING parquet + +statement +INSERT INTO test_pct_wg_precision VALUES (0.0), (10000000.0) + -- ascending WITHIN GROUP runs natively query SELECT percentile_cont(0.5) WITHIN GROUP (ORDER BY v) FROM test_pct_wg +-- deeply interpolated percentile that would differ if the interpolation weight were quantized +query +SELECT percentile_cont(0.123456789) WITHIN GROUP (ORDER BY v) FROM test_pct_wg_precision + query SELECT g, percentile_cont(0.25) WITHIN GROUP (ORDER BY v) FROM test_pct_wg GROUP BY g ORDER BY g diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometAggregateExpressionBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometAggregateExpressionBenchmark.scala index a9ee46802a..256ba6c421 100644 --- a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometAggregateExpressionBenchmark.scala +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometAggregateExpressionBenchmark.scala @@ -104,7 +104,7 @@ object CometAggregateExpressionBenchmark extends CometBenchmarkBase { "SELECT COUNT(DISTINCT c_int) FROM parquetV1Table GROUP BY high_card_grp")) // Exact percentile. Only the single-percentage, default-frequency, numeric-input form runs - // natively (maps to DataFusion's percentile_cont); other forms fall back to Spark. + // natively through Comet's Spark-compatible percentile UDAF; other forms fall back to Spark. private val percentileAggregates = List( AggExprConfig( "percentile_int_median", From 75cd77f6c2a12b3e1913e45c129f4eb14e436a0d Mon Sep 17 00:00:00 2001 From: Manu Zhang Date: Fri, 3 Jul 2026 07:25:23 +0800 Subject: [PATCH 2/2] fix: address percentile review comments Refresh percentile support docs, keep native-state schema adjustment on the object aggregate path, and add SQL coverage for special double ordering. Co-authored-by: Codex --- docs/source/user-guide/latest/expressions.md | 6 +++--- .../scala/org/apache/spark/sql/comet/operators.scala | 2 +- .../sql-tests/expressions/aggregate/percentile.sql | 10 ++++++++++ 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/docs/source/user-guide/latest/expressions.md b/docs/source/user-guide/latest/expressions.md index 680422992c..3b650e6d57 100644 --- a/docs/source/user-guide/latest/expressions.md +++ b/docs/source/user-guide/latest/expressions.md @@ -97,12 +97,12 @@ The tables below list every Spark built-in expression with its current status. | `max` | ✅ | | | `max_by` | 🔜 | [#3841](https://github.com/apache/datafusion-comet/issues/3841) | | `mean` | ✅ | | -| `median` | ✅ | Rewrites to `percentile(col, 0.5)`; falls back by default, opt-in via allowIncompatible ([#4719](https://github.com/apache/datafusion-comet/issues/4719)) | +| `median` | ✅ | Rewrites to `percentile(col, 0.5)` and runs natively for supported percentile inputs | | `min` | ✅ | | | `min_by` | 🔜 | [#3841](https://github.com/apache/datafusion-comet/issues/3841) | | `mode` | 🔜 | [#3970](https://github.com/apache/datafusion-comet/issues/3970) | -| `percentile` | ✅ | Single literal percentage on numeric input; array of percentages and a frequency argument fall back to Spark. Falls back by default, opt-in via allowIncompatible ([#4719](https://github.com/apache/datafusion-comet/issues/4719)) | -| `percentile_cont` | ✅ | Spark 4.0+ `WITHIN GROUP (ORDER BY ...)`; ascending only, `DESC` falls back to Spark. Falls back by default, opt-in via allowIncompatible ([#4719](https://github.com/apache/datafusion-comet/issues/4719)) | +| `percentile` | ✅ | Single literal percentage on numeric input runs natively; array of percentages and a frequency argument fall back to Spark | +| `percentile_cont` | ✅ | Spark 4.0+ `WITHIN GROUP (ORDER BY ...)`; ascending only runs natively, `DESC` falls back to Spark | | `percentile_disc` | 🔜 | Percentile aggregate | | `regr_avgx` | ✅ | Native: Spark rewrites to `Average` (tests in [#4551](https://github.com/apache/datafusion-comet/issues/4551)) | | `regr_avgy` | ✅ | Native: Spark rewrites to `Average` (tests in [#4551](https://github.com/apache/datafusion-comet/issues/4551)) | diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index a73441dd45..f12e978065 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -1872,7 +1872,7 @@ object CometHashAggregateExec CometHashAggregateExec( nativeOp, op, - adjustOutputForNativeState(op), + op.output, op.groupingExpressions, op.aggregateExpressions, op.resultExpressions, diff --git a/spark/src/test/resources/sql-tests/expressions/aggregate/percentile.sql b/spark/src/test/resources/sql-tests/expressions/aggregate/percentile.sql index 6ecf592d47..2a2306bf6f 100644 --- a/spark/src/test/resources/sql-tests/expressions/aggregate/percentile.sql +++ b/spark/src/test/resources/sql-tests/expressions/aggregate/percentile.sql @@ -95,6 +95,16 @@ INSERT INTO test_percentile_neg VALUES (-10.0), (-5.0), (0.0), (5.0), (10.0) query SELECT percentile(v, 0.5), percentile(v, 0.1), percentile(v, 0.9) FROM test_percentile_neg +statement +CREATE TABLE test_percentile_special(v double) USING parquet + +statement +INSERT INTO test_percentile_special VALUES + (double('-Infinity')), (-0.0), (0.0), (1.0), (double('Infinity')), (double('NaN')) + +query +SELECT percentile(v, 0.0), percentile(v, 0.5), percentile(v, 0.8), percentile(v, 1.0) FROM test_percentile_special + -- ============================================================ -- Unsupported forms fall back to Spark cleanly -- ============================================================