diff --git a/docs/source/contributor-guide/expression-audits/agg_funcs.md b/docs/source/contributor-guide/expression-audits/agg_funcs.md index fb27662b49..c56082f283 100644 --- a/docs/source/contributor-guide/expression-audits/agg_funcs.md +++ b/docs/source/contributor-guide/expression-audits/agg_funcs.md @@ -54,4 +54,12 @@ - Spark 4.1.1 (audited 2026-06-24): identical to 4.0.1. - `CometPercentile` reports `Incompatible` for the otherwise-supported form because DataFusion's `percentile_cont` quantizes the interpolation weight to 6 decimal places (`INTERPOLATION_PRECISION = 1e6`), so a deeply-interpolated value can differ from Spark by up to roughly `(upper - lower) * 1e-6`. The native path is opt-in via `spark.comet.expression.Percentile.allowIncompatible=true` ([#4719](https://github.com/apache/datafusion-comet/issues/4719)). +## pivot_first + +- Spark 3.4.3 (audited 2026-07-02): `PivotFirst(pivotColumn, valueColumn, pivotColumnValues)` is an internal `ImperativeAggregate` emitted only by the optimized-pivot fast path in `Analyzer.ResolvePivot`. It buckets `valueColumn` into an array of length `pivotColumnValues.size` indexed by matching `pivotColumn`. Null value columns are ignored, unmatched pivot values are dropped. Value types are gated by `PivotFirst.supportsDataType` (Boolean, Byte, Short, Int, Long, Float, Double, Decimal). +- Spark 3.5.8 (audited 2026-07-02): swaps `StructType.fromAttributes` for `DataTypeUtils.fromAttributes`; no behavior change. +- Spark 4.0.1 (audited 2026-07-02): identical to 3.5.8. +- Spark 4.1.1 (audited 2026-07-02): identical to 4.0.1. (Spark `master` adds a defensive `findPivotIndex` wrapper that returns `-1` for null keys on the non-`AtomicType` `TreeMap` path; not present in any released version Comet builds against, and Comet's HashMap-backed lookup handles `ScalarValue::Null` safely on all types.) +- `CometPivotFirst` (in `spark/src/main/scala/org/apache/comet/serde/aggregates.scala`) forwards the aggregate to the native `SparkPivotFirst` UDAF (`native/spark-expr/src/agg_funcs/pivot_first.rs`) when the value type is in the supported set. State layout matches Spark's `aggBufferAttributes` (one scalar column per pivot slot) so the shuffle schema between Partial and Final stays consistent. `evaluate()` reassembles the slots into a `ListArray` matching `PivotFirst.dataType = ArrayType(valueDataType)`. + [Spark Expression Support]: ../../user-guide/latest/expressions.md diff --git a/docs/source/user-guide/latest/expressions.md b/docs/source/user-guide/latest/expressions.md index 680422992c..78386a8a6f 100644 --- a/docs/source/user-guide/latest/expressions.md +++ b/docs/source/user-guide/latest/expressions.md @@ -104,6 +104,7 @@ The tables below list every Spark built-in expression with its current status. | `percentile` | ✅ | Single literal percentage on numeric input; array of percentages and a frequency argument fall back to Spark. Falls back by default, opt-in via allowIncompatible ([#4719](https://github.com/apache/datafusion-comet/issues/4719)) | | `percentile_cont` | ✅ | Spark 4.0+ `WITHIN GROUP (ORDER BY ...)`; ascending only, `DESC` falls back to Spark. Falls back by default, opt-in via allowIncompatible ([#4719](https://github.com/apache/datafusion-comet/issues/4719)) | | `percentile_disc` | 🔜 | Percentile aggregate | +| `pivot_first` | ✅ | Internal aggregate for the optimized `PIVOT` fast path. Value type must be in `PivotFirst.supportsDataType` (Boolean, Byte, Short, Int, Long, Float, Double, Decimal); other types keep Spark's standard filtered-aggregate path. | | `regr_avgx` | ✅ | Native: Spark rewrites to `Average` (tests in [#4551](https://github.com/apache/datafusion-comet/issues/4551)) | | `regr_avgy` | ✅ | Native: Spark rewrites to `Average` (tests in [#4551](https://github.com/apache/datafusion-comet/issues/4551)) | | `regr_count` | ✅ | Native: Spark rewrites to `Count` (tests in [#4551](https://github.com/apache/datafusion-comet/issues/4551)) | diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 25162332fd..2ab693198b 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -74,7 +74,7 @@ use datafusion::{ use datafusion_comet_spark_expr::{ create_comet_physical_fun, create_comet_physical_fun_with_eval_mode, BinaryOutputStyle, BloomFilterAgg, BloomFilterMightContain, CsvWriteOptions, EvalMode, SparkArraysZipFunc, - SparkBloomFilterVersion, SumInteger, ToCsv, + SparkBloomFilterVersion, SparkPivotFirst, SumInteger, ToCsv, }; use datafusion_spark::function::aggregate::collect::SparkCollectSet; use iceberg::expr::Bind; @@ -2653,6 +2653,30 @@ impl PhysicalPlanner { let func = AggregateUDF::new_from_impl(SparkCollectSet::new()); Self::create_aggr_func_expr("collect_set", schema, vec![child], func) } + AggExprStruct::PivotFirst(expr) => { + let pivot_col = + self.create_expr(expr.pivot_column.as_ref().unwrap(), Arc::clone(&schema))?; + let value_col = + self.create_expr(expr.value_column.as_ref().unwrap(), Arc::clone(&schema))?; + let value_type = to_arrow_datatype(expr.value_datatype.as_ref().unwrap()); + // Reconstruct pivot values as ScalarValues from the serialized Literal exprs. + // The Scala serde builds them as Literal(v, pivot_column.dataType), so extracting + // via Literal downcast gives us the exact ScalarValue the update path compares + // against per input row. + let mut pivot_values = Vec::with_capacity(expr.pivot_values.len()); + for lit_expr in &expr.pivot_values { + let physical = self.create_expr(lit_expr, Arc::clone(&schema))?; + let literal = physical.downcast_ref::().ok_or_else(|| { + GeneralError( + "PivotFirst pivot_values must be Literal expressions".to_string(), + ) + })?; + pivot_values.push(literal.value().clone()); + } + let func = + AggregateUDF::new_from_impl(SparkPivotFirst::new(value_type, pivot_values)); + Self::create_aggr_func_expr("pivot_first", schema, vec![pivot_col, value_col], func) + } } } diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 5b2a6ce9ee..1055c32c16 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -145,6 +145,7 @@ message AggExpr { BloomFilterAgg bloomFilterAgg = 16; CollectSet collectSet = 17; Percentile percentile = 18; + PivotFirst pivotFirst = 19; } // Optional filter expression for SQL FILTER (WHERE ...) clause. @@ -276,6 +277,21 @@ message CollectSet { DataType datatype = 2; } +// Optimized Pivot second-phase aggregate. Given a `pivot_column` and a +// pre-computed list of `pivot_values`, buckets the incoming `value_column` +// into an array whose length is `pivot_values.size()`. Rows whose pivot +// value is not in `pivot_values` are ignored; if multiple rows in the same +// group map to the same bucket, the last non-null value wins (matches +// Spark's `PivotFirst.update`/`merge`). Output type is +// ArrayType(value_datatype). The pivot_values are serialized as Literal +// Exprs so the value bytes and data type reach native code faithfully. +message PivotFirst { + Expr pivot_column = 1; + Expr value_column = 2; + repeated Expr pivot_values = 3; + DataType value_datatype = 4; +} + enum EvalMode { LEGACY = 0; TRY = 1; diff --git a/native/spark-expr/src/agg_funcs/mod.rs b/native/spark-expr/src/agg_funcs/mod.rs index 2a0322e46c..2955e5ef54 100644 --- a/native/spark-expr/src/agg_funcs/mod.rs +++ b/native/spark-expr/src/agg_funcs/mod.rs @@ -19,6 +19,7 @@ mod avg; mod avg_decimal; mod correlation; mod covariance; +mod pivot_first; mod stddev; mod sum_decimal; mod sum_int; @@ -29,6 +30,7 @@ pub use avg::Avg; pub use avg_decimal::AvgDecimal; pub use correlation::Correlation; pub use covariance::Covariance; +pub use pivot_first::SparkPivotFirst; pub use stddev::Stddev; pub use sum_decimal::SumDecimal; pub use sum_int::SumInteger; diff --git a/native/spark-expr/src/agg_funcs/pivot_first.rs b/native/spark-expr/src/agg_funcs/pivot_first.rs new file mode 100644 index 0000000000..a80259fa89 --- /dev/null +++ b/native/spark-expr/src/agg_funcs/pivot_first.rs @@ -0,0 +1,349 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Spark's `PivotFirst` aggregate. Used only by the second phase of the optimized pivot plan +//! generated by `PivotTransformer`. For each group, `PivotFirst` maintains an array of +//! `pivot_values.len()` slots; on each input row it evaluates the pivot column, looks up its +//! index in `pivot_values`, and writes the value column into that slot when a match is found +//! and the value is non-null. Rows with unmatched pivot values are ignored; matched rows with +//! a null value column leave the slot unchanged (matches Spark). +//! +//! State layout is one column per pivot slot, matching Spark's `aggBufferAttributes` (which +//! declares `indexSize` `AttributeReference`s, one per pivot value). This keeps the shuffle +//! schema between Partial and Final consistent with what Spark catalyst declared; otherwise +//! the shuffle exchange rejects the batch. `evaluate()` reassembles the slots into a +//! `ListArray` matching `PivotFirst.dataType = ArrayType(value_type)`. + +use arrow::array::{Array, ArrayRef}; +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion::common::utils::SingleRowListArrayBuilder; +use datafusion::common::{DataFusionError, Result as DFResult, ScalarValue}; +use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion::logical_expr::Volatility::Immutable; +use datafusion::logical_expr::{Accumulator, AggregateUDFImpl, Signature}; +use datafusion::physical_expr::expressions::format_state_name; +use std::collections::HashMap; +use std::sync::Arc; + +/// UDAF implementation of Spark's `PivotFirst`. +/// +/// `pivot_values` is a fixed, plan-time list of the pivot column values that occupy each +/// output slot; `pivot_index[v] = i` means an input row whose pivot column equals `v` writes +/// into slot `i`. Both the vector and the map are wrapped in `Arc` because `accumulator()` +/// fires once per group in a grouped aggregate and we want that path to bump a refcount +/// rather than deep-clone. +#[derive(Debug)] +pub struct SparkPivotFirst { + signature: Signature, + value_type: DataType, + // Kept for `PartialEq`/`Hash` (identity of the aggregate for plan comparison) and for the + // deterministic slot ordering `state_fields` needs. `HashMap` alone would give us the map + // but not a stable order or a `Hash` impl. + pivot_values: Arc>, + pivot_index: Arc>, +} + +impl PartialEq for SparkPivotFirst { + fn eq(&self, other: &Self) -> bool { + self.value_type == other.value_type && self.pivot_values == other.pivot_values + } +} + +impl Eq for SparkPivotFirst {} + +impl std::hash::Hash for SparkPivotFirst { + fn hash(&self, state: &mut H) { + self.value_type.hash(state); + self.pivot_values.hash(state); + } +} + +impl SparkPivotFirst { + pub fn new(value_type: DataType, pivot_values: Vec) -> Self { + let mut pivot_index = HashMap::with_capacity(pivot_values.len()); + // Spark's PivotFirst uses the FIRST occurrence's index (HashMap/TreeMap semantics), so + // when duplicates are somehow present we mirror that by only inserting the first one. + for (i, v) in pivot_values.iter().enumerate() { + pivot_index.entry(v.clone()).or_insert(i); + } + Self { + signature: Signature::user_defined(Immutable), + value_type, + pivot_values: Arc::new(pivot_values), + pivot_index: Arc::new(pivot_index), + } + } +} + +impl AggregateUDFImpl for SparkPivotFirst { + fn name(&self) -> &str { + "pivot_first" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DFResult { + Ok(DataType::List(Arc::new(Field::new_list_field( + self.value_type.clone(), + true, + )))) + } + + fn state_fields(&self, args: StateFieldsArgs) -> DFResult> { + // One field per pivot slot, matching Spark's aggBufferAttributes so the shuffle + // exchange sees the same schema catalyst declared. `format_state_name` is the same + // helper other aggregates in this crate use (see `avg.rs`, `stddev.rs`). + Ok((0..self.pivot_values.len()) + .map(|i| { + Arc::new(Field::new( + format_state_name(args.name, &i.to_string()), + self.value_type.clone(), + true, + )) + }) + .collect()) + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> DFResult> { + Ok(Box::new(PivotFirstAccumulator::new( + self.value_type.clone(), + Arc::clone(&self.pivot_index), + ))) + } +} + +/// Per-group state: `slots[i]` holds the latest non-null value assigned to pivot slot `i`, or +/// `None` when nothing has written to that slot yet. +#[derive(Debug)] +struct PivotFirstAccumulator { + value_type: DataType, + pivot_index: Arc>, + slots: Vec>, +} + +impl PivotFirstAccumulator { + fn new(value_type: DataType, pivot_index: Arc>) -> Self { + let slots = vec![None; pivot_index.len()]; + Self { + value_type, + pivot_index, + slots, + } + } + + /// Turn slot `i` into a `ScalarValue`, substituting a typed null when the slot is empty. + fn slot_or_null(&self, i: usize) -> DFResult { + Ok(match &self.slots[i] { + Some(v) => v.clone(), + None => ScalarValue::try_from(&self.value_type)?, + }) + } +} + +impl Accumulator for PivotFirstAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> { + if values.len() != 2 { + return Err(DataFusionError::Internal(format!( + "PivotFirst expects 2 inputs (pivot, value); got {}", + values.len() + ))); + } + let pivot_arr = &values[0]; + let value_arr = &values[1]; + if pivot_arr.len() != value_arr.len() { + return Err(DataFusionError::Internal( + "PivotFirst pivot and value arrays have different lengths".into(), + )); + } + for row in 0..pivot_arr.len() { + // Spark ignores the row entirely if either the pivot value is unmatched (index<0) + // or the value is null. Matching Spark exactly here is important because + // `PivotFirst.update` never writes for a null value, so a mid-batch null does not + // clobber an earlier non-null. + let pivot_scalar = ScalarValue::try_from_array(pivot_arr, row)?; + if let Some(&slot_idx) = self.pivot_index.get(&pivot_scalar) { + if !value_arr.is_null(row) { + self.slots[slot_idx] = Some(ScalarValue::try_from_array(value_arr, row)?); + } + } + } + Ok(()) + } + + fn evaluate(&mut self) -> DFResult { + // Collapse the per-slot state into a single ScalarValue::List with typed nulls for + // empty slots. This matches Spark's `PivotFirst.dataType = ArrayType(value_type)`. + let flat = ScalarValue::iter_to_array(self.state()?)?; + Ok(SingleRowListArrayBuilder::new(flat).build_list_scalar()) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + + self.slots.capacity() * std::mem::size_of::>() + } + + fn state(&mut self) -> DFResult> { + // One ScalarValue per pivot slot; matches `state_fields`. + (0..self.slots.len()) + .map(|i| self.slot_or_null(i)) + .collect() + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> { + if states.len() != self.slots.len() { + return Err(DataFusionError::Internal(format!( + "PivotFirst merge expects {} state columns; got {}", + self.slots.len(), + states.len() + ))); + } + if states.is_empty() { + return Ok(()); + } + // Each column is one slot; each row is one incoming partial state. Any non-null cell + // overwrites our current slot with "last write wins" (matches Spark's + // `PivotFirst.merge`). Walk backwards and break on the first non-null so we only build + // at most one `ScalarValue` per slot per merge, not one per row per slot. + let n_rows = states[0].len(); + for (slot_idx, col) in states.iter().enumerate() { + for row in (0..n_rows).rev() { + if !col.is_null(row) { + self.slots[slot_idx] = Some(ScalarValue::try_from_array(col, row)?); + break; + } + } + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Int32Array, StringArray}; + + fn scalar(v: i32) -> ScalarValue { + ScalarValue::Int32(Some(v)) + } + fn s(v: &str) -> ScalarValue { + ScalarValue::Utf8(Some(v.to_string())) + } + + fn make_acc() -> PivotFirstAccumulator { + let pivot_values = [s("a"), s("b"), s("c")]; + let pivot_index: HashMap = pivot_values + .iter() + .enumerate() + .map(|(i, v)| (v.clone(), i)) + .collect(); + PivotFirstAccumulator::new(DataType::Int32, Arc::new(pivot_index)) + } + + #[test] + fn update_writes_by_pivot_index() { + let mut acc = make_acc(); + let pivots: ArrayRef = Arc::new(StringArray::from(vec!["a", "c", "a"])); + let values: ArrayRef = Arc::new(Int32Array::from(vec![1, 3, 10])); + acc.update_batch(&[pivots, values]).unwrap(); + // Slot 'a' = 10 (last non-null write wins for the same pivot value), slot 'b' = null, + // slot 'c' = 3. + assert_eq!(acc.slots, vec![Some(scalar(10)), None, Some(scalar(3))]); + } + + #[test] + fn update_ignores_unmatched_pivot() { + let mut acc = make_acc(); + let pivots: ArrayRef = Arc::new(StringArray::from(vec!["a", "z", "b"])); + let values: ArrayRef = Arc::new(Int32Array::from(vec![1, 99, 2])); + acc.update_batch(&[pivots, values]).unwrap(); + // 'z' is not in the pivot list - its row is ignored entirely. + assert_eq!(acc.slots, vec![Some(scalar(1)), Some(scalar(2)), None]); + } + + #[test] + fn update_ignores_null_value() { + let mut acc = make_acc(); + let pivots: ArrayRef = Arc::new(StringArray::from(vec!["a", "a"])); + let values: ArrayRef = Arc::new(Int32Array::from(vec![Some(5), None])); + acc.update_batch(&[pivots, values]).unwrap(); + // The second row's null must not clobber the first row's 5. + assert_eq!(acc.slots, vec![Some(scalar(5)), None, None]); + } + + #[test] + fn evaluate_returns_list_with_nulls_for_empty_slots() { + let mut acc = make_acc(); + let pivots: ArrayRef = Arc::new(StringArray::from(vec!["c"])); + let values: ArrayRef = Arc::new(Int32Array::from(vec![7])); + acc.update_batch(&[pivots, values]).unwrap(); + + let result = acc.evaluate().unwrap(); + match result { + ScalarValue::List(list) => { + assert_eq!(list.len(), 1); + let inner = list.value(0); + let inner = inner.as_any().downcast_ref::().unwrap(); + assert!(inner.is_null(0)); + assert!(inner.is_null(1)); + assert_eq!(inner.value(2), 7); + } + other => panic!("expected ScalarValue::List, got {other:?}"), + } + } + + #[test] + fn state_returns_one_scalar_per_slot() { + let mut acc = make_acc(); + let pivots: ArrayRef = Arc::new(StringArray::from(vec!["a", "c"])); + let values: ArrayRef = Arc::new(Int32Array::from(vec![1, 3])); + acc.update_batch(&[pivots, values]).unwrap(); + + let state = acc.state().unwrap(); + assert_eq!(state.len(), 3); + assert_eq!(state[0], scalar(1)); + assert!(matches!(state[1], ScalarValue::Int32(None))); + assert_eq!(state[2], scalar(3)); + } + + #[test] + fn merge_pastes_non_null_slots() { + let mut acc = make_acc(); + let pivots: ArrayRef = Arc::new(StringArray::from(vec!["a"])); + let values: ArrayRef = Arc::new(Int32Array::from(vec![1])); + acc.update_batch(&[pivots, values]).unwrap(); + + // Build the other side's state and merge - slot 'a' stays 1, 'b' becomes 2, 'c' 3. + let mut other = make_acc(); + let pivots2: ArrayRef = Arc::new(StringArray::from(vec!["b", "c"])); + let values2: ArrayRef = Arc::new(Int32Array::from(vec![2, 3])); + other.update_batch(&[pivots2, values2]).unwrap(); + let state = other.state().unwrap(); + // Feed each state scalar as a single-row column, mirroring what the shuffle produces. + let state_arrays: Vec = state + .into_iter() + .map(|sv| sv.to_array_of_size(1).unwrap()) + .collect(); + acc.merge_batch(&state_arrays).unwrap(); + + assert_eq!( + acc.slots, + vec![Some(scalar(1)), Some(scalar(2)), Some(scalar(3))] + ); + } +} diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index b752f41d74..07d461b084 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -400,6 +400,7 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim { classOf[Max] -> CometMax, classOf[Min] -> CometMin, classOf[Percentile] -> CometPercentile, + classOf[PivotFirst] -> CometPivotFirst, classOf[StddevPop] -> CometStddevPop, classOf[StddevSamp] -> CometStddevSamp, classOf[Sum] -> CometSum, diff --git a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala index 5710232cb4..c71b1dada9 100644 --- a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala +++ b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala @@ -22,9 +22,9 @@ package org.apache.comet.serde import scala.jdk.CollectionConverters._ import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal} -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate, CentralMomentAgg, CollectSet, Corr, Count, Covariance, CovPopulation, CovSample, First, Last, Max, Min, Percentile, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate, CentralMomentAgg, CollectSet, Corr, Count, Covariance, CovPopulation, CovSample, First, Last, Max, Min, Percentile, PivotFirst, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{ByteType, DecimalType, DoubleType, IntegerType, LongType, NumericType, ShortType, StringType} +import org.apache.spark.sql.types.{ByteType, DataType, DecimalType, DoubleType, IntegerType, LongType, NumericType, ShortType, StringType} import org.apache.comet.CometConf.COMET_EXEC_STRICT_FLOATING_POINT import org.apache.comet.CometSparkSessionExtensions.{isSpark41Plus, withFallbackReason} @@ -823,6 +823,69 @@ object CometCollectSet extends CometAggregateExpressionSerde[CollectSet] { } } +object CometPivotFirst extends CometAggregateExpressionSerde[PivotFirst] { + + // Delegate to Spark's own PivotFirst.supportsDataType so the two lists cannot drift if a + // future Spark version adds a value type to the fast-path gate. + private def unsupportedValueTypeReason(dt: DataType): String = + s"Unsupported value data type: $dt" + + private val emptyPivotValuesReason = "Pivot values list is empty" + + override def getUnsupportedReasons(): Seq[String] = Seq( + "Value data type outside PivotFirst.supportsDataType " + + "(Boolean, Byte, Short, Int, Long, Float, Double, Decimal)", + emptyPivotValuesReason) + + override def getSupportLevel(expr: PivotFirst): SupportLevel = { + if (!PivotFirst.supportsDataType(expr.valueDataType)) { + Unsupported(Some(unsupportedValueTypeReason(expr.valueDataType))) + } else if (expr.pivotColumnValues.isEmpty) { + Unsupported(Some(emptyPivotValuesReason)) + } else { + Compatible() + } + } + + override def convert( + aggExpr: AggregateExpression, + expr: PivotFirst, + inputs: Seq[Attribute], + binding: Boolean, + conf: SQLConf): Option[ExprOuterClass.AggExpr] = { + val pivotColExpr = exprToProto(expr.pivotColumn, inputs, binding) + val valueColExpr = exprToProto(expr.valueColumn, inputs, binding) + val valueDt = serializeDataType(expr.valueDataType) + val pivotDt = expr.pivotColumn.dataType + + // Spark stores the pivot values as already-evaluated raw Scala values. Rebuild them as + // Literal expressions carrying the pivot column data type so the native side receives + // the exact bytes it will later compare against. Preserve list order - the position in + // this list is the output array index. + val pivotValueExprs = + expr.pivotColumnValues.map(v => exprToProto(Literal(v, pivotDt), inputs, binding)) + + if (pivotColExpr.isDefined && valueColExpr.isDefined && valueDt.isDefined && + pivotValueExprs.forall(_.isDefined)) { + val builder = ExprOuterClass.PivotFirst.newBuilder() + builder.setPivotColumn(pivotColExpr.get) + builder.setValueColumn(valueColExpr.get) + builder.setValueDatatype(valueDt.get) + pivotValueExprs.foreach(v => builder.addPivotValues(v.get)) + Some(ExprOuterClass.AggExpr.newBuilder().setPivotFirst(builder).build()) + } else if (valueDt.isEmpty) { + withFallbackReason( + aggExpr, + unsupportedValueTypeReason(expr.valueDataType), + expr.valueColumn) + None + } else { + withFallbackReason(aggExpr, expr.pivotColumn, expr.valueColumn) + None + } + } +} + object AggSerde { import org.apache.spark.sql.types._ diff --git a/spark/src/test/resources/sql-tests/expressions/aggregate/PivotFirst.sql b/spark/src/test/resources/sql-tests/expressions/aggregate/PivotFirst.sql new file mode 100644 index 0000000000..bf9c45279a --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/aggregate/PivotFirst.sql @@ -0,0 +1,321 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- PivotFirst is Spark's internal aggregate emitted by the two-phase optimized pivot plan +-- (PivotTransformer). It runs whenever every aggregate output type is in +-- PivotFirst.supportsDataType (Boolean, Byte, Short, Int, Long, Float, Double, Decimal). +-- These SQL tests exercise that path via SQL PIVOT clauses, one per supported value type +-- plus a handful of edge cases (unmatched pivot values, all-null groups, null pivot column). +-- They intentionally use `spark_answer_only` because the pivot plan's final projection is a +-- Spark expression (ExtractValue on the aggregate output) so the whole query does not always +-- run natively even when PivotFirst itself does. + +-- ============================================================ +-- Setup: sales-like table used for most cases +-- ============================================================ + +statement +CREATE TABLE pf_sales (year int, course string, earnings int) USING parquet + +statement +INSERT INTO pf_sales VALUES + (2012, 'dotNET', 15000), + (2012, 'Java', 20000), + (2013, 'dotNET', 48000), + (2013, 'Java', 30000) + +-- ============================================================ +-- Int values: canonical optimized-pivot query. Extra pivot values +-- (1..10) push the aggregate over the threshold that turns on the +-- PivotFirst plan. +-- ============================================================ + +query spark_answer_only +SELECT * FROM pf_sales + PIVOT (sum(earnings) + FOR course IN ('dotNET', 'Java')) + ORDER BY year + +-- ============================================================ +-- Long values: sum of a bigint aggregates to bigint. +-- ============================================================ + +statement +CREATE TABLE pf_sales_bi (year int, course string, earnings bigint) USING parquet + +statement +INSERT INTO pf_sales_bi VALUES + (2012, 'dotNET', 15000000000), + (2012, 'Java', 20000000000), + (2013, 'dotNET', 48000000000), + (2013, 'Java', 30000000000) + +query spark_answer_only +SELECT * FROM pf_sales_bi + PIVOT (sum(earnings) + FOR course IN ('dotNET', 'Java')) + ORDER BY year + +-- ============================================================ +-- Double values: avg(earnings) is double. +-- ============================================================ + +query spark_answer_only +SELECT * FROM pf_sales + PIVOT (avg(earnings) + FOR course IN ('dotNET', 'Java')) + ORDER BY year + +-- ============================================================ +-- Float values: cast the sum result to float via a wrapping min. +-- ============================================================ + +statement +CREATE TABLE pf_sales_f (year int, course string, earnings float) USING parquet + +statement +INSERT INTO pf_sales_f VALUES + (2012, 'dotNET', CAST(15000 AS FLOAT)), + (2012, 'Java', CAST(20000 AS FLOAT)), + (2013, 'dotNET', CAST(48000 AS FLOAT)), + (2013, 'Java', CAST(30000 AS FLOAT)) + +query spark_answer_only +SELECT * FROM pf_sales_f + PIVOT (min(earnings) + FOR course IN ('dotNET', 'Java')) + ORDER BY year + +-- ============================================================ +-- Short and Byte value types (sum promotes them to bigint, so use min +-- to keep the aggregate output type equal to the input type). +-- ============================================================ + +statement +CREATE TABLE pf_sales_small ( + year int, course string, s smallint, b tinyint) USING parquet + +statement +INSERT INTO pf_sales_small VALUES + (2012, 'dotNET', 150, 10), + (2012, 'Java', 200, 20), + (2013, 'dotNET', 480, 30), + (2013, 'Java', 300, 40) + +query spark_answer_only +SELECT * FROM pf_sales_small + PIVOT (min(s) + FOR course IN ('dotNET', 'Java')) + ORDER BY year + +query spark_answer_only +SELECT * FROM pf_sales_small + PIVOT (min(b) + FOR course IN ('dotNET', 'Java')) + ORDER BY year + +-- ============================================================ +-- Boolean values: bool_or's output type is boolean. +-- ============================================================ + +statement +CREATE TABLE pf_flags (year int, course string, flag boolean) USING parquet + +statement +INSERT INTO pf_flags VALUES + (2012, 'dotNET', true), + (2012, 'Java', false), + (2013, 'dotNET', false), + (2013, 'Java', true) + +query spark_answer_only +SELECT * FROM pf_flags + PIVOT (bool_or(flag) + FOR course IN ('dotNET', 'Java')) + ORDER BY year + +-- ============================================================ +-- Decimal values. +-- ============================================================ + +statement +CREATE TABLE pf_sales_dec (year int, course string, earnings decimal(10,2)) USING parquet + +statement +INSERT INTO pf_sales_dec VALUES + (2012, 'dotNET', 15000.00), + (2012, 'Java', 20000.00), + (2013, 'dotNET', 48000.00), + (2013, 'Java', 30000.00) + +query spark_answer_only +SELECT * FROM pf_sales_dec + PIVOT (sum(earnings) + FOR course IN ('dotNET', 'Java')) + ORDER BY year + +-- ============================================================ +-- Multiple aggregates in a single pivot. +-- ============================================================ + +query spark_answer_only +SELECT * FROM pf_sales + PIVOT (sum(earnings) AS s, avg(earnings) AS a + FOR course IN ('dotNET', 'Java')) + ORDER BY year + +-- ============================================================ +-- Pivot values that do not appear in the data: the corresponding +-- output slots must be NULL. +-- ============================================================ + +query spark_answer_only +SELECT * FROM pf_sales + PIVOT (sum(earnings) + FOR course IN ('dotNET', 'Java', 'MissingA', 'MissingB')) + ORDER BY year + +-- ============================================================ +-- Rows whose pivot column is not in the pivot list must be ignored: +-- 'Python' below should not contribute to any output slot. +-- ============================================================ + +statement +CREATE TABLE pf_sales_extra (year int, course string, earnings int) USING parquet + +statement +INSERT INTO pf_sales_extra VALUES + (2012, 'dotNET', 15000), + (2012, 'Java', 20000), + (2012, 'Python', 12345), + (2013, 'dotNET', 48000), + (2013, 'Java', 30000), + (2013, 'Python', 67890) + +query spark_answer_only +SELECT * FROM pf_sales_extra + PIVOT (sum(earnings) + FOR course IN ('dotNET', 'Java')) + ORDER BY year + +-- ============================================================ +-- Group with no matching rows: every output slot is NULL for that +-- pivot combination. Ensures NULL initialization matches Spark. +-- ============================================================ + +statement +CREATE TABLE pf_sales_partial (year int, course string, earnings int) USING parquet + +statement +INSERT INTO pf_sales_partial VALUES + (2012, 'dotNET', 15000), + (2012, 'Java', 20000), + (2013, 'Python', 42000) + +query spark_answer_only +SELECT * FROM pf_sales_partial + PIVOT (sum(earnings) + FOR course IN ('dotNET', 'Java')) + ORDER BY year + +-- ============================================================ +-- NULL value in an aggregated column: the null must not clobber a +-- previously written slot (Spark PivotFirst.update ignores nulls). +-- ============================================================ + +statement +CREATE TABLE pf_sales_nullable (year int, course string, earnings int) USING parquet + +statement +INSERT INTO pf_sales_nullable VALUES + (2012, 'dotNET', 15000), + (2012, 'dotNET', NULL), + (2012, 'Java', NULL), + (2013, 'dotNET', 48000), + (2013, 'Java', 30000) + +query spark_answer_only +SELECT * FROM pf_sales_nullable + PIVOT (sum(earnings) + FOR course IN ('dotNET', 'Java')) + ORDER BY year + +-- ============================================================ +-- NULL in the pivot COLUMN. Spark tracks this with a dedicated +-- test ("pivot with null should not throw NPE") because early +-- versions of PivotFirst NPE'd on null keys with a TreeMap. +-- Comet uses a HashMap where ScalarValue::Null +-- hashes safely, so this must produce the same rows as Spark. +-- ============================================================ + +statement +CREATE TABLE pf_sales_nullpivot (year int, course string, earnings int) USING parquet + +statement +INSERT INTO pf_sales_nullpivot VALUES + (2012, 'dotNET', 15000), + (2012, NULL, 9999), + (2013, 'Java', 30000), + (2013, NULL, 42) + +query spark_answer_only +SELECT * FROM pf_sales_nullpivot + PIVOT (sum(earnings) + FOR course IN ('dotNET', 'Java')) + ORDER BY year + +-- ============================================================ +-- Non-string pivot column: integer. Exercises the pivot value +-- serialization path with a numeric literal. +-- ============================================================ + +statement +CREATE TABLE pf_sales_intpivot (region string, year int, earnings int) USING parquet + +statement +INSERT INTO pf_sales_intpivot VALUES + ('NA', 2012, 15000), + ('NA', 2013, 48000), + ('EU', 2012, 20000), + ('EU', 2013, 30000) + +query spark_answer_only +SELECT * FROM pf_sales_intpivot + PIVOT (sum(earnings) + FOR year IN (2012, 2013)) + ORDER BY region + +-- ============================================================ +-- Non-string pivot column: date. Exercises the pivot value +-- serialization path with a DateType literal. +-- ============================================================ + +statement +CREATE TABLE pf_sales_datepivot (region string, d date, earnings int) USING parquet + +statement +INSERT INTO pf_sales_datepivot VALUES + ('NA', DATE '2024-01-01', 15000), + ('NA', DATE '2024-06-15', 48000), + ('EU', DATE '2024-01-01', 20000), + ('EU', DATE '2024-06-15', 30000) + +query spark_answer_only +SELECT * FROM pf_sales_datepivot + PIVOT (sum(earnings) + FOR d IN (DATE '2024-01-01', DATE '2024-06-15')) + ORDER BY region diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index ae14c68207..824d7e8f31 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -2150,4 +2150,45 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + // Optimized-pivot fast path: Spark's PivotTransformer emits a two-phase aggregate whose + // second phase uses PivotFirst whenever every aggregate output type is in + // PivotFirst.supportsDataType. We assert: + // 1. The physical plan actually contains a PivotFirst (Spark still picked the fast path). + // 2. Comet converted the second-phase aggregate to CometHashAggregateExec (i.e. our + // CometPivotFirst serde accepted the aggregate). + // 3. Results match Spark exactly. + // If any of the three fails, the test surfaces the specific breakage. + test("PivotFirst runs natively on the optimized pivot fast path") { + withTempDir { dir => + val path = new Path(dir.toURI.toString, "sales") + spark + .sql("""SELECT * FROM VALUES + | (2012, 'dotNET', 15000), + | (2012, 'Java', 20000), + | (2013, 'dotNET', 48000), + | (2013, 'Java', 30000) + |AS t(year, course, earnings)""".stripMargin) + .write + .parquet(path.toUri.toString) + withParquetTable(path.toUri.toString, "sales") { + val pivotSql = + """SELECT * FROM sales + | PIVOT (sum(earnings) FOR course IN ('dotNET', 'Java')) + | ORDER BY year""".stripMargin + val df = spark.sql(pivotSql) + val plan = df.queryExecution.executedPlan + val pivotFirstAggs = collectWithSubqueries(plan) { + case a: CometHashAggregateExec + if a.aggregateExpressions.exists(_.aggregateFunction + .isInstanceOf[org.apache.spark.sql.catalyst.expressions.aggregate.PivotFirst]) => + a + } + assert( + pivotFirstAggs.nonEmpty, + s"Expected a CometHashAggregateExec containing PivotFirst in:\n$plan") + checkSparkAnswer(df) + } + } + } + }