From a303884ece12256bbdf86c6768ea1a081422abda Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 2 Jul 2026 14:13:50 -0600 Subject: [PATCH 1/2] feat: support PivotFirst aggregate for the optimized PIVOT fast path Wire the Spark-internal PivotFirst aggregate through Comet so that PIVOT queries which trigger Spark's two-phase fast path (12+ pivot values with an aggregate whose output type is in PivotFirst.supportsDataType) run natively instead of falling back. Scaffolded with the implement-comet-expression skill. --- .../expression-audits/agg_funcs.md | 8 + docs/source/user-guide/latest/expressions.md | 1 + native/core/src/execution/planner.rs | 26 +- native/proto/src/proto/expr.proto | 16 + native/spark-expr/src/agg_funcs/mod.rs | 2 + .../spark-expr/src/agg_funcs/pivot_first.rs | 347 ++++++++++++++++++ .../apache/comet/serde/QueryPlanSerde.scala | 1 + .../org/apache/comet/serde/aggregates.scala | 81 +++- .../expressions/aggregate/PivotFirst.sql | 326 ++++++++++++++++ .../comet/exec/CometAggregateSuite.scala | 42 +++ 10 files changed, 847 insertions(+), 3 deletions(-) create mode 100644 native/spark-expr/src/agg_funcs/pivot_first.rs create mode 100644 spark/src/test/resources/sql-tests/expressions/aggregate/PivotFirst.sql diff --git a/docs/source/contributor-guide/expression-audits/agg_funcs.md b/docs/source/contributor-guide/expression-audits/agg_funcs.md index fb27662b49..c56082f283 100644 --- a/docs/source/contributor-guide/expression-audits/agg_funcs.md +++ b/docs/source/contributor-guide/expression-audits/agg_funcs.md @@ -54,4 +54,12 @@ - Spark 4.1.1 (audited 2026-06-24): identical to 4.0.1. - `CometPercentile` reports `Incompatible` for the otherwise-supported form because DataFusion's `percentile_cont` quantizes the interpolation weight to 6 decimal places (`INTERPOLATION_PRECISION = 1e6`), so a deeply-interpolated value can differ from Spark by up to roughly `(upper - lower) * 1e-6`. The native path is opt-in via `spark.comet.expression.Percentile.allowIncompatible=true` ([#4719](https://github.com/apache/datafusion-comet/issues/4719)). +## pivot_first + +- Spark 3.4.3 (audited 2026-07-02): `PivotFirst(pivotColumn, valueColumn, pivotColumnValues)` is an internal `ImperativeAggregate` emitted only by the optimized-pivot fast path in `Analyzer.ResolvePivot`. It buckets `valueColumn` into an array of length `pivotColumnValues.size` indexed by matching `pivotColumn`. Null value columns are ignored, unmatched pivot values are dropped. Value types are gated by `PivotFirst.supportsDataType` (Boolean, Byte, Short, Int, Long, Float, Double, Decimal). +- Spark 3.5.8 (audited 2026-07-02): swaps `StructType.fromAttributes` for `DataTypeUtils.fromAttributes`; no behavior change. +- Spark 4.0.1 (audited 2026-07-02): identical to 3.5.8. +- Spark 4.1.1 (audited 2026-07-02): identical to 4.0.1. (Spark `master` adds a defensive `findPivotIndex` wrapper that returns `-1` for null keys on the non-`AtomicType` `TreeMap` path; not present in any released version Comet builds against, and Comet's HashMap-backed lookup handles `ScalarValue::Null` safely on all types.) +- `CometPivotFirst` (in `spark/src/main/scala/org/apache/comet/serde/aggregates.scala`) forwards the aggregate to the native `SparkPivotFirst` UDAF (`native/spark-expr/src/agg_funcs/pivot_first.rs`) when the value type is in the supported set. State layout matches Spark's `aggBufferAttributes` (one scalar column per pivot slot) so the shuffle schema between Partial and Final stays consistent. `evaluate()` reassembles the slots into a `ListArray` matching `PivotFirst.dataType = ArrayType(valueDataType)`. + [Spark Expression Support]: ../../user-guide/latest/expressions.md diff --git a/docs/source/user-guide/latest/expressions.md b/docs/source/user-guide/latest/expressions.md index 680422992c..78386a8a6f 100644 --- a/docs/source/user-guide/latest/expressions.md +++ b/docs/source/user-guide/latest/expressions.md @@ -104,6 +104,7 @@ The tables below list every Spark built-in expression with its current status. | `percentile` | ✅ | Single literal percentage on numeric input; array of percentages and a frequency argument fall back to Spark. Falls back by default, opt-in via allowIncompatible ([#4719](https://github.com/apache/datafusion-comet/issues/4719)) | | `percentile_cont` | ✅ | Spark 4.0+ `WITHIN GROUP (ORDER BY ...)`; ascending only, `DESC` falls back to Spark. Falls back by default, opt-in via allowIncompatible ([#4719](https://github.com/apache/datafusion-comet/issues/4719)) | | `percentile_disc` | 🔜 | Percentile aggregate | +| `pivot_first` | ✅ | Internal aggregate for the optimized `PIVOT` fast path. Value type must be in `PivotFirst.supportsDataType` (Boolean, Byte, Short, Int, Long, Float, Double, Decimal); other types keep Spark's standard filtered-aggregate path. | | `regr_avgx` | ✅ | Native: Spark rewrites to `Average` (tests in [#4551](https://github.com/apache/datafusion-comet/issues/4551)) | | `regr_avgy` | ✅ | Native: Spark rewrites to `Average` (tests in [#4551](https://github.com/apache/datafusion-comet/issues/4551)) | | `regr_count` | ✅ | Native: Spark rewrites to `Count` (tests in [#4551](https://github.com/apache/datafusion-comet/issues/4551)) | diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 25162332fd..2ab693198b 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -74,7 +74,7 @@ use datafusion::{ use datafusion_comet_spark_expr::{ create_comet_physical_fun, create_comet_physical_fun_with_eval_mode, BinaryOutputStyle, BloomFilterAgg, BloomFilterMightContain, CsvWriteOptions, EvalMode, SparkArraysZipFunc, - SparkBloomFilterVersion, SumInteger, ToCsv, + SparkBloomFilterVersion, SparkPivotFirst, SumInteger, ToCsv, }; use datafusion_spark::function::aggregate::collect::SparkCollectSet; use iceberg::expr::Bind; @@ -2653,6 +2653,30 @@ impl PhysicalPlanner { let func = AggregateUDF::new_from_impl(SparkCollectSet::new()); Self::create_aggr_func_expr("collect_set", schema, vec![child], func) } + AggExprStruct::PivotFirst(expr) => { + let pivot_col = + self.create_expr(expr.pivot_column.as_ref().unwrap(), Arc::clone(&schema))?; + let value_col = + self.create_expr(expr.value_column.as_ref().unwrap(), Arc::clone(&schema))?; + let value_type = to_arrow_datatype(expr.value_datatype.as_ref().unwrap()); + // Reconstruct pivot values as ScalarValues from the serialized Literal exprs. + // The Scala serde builds them as Literal(v, pivot_column.dataType), so extracting + // via Literal downcast gives us the exact ScalarValue the update path compares + // against per input row. + let mut pivot_values = Vec::with_capacity(expr.pivot_values.len()); + for lit_expr in &expr.pivot_values { + let physical = self.create_expr(lit_expr, Arc::clone(&schema))?; + let literal = physical.downcast_ref::().ok_or_else(|| { + GeneralError( + "PivotFirst pivot_values must be Literal expressions".to_string(), + ) + })?; + pivot_values.push(literal.value().clone()); + } + let func = + AggregateUDF::new_from_impl(SparkPivotFirst::new(value_type, pivot_values)); + Self::create_aggr_func_expr("pivot_first", schema, vec![pivot_col, value_col], func) + } } } diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 5b2a6ce9ee..1055c32c16 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -145,6 +145,7 @@ message AggExpr { BloomFilterAgg bloomFilterAgg = 16; CollectSet collectSet = 17; Percentile percentile = 18; + PivotFirst pivotFirst = 19; } // Optional filter expression for SQL FILTER (WHERE ...) clause. @@ -276,6 +277,21 @@ message CollectSet { DataType datatype = 2; } +// Optimized Pivot second-phase aggregate. Given a `pivot_column` and a +// pre-computed list of `pivot_values`, buckets the incoming `value_column` +// into an array whose length is `pivot_values.size()`. Rows whose pivot +// value is not in `pivot_values` are ignored; if multiple rows in the same +// group map to the same bucket, the last non-null value wins (matches +// Spark's `PivotFirst.update`/`merge`). Output type is +// ArrayType(value_datatype). The pivot_values are serialized as Literal +// Exprs so the value bytes and data type reach native code faithfully. +message PivotFirst { + Expr pivot_column = 1; + Expr value_column = 2; + repeated Expr pivot_values = 3; + DataType value_datatype = 4; +} + enum EvalMode { LEGACY = 0; TRY = 1; diff --git a/native/spark-expr/src/agg_funcs/mod.rs b/native/spark-expr/src/agg_funcs/mod.rs index 2a0322e46c..2955e5ef54 100644 --- a/native/spark-expr/src/agg_funcs/mod.rs +++ b/native/spark-expr/src/agg_funcs/mod.rs @@ -19,6 +19,7 @@ mod avg; mod avg_decimal; mod correlation; mod covariance; +mod pivot_first; mod stddev; mod sum_decimal; mod sum_int; @@ -29,6 +30,7 @@ pub use avg::Avg; pub use avg_decimal::AvgDecimal; pub use correlation::Correlation; pub use covariance::Covariance; +pub use pivot_first::SparkPivotFirst; pub use stddev::Stddev; pub use sum_decimal::SumDecimal; pub use sum_int::SumInteger; diff --git a/native/spark-expr/src/agg_funcs/pivot_first.rs b/native/spark-expr/src/agg_funcs/pivot_first.rs new file mode 100644 index 0000000000..f08153ce3d --- /dev/null +++ b/native/spark-expr/src/agg_funcs/pivot_first.rs @@ -0,0 +1,347 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Spark's `PivotFirst` aggregate. Used only by the second phase of the optimized pivot plan +//! generated by `PivotTransformer`. For each group, `PivotFirst` maintains an array of +//! `pivot_values.len()` slots; on each input row it evaluates the pivot column, looks up its +//! index in `pivot_values`, and writes the value column into that slot when a match is found +//! and the value is non-null. Rows with unmatched pivot values are ignored; matched rows with +//! a null value column leave the slot unchanged (matches Spark). +//! +//! State layout is one column per pivot slot, matching Spark's `aggBufferAttributes` (which +//! declares `indexSize` `AttributeReference`s, one per pivot value). This keeps the shuffle +//! schema between Partial and Final consistent with what Spark catalyst declared — otherwise +//! the shuffle exchange rejects the batch. `evaluate()` reassembles the slots into a +//! `ListArray` matching `PivotFirst.dataType = ArrayType(value_type)`. + +use arrow::array::{Array, ArrayRef}; +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion::common::utils::SingleRowListArrayBuilder; +use datafusion::common::{DataFusionError, Result as DFResult, ScalarValue}; +use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion::logical_expr::Volatility::Immutable; +use datafusion::logical_expr::{Accumulator, AggregateUDFImpl, Signature}; +use std::collections::HashMap; +use std::sync::Arc; + +/// UDAF implementation of Spark's `PivotFirst`. +/// +/// `pivot_values` is a fixed, plan-time list of the pivot column values that occupy each +/// output slot; `pivot_index[v] = i` means an input row whose pivot column equals `v` writes +/// into slot `i`. +#[derive(Debug)] +pub struct SparkPivotFirst { + signature: Signature, + value_type: DataType, + pivot_values: Vec, + pivot_index: HashMap, +} + +impl PartialEq for SparkPivotFirst { + fn eq(&self, other: &Self) -> bool { + self.value_type == other.value_type && self.pivot_values == other.pivot_values + } +} + +impl Eq for SparkPivotFirst {} + +impl std::hash::Hash for SparkPivotFirst { + fn hash(&self, state: &mut H) { + self.value_type.hash(state); + self.pivot_values.hash(state); + } +} + +impl SparkPivotFirst { + pub fn new(value_type: DataType, pivot_values: Vec) -> Self { + let mut pivot_index = HashMap::with_capacity(pivot_values.len()); + // Spark's PivotFirst uses the FIRST occurrence's index (HashMap/TreeMap semantics), so + // when duplicates are somehow present we mirror that by only inserting the first one. + for (i, v) in pivot_values.iter().enumerate() { + pivot_index.entry(v.clone()).or_insert(i); + } + Self { + signature: Signature::user_defined(Immutable), + value_type, + pivot_values, + pivot_index, + } + } +} + +impl AggregateUDFImpl for SparkPivotFirst { + fn name(&self) -> &str { + "pivot_first" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DFResult { + Ok(DataType::List(Arc::new(Field::new_list_field( + self.value_type.clone(), + true, + )))) + } + + fn state_fields(&self, args: StateFieldsArgs) -> DFResult> { + // One field per pivot slot, matching Spark's aggBufferAttributes so the shuffle + // exchange sees the same schema catalyst declared. + Ok(self + .pivot_values + .iter() + .enumerate() + .map(|(i, _)| { + Arc::new(Field::new( + format!("{}[{}]", args.name, i), + self.value_type.clone(), + true, + )) + }) + .collect()) + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> DFResult> { + Ok(Box::new(PivotFirstAccumulator::new( + self.value_type.clone(), + self.pivot_values.clone(), + self.pivot_index.clone(), + ))) + } +} + +/// Per-group state: `slots[i]` holds the latest non-null value assigned to pivot slot `i`, or +/// `None` when nothing has written to that slot yet. +#[derive(Debug)] +struct PivotFirstAccumulator { + value_type: DataType, + pivot_index: HashMap, + slots: Vec>, +} + +impl PivotFirstAccumulator { + fn new( + value_type: DataType, + pivot_values: Vec, + pivot_index: HashMap, + ) -> Self { + let slots = vec![None; pivot_values.len()]; + Self { + value_type, + pivot_index, + slots, + } + } + + /// Turn slot `i` into a `ScalarValue`, substituting a typed null when the slot is empty. + fn slot_or_null(&self, i: usize) -> DFResult { + Ok(match &self.slots[i] { + Some(v) => v.clone(), + None => ScalarValue::try_from(&self.value_type)?, + }) + } +} + +impl Accumulator for PivotFirstAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> { + if values.len() != 2 { + return Err(DataFusionError::Internal(format!( + "PivotFirst expects 2 inputs (pivot, value); got {}", + values.len() + ))); + } + let pivot_arr = &values[0]; + let value_arr = &values[1]; + if pivot_arr.len() != value_arr.len() { + return Err(DataFusionError::Internal( + "PivotFirst pivot and value arrays have different lengths".into(), + )); + } + for row in 0..pivot_arr.len() { + // Spark ignores the row entirely if either the pivot value is unmatched (index<0) + // or the value is null. Matching Spark exactly here is important because + // `PivotFirst.update` never writes for a null value, so a mid-batch null does not + // clobber an earlier non-null. + let pivot_scalar = ScalarValue::try_from_array(pivot_arr, row)?; + if let Some(&slot_idx) = self.pivot_index.get(&pivot_scalar) { + if !value_arr.is_null(row) { + self.slots[slot_idx] = Some(ScalarValue::try_from_array(value_arr, row)?); + } + } + } + Ok(()) + } + + fn evaluate(&mut self) -> DFResult { + // Collapse the slot vector into a single ScalarValue::List whose inner array is + // `slots.len()` items long, with typed nulls for empty slots. This matches Spark's + // `PivotFirst.dataType = ArrayType(value_type)`. + let scalars = (0..self.slots.len()) + .map(|i| self.slot_or_null(i)) + .collect::>>()?; + let flat = ScalarValue::iter_to_array(scalars)?; + Ok(SingleRowListArrayBuilder::new(flat).build_list_scalar()) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + + self.slots.capacity() * std::mem::size_of::>() + } + + fn state(&mut self) -> DFResult> { + // One ScalarValue per pivot slot; matches state_fields. + (0..self.slots.len()) + .map(|i| self.slot_or_null(i)) + .collect() + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> { + if states.len() != self.slots.len() { + return Err(DataFusionError::Internal(format!( + "PivotFirst merge expects {} state columns; got {}", + self.slots.len(), + states.len() + ))); + } + // Each column is one slot; each row is one incoming partial state. Any non-null cell + // overwrites our current slot with "last write wins" (matches Spark's `PivotFirst.merge`). + let n_rows = states.first().map(|a| a.len()).unwrap_or(0); + for (slot_idx, col) in states.iter().enumerate() { + for row in 0..n_rows { + if !col.is_null(row) { + self.slots[slot_idx] = Some(ScalarValue::try_from_array(col, row)?); + } + } + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Int32Array, StringArray}; + + fn scalar(v: i32) -> ScalarValue { + ScalarValue::Int32(Some(v)) + } + fn s(v: &str) -> ScalarValue { + ScalarValue::Utf8(Some(v.to_string())) + } + + fn make_acc() -> PivotFirstAccumulator { + let pivot_values = vec![s("a"), s("b"), s("c")]; + let mut pivot_index = HashMap::new(); + for (i, v) in pivot_values.iter().enumerate() { + pivot_index.insert(v.clone(), i); + } + PivotFirstAccumulator::new(DataType::Int32, pivot_values, pivot_index) + } + + #[test] + fn update_writes_by_pivot_index() { + let mut acc = make_acc(); + let pivots: ArrayRef = Arc::new(StringArray::from(vec!["a", "c", "a"])); + let values: ArrayRef = Arc::new(Int32Array::from(vec![1, 3, 10])); + acc.update_batch(&[pivots, values]).unwrap(); + // Slot 'a' = 10 (last non-null write wins for the same pivot value), slot 'b' = null, + // slot 'c' = 3. + assert_eq!(acc.slots, vec![Some(scalar(10)), None, Some(scalar(3))]); + } + + #[test] + fn update_ignores_unmatched_pivot() { + let mut acc = make_acc(); + let pivots: ArrayRef = Arc::new(StringArray::from(vec!["a", "z", "b"])); + let values: ArrayRef = Arc::new(Int32Array::from(vec![1, 99, 2])); + acc.update_batch(&[pivots, values]).unwrap(); + // 'z' is not in the pivot list - its row is ignored entirely. + assert_eq!(acc.slots, vec![Some(scalar(1)), Some(scalar(2)), None]); + } + + #[test] + fn update_ignores_null_value() { + let mut acc = make_acc(); + let pivots: ArrayRef = Arc::new(StringArray::from(vec!["a", "a"])); + let values: ArrayRef = Arc::new(Int32Array::from(vec![Some(5), None])); + acc.update_batch(&[pivots, values]).unwrap(); + // The second row's null must not clobber the first row's 5. + assert_eq!(acc.slots, vec![Some(scalar(5)), None, None]); + } + + #[test] + fn evaluate_returns_list_with_nulls_for_empty_slots() { + let mut acc = make_acc(); + let pivots: ArrayRef = Arc::new(StringArray::from(vec!["c"])); + let values: ArrayRef = Arc::new(Int32Array::from(vec![7])); + acc.update_batch(&[pivots, values]).unwrap(); + + let result = acc.evaluate().unwrap(); + match result { + ScalarValue::List(list) => { + assert_eq!(list.len(), 1); + let inner = list.value(0); + let inner = inner.as_any().downcast_ref::().unwrap(); + assert!(inner.is_null(0)); + assert!(inner.is_null(1)); + assert_eq!(inner.value(2), 7); + } + other => panic!("expected ScalarValue::List, got {other:?}"), + } + } + + #[test] + fn state_returns_one_scalar_per_slot() { + let mut acc = make_acc(); + let pivots: ArrayRef = Arc::new(StringArray::from(vec!["a", "c"])); + let values: ArrayRef = Arc::new(Int32Array::from(vec![1, 3])); + acc.update_batch(&[pivots, values]).unwrap(); + + let state = acc.state().unwrap(); + assert_eq!(state.len(), 3); + assert_eq!(state[0], scalar(1)); + assert!(matches!(state[1], ScalarValue::Int32(None))); + assert_eq!(state[2], scalar(3)); + } + + #[test] + fn merge_pastes_non_null_slots() { + let mut acc = make_acc(); + let pivots: ArrayRef = Arc::new(StringArray::from(vec!["a"])); + let values: ArrayRef = Arc::new(Int32Array::from(vec![1])); + acc.update_batch(&[pivots, values]).unwrap(); + + // Build the other side's state and merge - slot 'a' stays 1, 'b' becomes 2, 'c' 3. + let mut other = make_acc(); + let pivots2: ArrayRef = Arc::new(StringArray::from(vec!["b", "c"])); + let values2: ArrayRef = Arc::new(Int32Array::from(vec![2, 3])); + other.update_batch(&[pivots2, values2]).unwrap(); + let state = other.state().unwrap(); + // Feed each state scalar as a single-row column, mirroring what the shuffle produces. + let state_arrays: Vec = state + .into_iter() + .map(|sv| sv.to_array_of_size(1).unwrap()) + .collect(); + acc.merge_batch(&state_arrays).unwrap(); + + assert_eq!( + acc.slots, + vec![Some(scalar(1)), Some(scalar(2)), Some(scalar(3))] + ); + } +} diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index b752f41d74..07d461b084 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -400,6 +400,7 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim { classOf[Max] -> CometMax, classOf[Min] -> CometMin, classOf[Percentile] -> CometPercentile, + classOf[PivotFirst] -> CometPivotFirst, classOf[StddevPop] -> CometStddevPop, classOf[StddevSamp] -> CometStddevSamp, classOf[Sum] -> CometSum, diff --git a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala index 5710232cb4..aef7a85910 100644 --- a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala +++ b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala @@ -22,9 +22,9 @@ package org.apache.comet.serde import scala.jdk.CollectionConverters._ import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal} -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate, CentralMomentAgg, CollectSet, Corr, Count, Covariance, CovPopulation, CovSample, First, Last, Max, Min, Percentile, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate, CentralMomentAgg, CollectSet, Corr, Count, Covariance, CovPopulation, CovSample, First, Last, Max, Min, Percentile, PivotFirst, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{ByteType, DecimalType, DoubleType, IntegerType, LongType, NumericType, ShortType, StringType} +import org.apache.spark.sql.types.{BooleanType, ByteType, DataType, DecimalType, DoubleType, FloatType, IntegerType, LongType, NumericType, ShortType, StringType} import org.apache.comet.CometConf.COMET_EXEC_STRICT_FLOATING_POINT import org.apache.comet.CometSparkSessionExtensions.{isSpark41Plus, withFallbackReason} @@ -823,6 +823,83 @@ object CometCollectSet extends CometAggregateExpressionSerde[CollectSet] { } } +object CometPivotFirst extends CometAggregateExpressionSerde[PivotFirst] { + + // Mirrors Spark's `PivotFirst.supportsDataType`, which is the same gate the analyzer + // uses to decide between the two-phase fast path (this aggregate) and the standard + // filtered-aggregate path. Only the fast path emits `PivotFirst`, so we only need to + // cover these value types. + private def isValueTypeSupported(dt: DataType): Boolean = dt match { + case BooleanType | ByteType | ShortType | IntegerType | LongType => true + case FloatType | DoubleType => true + case _: DecimalType => true + case _ => false + } + + private def unsupportedValueTypeReason(dt: DataType): String = + s"Unsupported value data type: $dt" + + private val emptyPivotValuesReason = "Pivot values list is empty" + + override def getUnsupportedReasons(): Seq[String] = Seq( + "Value data type outside PivotFirst.supportsDataType " + + "(Boolean, Byte, Short, Int, Long, Float, Double, Decimal)", + emptyPivotValuesReason) + + override def getSupportLevel(expr: PivotFirst): SupportLevel = { + if (!isValueTypeSupported(expr.valueDataType)) { + Unsupported(Some(unsupportedValueTypeReason(expr.valueDataType))) + } else if (expr.pivotColumnValues.isEmpty) { + Unsupported(Some(emptyPivotValuesReason)) + } else { + Compatible() + } + } + + override def convert( + aggExpr: AggregateExpression, + expr: PivotFirst, + inputs: Seq[Attribute], + binding: Boolean, + conf: SQLConf): Option[ExprOuterClass.AggExpr] = { + val pivotColExpr = exprToProto(expr.pivotColumn, inputs, binding) + val valueColExpr = exprToProto(expr.valueColumn, inputs, binding) + val valueDt = serializeDataType(expr.valueDataType) + val pivotDt = expr.pivotColumn.dataType + + // Spark stores the pivot values as already-evaluated raw Scala values. Rebuild them as + // Literal expressions carrying the pivot column data type so the native side receives + // the exact bytes it will later compare against. Preserve list order - the position in + // this list is the output array index. + val pivotValueExprs = + expr.pivotColumnValues.map(v => exprToProto(Literal(v, pivotDt), inputs, binding)) + + if (pivotColExpr.isDefined && valueColExpr.isDefined && valueDt.isDefined && + pivotValueExprs.forall(_.isDefined)) { + val builder = ExprOuterClass.PivotFirst.newBuilder() + builder.setPivotColumn(pivotColExpr.get) + builder.setValueColumn(valueColExpr.get) + builder.setValueDatatype(valueDt.get) + pivotValueExprs.foreach(v => builder.addPivotValues(v.get)) + + Some( + ExprOuterClass.AggExpr + .newBuilder() + .setPivotFirst(builder) + .build()) + } else if (valueDt.isEmpty) { + withFallbackReason( + aggExpr, + unsupportedValueTypeReason(expr.valueDataType), + expr.valueColumn) + None + } else { + withFallbackReason(aggExpr, expr.pivotColumn, expr.valueColumn) + None + } + } +} + object AggSerde { import org.apache.spark.sql.types._ diff --git a/spark/src/test/resources/sql-tests/expressions/aggregate/PivotFirst.sql b/spark/src/test/resources/sql-tests/expressions/aggregate/PivotFirst.sql new file mode 100644 index 0000000000..17063bd93c --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/aggregate/PivotFirst.sql @@ -0,0 +1,326 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- PivotFirst is Spark's internal aggregate emitted by the two-phase optimized pivot plan +-- (PivotTransformer). It only runs when every aggregate output type is in +-- PivotFirst.supportsDataType (Boolean, Byte, Short, Int, Long, Float, Double, Decimal). +-- These SQL tests exercise that path via SQL PIVOT clauses, one per supported value type +-- plus a handful of edge cases (unmatched pivot values, all-null groups, multiple +-- aggregates). They intentionally use `spark_answer_only` because the pivot plan's final +-- projection is a Spark expression (ExtractValue on the aggregate output) so the whole +-- query does not always run natively even when PivotFirst itself does. + +-- ============================================================ +-- Setup: sales-like table used for most cases +-- ============================================================ + +statement +CREATE TABLE pf_sales (year int, course string, earnings int) USING parquet + +statement +INSERT INTO pf_sales VALUES + (2012, 'dotNET', 15000), + (2012, 'Java', 20000), + (2013, 'dotNET', 48000), + (2013, 'Java', 30000) + +-- ============================================================ +-- Int values: canonical optimized-pivot query. Extra pivot values +-- (1..10) push the aggregate over the threshold that turns on the +-- PivotFirst plan. +-- ============================================================ + +query spark_answer_only +SELECT * FROM pf_sales + PIVOT (sum(earnings) + FOR course IN ('dotNET', 'Java', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10')) + ORDER BY year + +-- ============================================================ +-- Long values: sum of a bigint aggregates to bigint. +-- ============================================================ + +statement +CREATE TABLE pf_sales_bi (year int, course string, earnings bigint) USING parquet + +statement +INSERT INTO pf_sales_bi VALUES + (2012, 'dotNET', 15000000000), + (2012, 'Java', 20000000000), + (2013, 'dotNET', 48000000000), + (2013, 'Java', 30000000000) + +query spark_answer_only +SELECT * FROM pf_sales_bi + PIVOT (sum(earnings) + FOR course IN ('dotNET', 'Java', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10')) + ORDER BY year + +-- ============================================================ +-- Double values: avg(earnings) is double. +-- ============================================================ + +query spark_answer_only +SELECT * FROM pf_sales + PIVOT (avg(earnings) + FOR course IN ('dotNET', 'Java', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10')) + ORDER BY year + +-- ============================================================ +-- Float values: cast the sum result to float via a wrapping min. +-- ============================================================ + +statement +CREATE TABLE pf_sales_f (year int, course string, earnings float) USING parquet + +statement +INSERT INTO pf_sales_f VALUES + (2012, 'dotNET', CAST(15000 AS FLOAT)), + (2012, 'Java', CAST(20000 AS FLOAT)), + (2013, 'dotNET', CAST(48000 AS FLOAT)), + (2013, 'Java', CAST(30000 AS FLOAT)) + +query spark_answer_only +SELECT * FROM pf_sales_f + PIVOT (min(earnings) + FOR course IN ('dotNET', 'Java', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10')) + ORDER BY year + +-- ============================================================ +-- Short and Byte value types (sum promotes them to bigint, so use min +-- to keep the aggregate output type equal to the input type). +-- ============================================================ + +statement +CREATE TABLE pf_sales_small ( + year int, course string, s smallint, b tinyint) USING parquet + +statement +INSERT INTO pf_sales_small VALUES + (2012, 'dotNET', 150, 10), + (2012, 'Java', 200, 20), + (2013, 'dotNET', 480, 30), + (2013, 'Java', 300, 40) + +query spark_answer_only +SELECT * FROM pf_sales_small + PIVOT (min(s) + FOR course IN ('dotNET', 'Java', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10')) + ORDER BY year + +query spark_answer_only +SELECT * FROM pf_sales_small + PIVOT (min(b) + FOR course IN ('dotNET', 'Java', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10')) + ORDER BY year + +-- ============================================================ +-- Boolean values: bool_or's output type is boolean. +-- ============================================================ + +statement +CREATE TABLE pf_flags (year int, course string, flag boolean) USING parquet + +statement +INSERT INTO pf_flags VALUES + (2012, 'dotNET', true), + (2012, 'Java', false), + (2013, 'dotNET', false), + (2013, 'Java', true) + +query spark_answer_only +SELECT * FROM pf_flags + PIVOT (bool_or(flag) + FOR course IN ('dotNET', 'Java', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10')) + ORDER BY year + +-- ============================================================ +-- Decimal values. +-- ============================================================ + +statement +CREATE TABLE pf_sales_dec (year int, course string, earnings decimal(10,2)) USING parquet + +statement +INSERT INTO pf_sales_dec VALUES + (2012, 'dotNET', 15000.00), + (2012, 'Java', 20000.00), + (2013, 'dotNET', 48000.00), + (2013, 'Java', 30000.00) + +query spark_answer_only +SELECT * FROM pf_sales_dec + PIVOT (sum(earnings) + FOR course IN ('dotNET', 'Java', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10')) + ORDER BY year + +-- ============================================================ +-- Multiple aggregates in a single pivot. +-- ============================================================ + +query spark_answer_only +SELECT * FROM pf_sales + PIVOT (sum(earnings) AS s, avg(earnings) AS a + FOR course IN ('dotNET', 'Java', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10')) + ORDER BY year + +-- ============================================================ +-- Pivot values that do not appear in the data: the corresponding +-- output slots must be NULL. +-- ============================================================ + +query spark_answer_only +SELECT * FROM pf_sales + PIVOT (sum(earnings) + FOR course IN ('dotNET', 'Java', 'MissingA', 'MissingB', + '1', '2', '3', '4', '5', '6', '7', '8', '9', '10')) + ORDER BY year + +-- ============================================================ +-- Rows whose pivot column is not in the pivot list must be ignored: +-- 'Python' below should not contribute to any output slot. +-- ============================================================ + +statement +CREATE TABLE pf_sales_extra (year int, course string, earnings int) USING parquet + +statement +INSERT INTO pf_sales_extra VALUES + (2012, 'dotNET', 15000), + (2012, 'Java', 20000), + (2012, 'Python', 12345), + (2013, 'dotNET', 48000), + (2013, 'Java', 30000), + (2013, 'Python', 67890) + +query spark_answer_only +SELECT * FROM pf_sales_extra + PIVOT (sum(earnings) + FOR course IN ('dotNET', 'Java', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10')) + ORDER BY year + +-- ============================================================ +-- Group with no matching rows: every output slot is NULL for that +-- pivot combination. Ensures NULL initialization matches Spark. +-- ============================================================ + +statement +CREATE TABLE pf_sales_partial (year int, course string, earnings int) USING parquet + +statement +INSERT INTO pf_sales_partial VALUES + (2012, 'dotNET', 15000), + (2012, 'Java', 20000), + (2013, 'Python', 42000) + +query spark_answer_only +SELECT * FROM pf_sales_partial + PIVOT (sum(earnings) + FOR course IN ('dotNET', 'Java', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10')) + ORDER BY year + +-- ============================================================ +-- NULL value in an aggregated column: the null must not clobber a +-- previously written slot (Spark PivotFirst.update ignores nulls). +-- ============================================================ + +statement +CREATE TABLE pf_sales_nullable (year int, course string, earnings int) USING parquet + +statement +INSERT INTO pf_sales_nullable VALUES + (2012, 'dotNET', 15000), + (2012, 'dotNET', NULL), + (2012, 'Java', NULL), + (2013, 'dotNET', 48000), + (2013, 'Java', 30000) + +query spark_answer_only +SELECT * FROM pf_sales_nullable + PIVOT (sum(earnings) + FOR course IN ('dotNET', 'Java', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10')) + ORDER BY year + +-- ============================================================ +-- NULL in the pivot COLUMN. Spark tracks this with a dedicated +-- test ("pivot with null should not throw NPE") because early +-- versions of PivotFirst NPE'd on null keys with a TreeMap. +-- Comet uses a HashMap where ScalarValue::Null +-- hashes safely, so this must produce the same rows as Spark. +-- ============================================================ + +statement +CREATE TABLE pf_sales_nullpivot (year int, course string, earnings int) USING parquet + +statement +INSERT INTO pf_sales_nullpivot VALUES + (2012, 'dotNET', 15000), + (2012, NULL, 9999), + (2013, 'Java', 30000), + (2013, NULL, 42) + +query spark_answer_only +SELECT * FROM pf_sales_nullpivot + PIVOT (sum(earnings) + FOR course IN ('dotNET', 'Java', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10')) + ORDER BY year + +-- ============================================================ +-- Non-string pivot column: integer. Exercises the pivot value +-- serialization path with a numeric literal. +-- ============================================================ + +statement +CREATE TABLE pf_sales_intpivot (region string, year int, earnings int) USING parquet + +statement +INSERT INTO pf_sales_intpivot VALUES + ('NA', 2012, 15000), + ('NA', 2013, 48000), + ('EU', 2012, 20000), + ('EU', 2013, 30000) + +query spark_answer_only +SELECT * FROM pf_sales_intpivot + PIVOT (sum(earnings) + FOR year IN (2012, 2013, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10)) + ORDER BY region + +-- ============================================================ +-- Non-string pivot column: date. Exercises the pivot value +-- serialization path with a DateType literal. +-- ============================================================ + +statement +CREATE TABLE pf_sales_datepivot (region string, d date, earnings int) USING parquet + +statement +INSERT INTO pf_sales_datepivot VALUES + ('NA', DATE '2024-01-01', 15000), + ('NA', DATE '2024-06-15', 48000), + ('EU', DATE '2024-01-01', 20000), + ('EU', DATE '2024-06-15', 30000) + +query spark_answer_only +SELECT * FROM pf_sales_datepivot + PIVOT (sum(earnings) + FOR d IN (DATE '2024-01-01', DATE '2024-06-15', + DATE '1970-01-01', DATE '1970-01-02', DATE '1970-01-03', + DATE '1970-01-04', DATE '1970-01-05', DATE '1970-01-06', + DATE '1970-01-07', DATE '1970-01-08', DATE '1970-01-09', + DATE '1970-01-10')) + ORDER BY region diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index ae14c68207..ae9c2cf64b 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -2150,4 +2150,46 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + // Optimized-pivot fast path: with 12+ pivot values Spark's PivotTransformer emits a + // two-phase aggregate whose second phase uses PivotFirst. We assert: + // 1. The physical plan actually contains a PivotFirst (Spark still picked the fast path). + // 2. Comet converted the second-phase aggregate to CometHashAggregateExec (i.e. our + // CometPivotFirst serde accepted the aggregate). + // 3. Results match Spark exactly. + // If any of the three fails, the test surfaces the specific breakage. + test("PivotFirst runs natively on the optimized pivot fast path") { + withTempDir { dir => + val path = new Path(dir.toURI.toString, "sales") + spark + .sql("""SELECT * FROM VALUES + | (2012, 'dotNET', 15000), + | (2012, 'Java', 20000), + | (2013, 'dotNET', 48000), + | (2013, 'Java', 30000) + |AS t(year, course, earnings)""".stripMargin) + .write + .parquet(path.toUri.toString) + withParquetTable(path.toUri.toString, "sales") { + val pivotSql = + """SELECT * FROM sales + | PIVOT (sum(earnings) + | FOR course IN ('dotNET', 'Java', + | '1', '2', '3', '4', '5', '6', '7', '8', '9', '10')) + | ORDER BY year""".stripMargin + val df = spark.sql(pivotSql) + val plan = df.queryExecution.executedPlan + val pivotFirstAggs = collectWithSubqueries(plan) { + case a: CometHashAggregateExec + if a.aggregateExpressions.exists(_.aggregateFunction + .isInstanceOf[org.apache.spark.sql.catalyst.expressions.aggregate.PivotFirst]) => + a + } + assert( + pivotFirstAggs.nonEmpty, + s"Expected a CometHashAggregateExec containing PivotFirst in:\n$plan") + checkSparkAnswer(df) + } + } + } + } From f632bb57be48af3f215321a8d2a40bbb64724a48 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 2 Jul 2026 14:31:53 -0600 Subject: [PATCH 2/2] simplify: address /simplify review of PivotFirst - Delegate value-type gate to Spark's `PivotFirst.supportsDataType` instead of maintaining a parallel list in the serde. - Wrap the plan-time pivot map/vector in `Arc` so per-group accumulator init bumps a refcount instead of deep-cloning. - Drop the redundant `pivot_values` field from the accumulator (only the map and `slots.len()` are read at runtime); the accumulator now derives its slot count from the shared map. - Use the shared `format_state_name` helper for state field names. - Delegate `evaluate()` to `state()` so the two paths share their slot materialization. - Walk each state column backwards in `merge_batch` and stop at the first non-null so we build at most one `ScalarValue` per slot per merge. - Drop the "pad every IN clause to 12 values" workaround from the SQL tests and the CometAggregateSuite test - the fast-path gate is supported-data-type only, no minimum count. --- .../spark-expr/src/agg_funcs/pivot_first.rs | 80 ++++++++++--------- .../org/apache/comet/serde/aggregates.scala | 24 ++---- .../expressions/aggregate/PivotFirst.sql | 47 +++++------ .../comet/exec/CometAggregateSuite.scala | 9 +-- 4 files changed, 71 insertions(+), 89 deletions(-) diff --git a/native/spark-expr/src/agg_funcs/pivot_first.rs b/native/spark-expr/src/agg_funcs/pivot_first.rs index f08153ce3d..a80259fa89 100644 --- a/native/spark-expr/src/agg_funcs/pivot_first.rs +++ b/native/spark-expr/src/agg_funcs/pivot_first.rs @@ -24,7 +24,7 @@ //! //! State layout is one column per pivot slot, matching Spark's `aggBufferAttributes` (which //! declares `indexSize` `AttributeReference`s, one per pivot value). This keeps the shuffle -//! schema between Partial and Final consistent with what Spark catalyst declared — otherwise +//! schema between Partial and Final consistent with what Spark catalyst declared; otherwise //! the shuffle exchange rejects the batch. `evaluate()` reassembles the slots into a //! `ListArray` matching `PivotFirst.dataType = ArrayType(value_type)`. @@ -35,6 +35,7 @@ use datafusion::common::{DataFusionError, Result as DFResult, ScalarValue}; use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion::logical_expr::Volatility::Immutable; use datafusion::logical_expr::{Accumulator, AggregateUDFImpl, Signature}; +use datafusion::physical_expr::expressions::format_state_name; use std::collections::HashMap; use std::sync::Arc; @@ -42,13 +43,18 @@ use std::sync::Arc; /// /// `pivot_values` is a fixed, plan-time list of the pivot column values that occupy each /// output slot; `pivot_index[v] = i` means an input row whose pivot column equals `v` writes -/// into slot `i`. +/// into slot `i`. Both the vector and the map are wrapped in `Arc` because `accumulator()` +/// fires once per group in a grouped aggregate and we want that path to bump a refcount +/// rather than deep-clone. #[derive(Debug)] pub struct SparkPivotFirst { signature: Signature, value_type: DataType, - pivot_values: Vec, - pivot_index: HashMap, + // Kept for `PartialEq`/`Hash` (identity of the aggregate for plan comparison) and for the + // deterministic slot ordering `state_fields` needs. `HashMap` alone would give us the map + // but not a stable order or a `Hash` impl. + pivot_values: Arc>, + pivot_index: Arc>, } impl PartialEq for SparkPivotFirst { @@ -77,8 +83,8 @@ impl SparkPivotFirst { Self { signature: Signature::user_defined(Immutable), value_type, - pivot_values, - pivot_index, + pivot_values: Arc::new(pivot_values), + pivot_index: Arc::new(pivot_index), } } } @@ -101,14 +107,12 @@ impl AggregateUDFImpl for SparkPivotFirst { fn state_fields(&self, args: StateFieldsArgs) -> DFResult> { // One field per pivot slot, matching Spark's aggBufferAttributes so the shuffle - // exchange sees the same schema catalyst declared. - Ok(self - .pivot_values - .iter() - .enumerate() - .map(|(i, _)| { + // exchange sees the same schema catalyst declared. `format_state_name` is the same + // helper other aggregates in this crate use (see `avg.rs`, `stddev.rs`). + Ok((0..self.pivot_values.len()) + .map(|i| { Arc::new(Field::new( - format!("{}[{}]", args.name, i), + format_state_name(args.name, &i.to_string()), self.value_type.clone(), true, )) @@ -119,8 +123,7 @@ impl AggregateUDFImpl for SparkPivotFirst { fn accumulator(&self, _acc_args: AccumulatorArgs) -> DFResult> { Ok(Box::new(PivotFirstAccumulator::new( self.value_type.clone(), - self.pivot_values.clone(), - self.pivot_index.clone(), + Arc::clone(&self.pivot_index), ))) } } @@ -130,17 +133,13 @@ impl AggregateUDFImpl for SparkPivotFirst { #[derive(Debug)] struct PivotFirstAccumulator { value_type: DataType, - pivot_index: HashMap, + pivot_index: Arc>, slots: Vec>, } impl PivotFirstAccumulator { - fn new( - value_type: DataType, - pivot_values: Vec, - pivot_index: HashMap, - ) -> Self { - let slots = vec![None; pivot_values.len()]; + fn new(value_type: DataType, pivot_index: Arc>) -> Self { + let slots = vec![None; pivot_index.len()]; Self { value_type, pivot_index, @@ -188,13 +187,9 @@ impl Accumulator for PivotFirstAccumulator { } fn evaluate(&mut self) -> DFResult { - // Collapse the slot vector into a single ScalarValue::List whose inner array is - // `slots.len()` items long, with typed nulls for empty slots. This matches Spark's - // `PivotFirst.dataType = ArrayType(value_type)`. - let scalars = (0..self.slots.len()) - .map(|i| self.slot_or_null(i)) - .collect::>>()?; - let flat = ScalarValue::iter_to_array(scalars)?; + // Collapse the per-slot state into a single ScalarValue::List with typed nulls for + // empty slots. This matches Spark's `PivotFirst.dataType = ArrayType(value_type)`. + let flat = ScalarValue::iter_to_array(self.state()?)?; Ok(SingleRowListArrayBuilder::new(flat).build_list_scalar()) } @@ -204,7 +199,7 @@ impl Accumulator for PivotFirstAccumulator { } fn state(&mut self) -> DFResult> { - // One ScalarValue per pivot slot; matches state_fields. + // One ScalarValue per pivot slot; matches `state_fields`. (0..self.slots.len()) .map(|i| self.slot_or_null(i)) .collect() @@ -218,13 +213,19 @@ impl Accumulator for PivotFirstAccumulator { states.len() ))); } + if states.is_empty() { + return Ok(()); + } // Each column is one slot; each row is one incoming partial state. Any non-null cell - // overwrites our current slot with "last write wins" (matches Spark's `PivotFirst.merge`). - let n_rows = states.first().map(|a| a.len()).unwrap_or(0); + // overwrites our current slot with "last write wins" (matches Spark's + // `PivotFirst.merge`). Walk backwards and break on the first non-null so we only build + // at most one `ScalarValue` per slot per merge, not one per row per slot. + let n_rows = states[0].len(); for (slot_idx, col) in states.iter().enumerate() { - for row in 0..n_rows { + for row in (0..n_rows).rev() { if !col.is_null(row) { self.slots[slot_idx] = Some(ScalarValue::try_from_array(col, row)?); + break; } } } @@ -245,12 +246,13 @@ mod tests { } fn make_acc() -> PivotFirstAccumulator { - let pivot_values = vec![s("a"), s("b"), s("c")]; - let mut pivot_index = HashMap::new(); - for (i, v) in pivot_values.iter().enumerate() { - pivot_index.insert(v.clone(), i); - } - PivotFirstAccumulator::new(DataType::Int32, pivot_values, pivot_index) + let pivot_values = [s("a"), s("b"), s("c")]; + let pivot_index: HashMap = pivot_values + .iter() + .enumerate() + .map(|(i, v)| (v.clone(), i)) + .collect(); + PivotFirstAccumulator::new(DataType::Int32, Arc::new(pivot_index)) } #[test] diff --git a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala index aef7a85910..c71b1dada9 100644 --- a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala +++ b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala @@ -24,7 +24,7 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate, CentralMomentAgg, CollectSet, Corr, Count, Covariance, CovPopulation, CovSample, First, Last, Max, Min, Percentile, PivotFirst, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{BooleanType, ByteType, DataType, DecimalType, DoubleType, FloatType, IntegerType, LongType, NumericType, ShortType, StringType} +import org.apache.spark.sql.types.{ByteType, DataType, DecimalType, DoubleType, IntegerType, LongType, NumericType, ShortType, StringType} import org.apache.comet.CometConf.COMET_EXEC_STRICT_FLOATING_POINT import org.apache.comet.CometSparkSessionExtensions.{isSpark41Plus, withFallbackReason} @@ -825,17 +825,8 @@ object CometCollectSet extends CometAggregateExpressionSerde[CollectSet] { object CometPivotFirst extends CometAggregateExpressionSerde[PivotFirst] { - // Mirrors Spark's `PivotFirst.supportsDataType`, which is the same gate the analyzer - // uses to decide between the two-phase fast path (this aggregate) and the standard - // filtered-aggregate path. Only the fast path emits `PivotFirst`, so we only need to - // cover these value types. - private def isValueTypeSupported(dt: DataType): Boolean = dt match { - case BooleanType | ByteType | ShortType | IntegerType | LongType => true - case FloatType | DoubleType => true - case _: DecimalType => true - case _ => false - } - + // Delegate to Spark's own PivotFirst.supportsDataType so the two lists cannot drift if a + // future Spark version adds a value type to the fast-path gate. private def unsupportedValueTypeReason(dt: DataType): String = s"Unsupported value data type: $dt" @@ -847,7 +838,7 @@ object CometPivotFirst extends CometAggregateExpressionSerde[PivotFirst] { emptyPivotValuesReason) override def getSupportLevel(expr: PivotFirst): SupportLevel = { - if (!isValueTypeSupported(expr.valueDataType)) { + if (!PivotFirst.supportsDataType(expr.valueDataType)) { Unsupported(Some(unsupportedValueTypeReason(expr.valueDataType))) } else if (expr.pivotColumnValues.isEmpty) { Unsupported(Some(emptyPivotValuesReason)) @@ -881,12 +872,7 @@ object CometPivotFirst extends CometAggregateExpressionSerde[PivotFirst] { builder.setValueColumn(valueColExpr.get) builder.setValueDatatype(valueDt.get) pivotValueExprs.foreach(v => builder.addPivotValues(v.get)) - - Some( - ExprOuterClass.AggExpr - .newBuilder() - .setPivotFirst(builder) - .build()) + Some(ExprOuterClass.AggExpr.newBuilder().setPivotFirst(builder).build()) } else if (valueDt.isEmpty) { withFallbackReason( aggExpr, diff --git a/spark/src/test/resources/sql-tests/expressions/aggregate/PivotFirst.sql b/spark/src/test/resources/sql-tests/expressions/aggregate/PivotFirst.sql index 17063bd93c..bf9c45279a 100644 --- a/spark/src/test/resources/sql-tests/expressions/aggregate/PivotFirst.sql +++ b/spark/src/test/resources/sql-tests/expressions/aggregate/PivotFirst.sql @@ -16,13 +16,13 @@ -- under the License. -- PivotFirst is Spark's internal aggregate emitted by the two-phase optimized pivot plan --- (PivotTransformer). It only runs when every aggregate output type is in +-- (PivotTransformer). It runs whenever every aggregate output type is in -- PivotFirst.supportsDataType (Boolean, Byte, Short, Int, Long, Float, Double, Decimal). -- These SQL tests exercise that path via SQL PIVOT clauses, one per supported value type --- plus a handful of edge cases (unmatched pivot values, all-null groups, multiple --- aggregates). They intentionally use `spark_answer_only` because the pivot plan's final --- projection is a Spark expression (ExtractValue on the aggregate output) so the whole --- query does not always run natively even when PivotFirst itself does. +-- plus a handful of edge cases (unmatched pivot values, all-null groups, null pivot column). +-- They intentionally use `spark_answer_only` because the pivot plan's final projection is a +-- Spark expression (ExtractValue on the aggregate output) so the whole query does not always +-- run natively even when PivotFirst itself does. -- ============================================================ -- Setup: sales-like table used for most cases @@ -47,7 +47,7 @@ INSERT INTO pf_sales VALUES query spark_answer_only SELECT * FROM pf_sales PIVOT (sum(earnings) - FOR course IN ('dotNET', 'Java', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10')) + FOR course IN ('dotNET', 'Java')) ORDER BY year -- ============================================================ @@ -67,7 +67,7 @@ INSERT INTO pf_sales_bi VALUES query spark_answer_only SELECT * FROM pf_sales_bi PIVOT (sum(earnings) - FOR course IN ('dotNET', 'Java', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10')) + FOR course IN ('dotNET', 'Java')) ORDER BY year -- ============================================================ @@ -77,7 +77,7 @@ SELECT * FROM pf_sales_bi query spark_answer_only SELECT * FROM pf_sales PIVOT (avg(earnings) - FOR course IN ('dotNET', 'Java', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10')) + FOR course IN ('dotNET', 'Java')) ORDER BY year -- ============================================================ @@ -97,7 +97,7 @@ INSERT INTO pf_sales_f VALUES query spark_answer_only SELECT * FROM pf_sales_f PIVOT (min(earnings) - FOR course IN ('dotNET', 'Java', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10')) + FOR course IN ('dotNET', 'Java')) ORDER BY year -- ============================================================ @@ -119,13 +119,13 @@ INSERT INTO pf_sales_small VALUES query spark_answer_only SELECT * FROM pf_sales_small PIVOT (min(s) - FOR course IN ('dotNET', 'Java', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10')) + FOR course IN ('dotNET', 'Java')) ORDER BY year query spark_answer_only SELECT * FROM pf_sales_small PIVOT (min(b) - FOR course IN ('dotNET', 'Java', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10')) + FOR course IN ('dotNET', 'Java')) ORDER BY year -- ============================================================ @@ -145,7 +145,7 @@ INSERT INTO pf_flags VALUES query spark_answer_only SELECT * FROM pf_flags PIVOT (bool_or(flag) - FOR course IN ('dotNET', 'Java', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10')) + FOR course IN ('dotNET', 'Java')) ORDER BY year -- ============================================================ @@ -165,7 +165,7 @@ INSERT INTO pf_sales_dec VALUES query spark_answer_only SELECT * FROM pf_sales_dec PIVOT (sum(earnings) - FOR course IN ('dotNET', 'Java', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10')) + FOR course IN ('dotNET', 'Java')) ORDER BY year -- ============================================================ @@ -175,7 +175,7 @@ SELECT * FROM pf_sales_dec query spark_answer_only SELECT * FROM pf_sales PIVOT (sum(earnings) AS s, avg(earnings) AS a - FOR course IN ('dotNET', 'Java', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10')) + FOR course IN ('dotNET', 'Java')) ORDER BY year -- ============================================================ @@ -186,8 +186,7 @@ SELECT * FROM pf_sales query spark_answer_only SELECT * FROM pf_sales PIVOT (sum(earnings) - FOR course IN ('dotNET', 'Java', 'MissingA', 'MissingB', - '1', '2', '3', '4', '5', '6', '7', '8', '9', '10')) + FOR course IN ('dotNET', 'Java', 'MissingA', 'MissingB')) ORDER BY year -- ============================================================ @@ -210,7 +209,7 @@ INSERT INTO pf_sales_extra VALUES query spark_answer_only SELECT * FROM pf_sales_extra PIVOT (sum(earnings) - FOR course IN ('dotNET', 'Java', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10')) + FOR course IN ('dotNET', 'Java')) ORDER BY year -- ============================================================ @@ -230,7 +229,7 @@ INSERT INTO pf_sales_partial VALUES query spark_answer_only SELECT * FROM pf_sales_partial PIVOT (sum(earnings) - FOR course IN ('dotNET', 'Java', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10')) + FOR course IN ('dotNET', 'Java')) ORDER BY year -- ============================================================ @@ -252,7 +251,7 @@ INSERT INTO pf_sales_nullable VALUES query spark_answer_only SELECT * FROM pf_sales_nullable PIVOT (sum(earnings) - FOR course IN ('dotNET', 'Java', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10')) + FOR course IN ('dotNET', 'Java')) ORDER BY year -- ============================================================ @@ -276,7 +275,7 @@ INSERT INTO pf_sales_nullpivot VALUES query spark_answer_only SELECT * FROM pf_sales_nullpivot PIVOT (sum(earnings) - FOR course IN ('dotNET', 'Java', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10')) + FOR course IN ('dotNET', 'Java')) ORDER BY year -- ============================================================ @@ -297,7 +296,7 @@ INSERT INTO pf_sales_intpivot VALUES query spark_answer_only SELECT * FROM pf_sales_intpivot PIVOT (sum(earnings) - FOR year IN (2012, 2013, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10)) + FOR year IN (2012, 2013)) ORDER BY region -- ============================================================ @@ -318,9 +317,5 @@ INSERT INTO pf_sales_datepivot VALUES query spark_answer_only SELECT * FROM pf_sales_datepivot PIVOT (sum(earnings) - FOR d IN (DATE '2024-01-01', DATE '2024-06-15', - DATE '1970-01-01', DATE '1970-01-02', DATE '1970-01-03', - DATE '1970-01-04', DATE '1970-01-05', DATE '1970-01-06', - DATE '1970-01-07', DATE '1970-01-08', DATE '1970-01-09', - DATE '1970-01-10')) + FOR d IN (DATE '2024-01-01', DATE '2024-06-15')) ORDER BY region diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index ae9c2cf64b..824d7e8f31 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -2150,8 +2150,9 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } - // Optimized-pivot fast path: with 12+ pivot values Spark's PivotTransformer emits a - // two-phase aggregate whose second phase uses PivotFirst. We assert: + // Optimized-pivot fast path: Spark's PivotTransformer emits a two-phase aggregate whose + // second phase uses PivotFirst whenever every aggregate output type is in + // PivotFirst.supportsDataType. We assert: // 1. The physical plan actually contains a PivotFirst (Spark still picked the fast path). // 2. Comet converted the second-phase aggregate to CometHashAggregateExec (i.e. our // CometPivotFirst serde accepted the aggregate). @@ -2172,9 +2173,7 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { withParquetTable(path.toUri.toString, "sales") { val pivotSql = """SELECT * FROM sales - | PIVOT (sum(earnings) - | FOR course IN ('dotNET', 'Java', - | '1', '2', '3', '4', '5', '6', '7', '8', '9', '10')) + | PIVOT (sum(earnings) FOR course IN ('dotNET', 'Java')) | ORDER BY year""".stripMargin val df = spark.sql(pivotSql) val plan = df.queryExecution.executedPlan