From 0782b09031686ba386be82270249ca310d2227e5 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 3 Jul 2026 09:19:29 -0600 Subject: [PATCH 1/2] feat: support listagg / string_agg aggregate (Spark 4.0+) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a `SparkListAgg` UDAF and a `CometListAgg` serde so that Comet can natively execute the simple form of Spark 4.0's `LISTAGG(child, delimiter)` / `string_agg` on `StringType` inputs with a literal delimiter. `WITHIN GROUP (ORDER BY ...)`, `BinaryType` inputs, non-literal delimiters, and non-default collations fall back to Spark. `DISTINCT` falls back because Comet already rejects multi-column distinct aggregates. The native accumulator returns `Utf8` but keeps its intermediate state as `Binary` to match Spark's `TypedImperativeAggregate` buffer schema so the Comet shuffle layer does not insert a `Utf8` β†’ `Binary` cast the merge side cannot read back. Scaffolding produced by the `implement-comet-expression` skill. --- .../expression-audits/agg_funcs.md | 8 + docs/source/user-guide/latest/expressions.md | 4 +- native/core/src/execution/planner.rs | 9 +- native/proto/src/proto/expr.proto | 14 + native/spark-expr/src/agg_funcs/list_agg.rs | 284 ++++++++++++++++++ native/spark-expr/src/agg_funcs/mod.rs | 2 + .../apache/comet/serde/QueryPlanSerde.scala | 45 +-- .../apache/comet/shims/CometExprShim.scala | 4 +- .../apache/comet/shims/CometExprShim.scala | 4 +- .../org/apache/comet/serde/CometListAgg.scala | 100 ++++++ .../comet/shims/Spark4xCometExprShim.scala | 5 +- .../expressions/aggregate/listagg.sql | 157 ++++++++++ 12 files changed, 609 insertions(+), 27 deletions(-) create mode 100644 native/spark-expr/src/agg_funcs/list_agg.rs create mode 100644 spark/src/main/spark-4.x/org/apache/comet/serde/CometListAgg.scala create mode 100644 spark/src/test/resources/sql-tests/expressions/aggregate/listagg.sql diff --git a/docs/source/contributor-guide/expression-audits/agg_funcs.md b/docs/source/contributor-guide/expression-audits/agg_funcs.md index fb27662b49..88c47ae625 100644 --- a/docs/source/contributor-guide/expression-audits/agg_funcs.md +++ b/docs/source/contributor-guide/expression-audits/agg_funcs.md @@ -39,6 +39,14 @@ - Spark 3.5.8 (2026-05-26) - Spark 4.0.1 (2026-05-26) +## listagg + +- Spark 3.4.3 (audited 2026-07-03): does not exist. `ListAgg` was added in Spark 4.0. +- Spark 3.5.8 (audited 2026-07-03): does not exist. +- Spark 4.0.1 (audited 2026-07-03): `ListAgg(child, delimiter, orderExpressions)` in `aggregate/collect.scala`. Accepts `StringType` or `BinaryType` inputs; result type matches child. Skips nulls; empty or all-null groups return `NULL`. A `NULL` delimiter is treated as an empty string. `CometListAgg` maps only the simple form: `StringType` child with a literal `StringType`/`NullType` delimiter and no `WITHIN GROUP`. `BinaryType` inputs, `WITHIN GROUP (ORDER BY ...)`, non-literal delimiters, and non-default collations fall back to Spark. `DISTINCT` falls back because Comet rejects multi-column distinct aggregates (`ListAgg` has two children). +- Spark 4.1.1 (audited 2026-07-03): byte-identical to 4.0.1. +- Native accumulator (`SparkListAgg`) returns `Utf8` but keeps its intermediate state as `Binary`, matching Spark's `TypedImperativeAggregate` buffer schema so the Comet shuffle layer does not insert a `Utf8` β†’ `Binary` cast the merge side cannot read back. + ## median - Spark 3.4.3 (audited 2026-06-24): `Median(child)` is a `RuntimeReplaceableAggregate` with `replacement = Percentile(child, Literal(0.5))`. Catalyst rewrites `median(x)` to `percentile(x, 0.5)` before Comet sees the plan, so it is served by `CometPercentile`. diff --git a/docs/source/user-guide/latest/expressions.md b/docs/source/user-guide/latest/expressions.md index d7c86aad1f..54ab9ce824 100644 --- a/docs/source/user-guide/latest/expressions.md +++ b/docs/source/user-guide/latest/expressions.md @@ -93,7 +93,7 @@ The tables below list every Spark built-in expression with its current status. | `kurtosis` | πŸ”œ | tracking [#4098](https://github.com/apache/datafusion-comet/issues/4098) | | `last` | βœ… | | | `last_value` | βœ… | | -| `listagg` | πŸ”œ | String aggregation | +| `listagg` | βœ… | Spark 4.0+. `StringType` input with a literal delimiter; `WITHIN GROUP (ORDER BY ...)` and `BinaryType` inputs fall back to Spark. | | `max` | βœ… | | | `max_by` | πŸ”œ | [#3841](https://github.com/apache/datafusion-comet/issues/3841) | | `mean` | βœ… | | @@ -119,7 +119,7 @@ The tables below list every Spark built-in expression with its current status. | `stddev` | βœ… | | | `stddev_pop` | βœ… | | | `stddev_samp` | βœ… | | -| `string_agg` | πŸ”œ | String aggregation (alias of `listagg`) | +| `string_agg` | βœ… | Alias of `listagg`; same restrictions apply. | | `sum` | βœ… | | | `try_avg` | βœ… | Interval types fall back | | `try_sum` | βœ… | | diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 25162332fd..56f0198920 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -74,7 +74,7 @@ use datafusion::{ use datafusion_comet_spark_expr::{ create_comet_physical_fun, create_comet_physical_fun_with_eval_mode, BinaryOutputStyle, BloomFilterAgg, BloomFilterMightContain, CsvWriteOptions, EvalMode, SparkArraysZipFunc, - SparkBloomFilterVersion, SumInteger, ToCsv, + SparkBloomFilterVersion, SparkListAgg, SumInteger, ToCsv, }; use datafusion_spark::function::aggregate::collect::SparkCollectSet; use iceberg::expr::Bind; @@ -2653,6 +2653,13 @@ impl PhysicalPlanner { let func = AggregateUDF::new_from_impl(SparkCollectSet::new()); Self::create_aggr_func_expr("collect_set", schema, vec![child], func) } + AggExprStruct::ListAgg(expr) => { + let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?; + let delimiter = + self.create_expr(expr.delimiter.as_ref().unwrap(), Arc::clone(&schema))?; + let func = AggregateUDF::new_from_impl(SparkListAgg::new()); + Self::create_aggr_func_expr("listagg", schema, vec![child, delimiter], func) + } } } diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 32adc16b72..833f3cdc57 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -146,6 +146,7 @@ message AggExpr { BloomFilterAgg bloomFilterAgg = 16; CollectSet collectSet = 17; Percentile percentile = 18; + ListAgg listAgg = 19; } // Optional filter expression for SQL FILTER (WHERE ...) clause. @@ -277,6 +278,19 @@ message CollectSet { DataType datatype = 2; } +// Spark 4.0+ LISTAGG / STRING_AGG aggregate. +// +// Comet only serializes the simple form: a StringType child with a literal +// (or NULL) delimiter and no WITHIN GROUP ORDER BY. DISTINCT is handled by +// Spark's multi-stage plan rewrite before the aggregate reaches Comet. +message ListAgg { + Expr child = 1; + // Literal delimiter expression. NULL delimiter is normalized to empty string + // by Spark's semantics. + Expr delimiter = 2; + DataType datatype = 3; +} + enum EvalMode { LEGACY = 0; TRY = 1; diff --git a/native/spark-expr/src/agg_funcs/list_agg.rs b/native/spark-expr/src/agg_funcs/list_agg.rs new file mode 100644 index 0000000000..1bb7411052 --- /dev/null +++ b/native/spark-expr/src/agg_funcs/list_agg.rs @@ -0,0 +1,284 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Spark-compatible `listagg` / `string_agg` aggregate function. +//! +//! Implements the simple form of Spark 4.0's `LISTAGG(expr, delimiter)` (no +//! `WITHIN GROUP (ORDER BY ...)`, no DISTINCT β€” DISTINCT is rewritten into a +//! multi-stage plan by Spark before it reaches Comet). Differences from +//! DataFusion's `string_agg`: +//! +//! * Returns `Utf8` to match Spark's `StringType` result type; DataFusion's +//! `string_agg` returns `LargeUtf8`. +//! * A `NULL` delimiter is treated as the empty string (Spark treats `NULL` as +//! the default empty delimiter; the JVM serde forwards the literal as-is). +//! * The delimiter is read once from the accumulator args (a literal is +//! enforced by Spark's analyzer). +//! +//! The intermediate state is exposed as `Binary` because Spark's `ListAgg` is +//! a `TypedImperativeAggregate` whose Catalyst buffer schema is `BinaryType`. +//! Emitting `Utf8` here would force a Comet shuffle-side cast (`Utf8` β†’ +//! `Binary`) that the merge side then can no longer read. + +use std::hash::Hash; +use std::mem::size_of_val; + +use arrow::array::{ArrayRef, StringArray}; +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion::common::cast::{as_binary_array, as_string_array}; +use datafusion::common::{internal_datafusion_err, not_impl_err, Result, ScalarValue}; +use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion::logical_expr::utils::format_state_name; +use datafusion::logical_expr::{ + Accumulator, AggregateUDFImpl, Signature, TypeSignature, Volatility, +}; +use datafusion::physical_expr::expressions::Literal; + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkListAgg { + signature: Signature, +} + +impl Default for SparkListAgg { + fn default() -> Self { + Self::new() + } +} + +impl SparkListAgg { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]), + TypeSignature::Exact(vec![DataType::Utf8, DataType::Null]), + ], + Volatility::Immutable, + ), + } + } +} + +impl AggregateUDFImpl for SparkListAgg { + fn name(&self) -> &str { + "listagg" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Utf8) + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + // Spark's ListAgg is a TypedImperativeAggregate β€” Catalyst declares its + // intermediate buffer as `BinaryType`. Match that so Comet's shuffle + // layer doesn't have to insert a Utf8 -> Binary cast that the merge + // side then can't read back. + Ok(vec![Field::new( + format_state_name(args.name, "listagg"), + DataType::Binary, + true, + ) + .into()]) + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + let Some(lit) = (*acc_args.exprs[1]).downcast_ref::() else { + return not_impl_err!( + "listagg delimiter must be a literal; got {:?}", + acc_args.exprs[1] + ); + }; + let delimiter = if lit.value().is_null() { + String::new() + } else if let Some(s) = lit.value().try_as_str() { + s.unwrap_or("").to_string() + } else { + return not_impl_err!( + "listagg delimiter literal must be Utf8; got {:?}", + lit.value() + ); + }; + Ok(Box::new(ListAggAccumulator::new(delimiter))) + } +} + +#[derive(Debug)] +struct ListAggAccumulator { + delimiter: String, + accumulated: String, + has_value: bool, +} + +impl ListAggAccumulator { + fn new(delimiter: String) -> Self { + Self { + delimiter, + accumulated: String::new(), + has_value: false, + } + } + + #[inline] + fn append_values(&mut self, array: &StringArray) { + for value in array.iter().flatten() { + if self.has_value { + self.accumulated.push_str(&self.delimiter); + } + self.accumulated.push_str(value); + self.has_value = true; + } + } +} + +impl Accumulator for ListAggAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let array = values.first().ok_or_else(|| { + internal_datafusion_err!("listagg update_batch expected the values array") + })?; + self.append_values(as_string_array(array)?); + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let array = states.first().ok_or_else(|| { + internal_datafusion_err!("listagg merge_batch expected the state array") + })?; + // Partial state is emitted as `Binary` (see `state_fields`); each + // entry is UTF-8 bytes originally produced by another partition's + // accumulator. + let bin = as_binary_array(array)?; + for value in bin.iter().flatten() { + let s = std::str::from_utf8(value).map_err(|e| { + internal_datafusion_err!("listagg merge_batch got non-UTF-8 partial state: {e}") + })?; + if self.has_value { + self.accumulated.push_str(&self.delimiter); + } + self.accumulated.push_str(s); + self.has_value = true; + } + Ok(()) + } + + fn state(&mut self) -> Result> { + let value = if self.has_value { + ScalarValue::Binary(Some(std::mem::take(&mut self.accumulated).into_bytes())) + } else { + ScalarValue::Binary(None) + }; + self.has_value = false; + Ok(vec![value]) + } + + fn evaluate(&mut self) -> Result { + if self.has_value { + Ok(ScalarValue::Utf8(Some(self.accumulated.clone()))) + } else { + Ok(ScalarValue::Utf8(None)) + } + } + + fn size(&self) -> usize { + size_of_val(self) + self.delimiter.capacity() + self.accumulated.capacity() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{BinaryArray, StringArray}; + use std::sync::Arc; + + fn utf8(items: &[Option<&str>]) -> ArrayRef { + Arc::new(StringArray::from(items.to_vec())) + } + + fn some(items: &[&str]) -> ArrayRef { + Arc::new(StringArray::from(items.to_vec())) + } + + #[test] + fn joins_non_null_values_with_delimiter() -> Result<()> { + let mut acc = ListAggAccumulator::new(",".to_string()); + acc.update_batch(&[some(&["a", "b", "c"])])?; + let ScalarValue::Utf8(Some(s)) = acc.evaluate()? else { + panic!("expected Utf8"); + }; + assert_eq!(s, "a,b,c"); + Ok(()) + } + + #[test] + fn empty_delimiter_concatenates() -> Result<()> { + let mut acc = ListAggAccumulator::new(String::new()); + acc.update_batch(&[some(&["a", "b", "c"])])?; + let ScalarValue::Utf8(Some(s)) = acc.evaluate()? else { + panic!("expected Utf8"); + }; + assert_eq!(s, "abc"); + Ok(()) + } + + #[test] + fn skips_null_inputs() -> Result<()> { + let mut acc = ListAggAccumulator::new(",".to_string()); + acc.update_batch(&[utf8(&[Some("a"), None, Some("b"), None])])?; + let ScalarValue::Utf8(Some(s)) = acc.evaluate()? else { + panic!("expected Utf8"); + }; + assert_eq!(s, "a,b"); + Ok(()) + } + + #[test] + fn returns_null_on_all_null_or_empty_input() -> Result<()> { + let mut acc = ListAggAccumulator::new(",".to_string()); + acc.update_batch(&[utf8(&[None, None])])?; + assert!(matches!(acc.evaluate()?, ScalarValue::Utf8(None))); + + let mut empty = ListAggAccumulator::new(",".to_string()); + assert!(matches!(empty.evaluate()?, ScalarValue::Utf8(None))); + Ok(()) + } + + #[test] + fn merge_state_across_partitions() -> Result<()> { + let mut a = ListAggAccumulator::new(",".to_string()); + a.update_batch(&[some(&["a", "b"])])?; + let state_bytes = match a.state()?.remove(0) { + ScalarValue::Binary(Some(b)) => b, + other => panic!("unexpected state {other:?}"), + }; + + let mut b = ListAggAccumulator::new(",".to_string()); + b.update_batch(&[some(&["c", "d"])])?; + let partial_state: ArrayRef = + Arc::new(BinaryArray::from(vec![Some(state_bytes.as_slice())])); + b.merge_batch(&[partial_state])?; + + let ScalarValue::Utf8(Some(s)) = b.evaluate()? else { + panic!("expected Utf8"); + }; + // partition A's already-joined "a,b" is appended as one value. + assert_eq!(s, "c,d,a,b"); + Ok(()) + } +} diff --git a/native/spark-expr/src/agg_funcs/mod.rs b/native/spark-expr/src/agg_funcs/mod.rs index 2a0322e46c..d1d91faee4 100644 --- a/native/spark-expr/src/agg_funcs/mod.rs +++ b/native/spark-expr/src/agg_funcs/mod.rs @@ -19,6 +19,7 @@ mod avg; mod avg_decimal; mod correlation; mod covariance; +mod list_agg; mod stddev; mod sum_decimal; mod sum_int; @@ -29,6 +30,7 @@ pub use avg::Avg; pub use avg_decimal::AvgDecimal; pub use correlation::Correlation; pub use covariance::Covariance; +pub use list_agg::SparkListAgg; pub use stddev::Stddev; pub use sum_decimal::SumDecimal; pub use sum_int::SumInteger; diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 7146eaec9b..34c6ae350e 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -385,27 +385,30 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim { /** * Mapping of Spark aggregate expression class to Comet expression handler. */ - val aggrSerdeMap: Map[Class[_], CometAggregateExpressionSerde[_]] = Map( - classOf[Average] -> CometAverage, - classOf[BitAndAgg] -> CometBitAndAgg, - classOf[BitOrAgg] -> CometBitOrAgg, - classOf[BitXorAgg] -> CometBitXOrAgg, - classOf[BloomFilterAggregate] -> CometBloomFilterAggregate, - classOf[CollectSet] -> CometCollectSet, - classOf[Corr] -> CometCorr, - classOf[Count] -> CometCount, - classOf[CovPopulation] -> CometCovPopulation, - classOf[CovSample] -> CometCovSample, - classOf[First] -> CometFirst, - classOf[Last] -> CometLast, - classOf[Max] -> CometMax, - classOf[Min] -> CometMin, - classOf[Percentile] -> CometPercentile, - classOf[StddevPop] -> CometStddevPop, - classOf[StddevSamp] -> CometStddevSamp, - classOf[Sum] -> CometSum, - classOf[VariancePop] -> CometVariancePop, - classOf[VarianceSamp] -> CometVarianceSamp) + val aggrSerdeMap: Map[Class[_], CometAggregateExpressionSerde[_]] = { + val base: Map[Class[_], CometAggregateExpressionSerde[_]] = Map( + classOf[Average] -> CometAverage, + classOf[BitAndAgg] -> CometBitAndAgg, + classOf[BitOrAgg] -> CometBitOrAgg, + classOf[BitXorAgg] -> CometBitXOrAgg, + classOf[BloomFilterAggregate] -> CometBloomFilterAggregate, + classOf[CollectSet] -> CometCollectSet, + classOf[Corr] -> CometCorr, + classOf[Count] -> CometCount, + classOf[CovPopulation] -> CometCovPopulation, + classOf[CovSample] -> CometCovSample, + classOf[First] -> CometFirst, + classOf[Last] -> CometLast, + classOf[Max] -> CometMax, + classOf[Min] -> CometMin, + classOf[Percentile] -> CometPercentile, + classOf[StddevPop] -> CometStddevPop, + classOf[StddevSamp] -> CometStddevSamp, + classOf[Sum] -> CometSum, + classOf[VariancePop] -> CometVariancePop, + classOf[VarianceSamp] -> CometVarianceSamp) + base ++ sparkVersionSpecificAggregates + } /** * Returns true if all aggregate expressions in the list have intermediate buffer formats that diff --git a/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala index 1ad9ec75bf..43837f266f 100644 --- a/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.Sum import org.apache.comet.expressions.CometEvalMode -import org.apache.comet.serde.{CometExpressionSerde, CometStringDecode} +import org.apache.comet.serde.{CometAggregateExpressionSerde, CometExpressionSerde, CometStringDecode} import org.apache.comet.serde.ExprOuterClass.{BinaryOutputStyle, Expr} /** @@ -44,6 +44,8 @@ trait CometExprShim { Map.empty def sparkVersionSpecificMapExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map.empty + def sparkVersionSpecificAggregates: Map[Class[_], CometAggregateExpressionSerde[_]] = + Map.empty def sparkVersionSpecificExprToProtoInternal( expr: Expression, diff --git a/spark/src/main/spark-3.5/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-3.5/org/apache/comet/shims/CometExprShim.scala index 0be1185f59..42290ab565 100644 --- a/spark/src/main/spark-3.5/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-3.5/org/apache/comet/shims/CometExprShim.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.Sum import org.apache.comet.expressions.CometEvalMode -import org.apache.comet.serde.{CometExpressionSerde, CometStringDecode, CometToPrettyString, CometWidthBucket} +import org.apache.comet.serde.{CometAggregateExpressionSerde, CometExpressionSerde, CometStringDecode, CometToPrettyString, CometWidthBucket} import org.apache.comet.serde.ExprOuterClass.{BinaryOutputStyle, Expr} /** @@ -44,6 +44,8 @@ trait CometExprShim { Map(classOf[ToPrettyString] -> CometToPrettyString) def sparkVersionSpecificMapExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map.empty + def sparkVersionSpecificAggregates: Map[Class[_], CometAggregateExpressionSerde[_]] = + Map.empty def sparkVersionSpecificExprToProtoInternal( expr: Expression, diff --git a/spark/src/main/spark-4.x/org/apache/comet/serde/CometListAgg.scala b/spark/src/main/spark-4.x/org/apache/comet/serde/CometListAgg.scala new file mode 100644 index 0000000000..49b0f2a2ac --- /dev/null +++ b/spark/src/main/spark-4.x/org/apache/comet/serde/CometListAgg.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde + +import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, ListAgg} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{NullType, StringType} + +import org.apache.comet.CometSparkSessionExtensions.withFallbackReason +import org.apache.comet.serde.QueryPlanSerde.{exprToProto, hasNonDefaultStringCollation, serializeDataType} + +/** + * Spark 4.0+ `LISTAGG` / `STRING_AGG`. + * + * Comet only supports the simple form: a `StringType` child with a literal delimiter and no + * `WITHIN GROUP (ORDER BY ...)` clause. DISTINCT is handled by Spark's multi-stage plan rewrite + * (grouping by the child before the aggregate), so the native side never sees it. + */ +object CometListAgg extends CometAggregateExpressionSerde[ListAgg] { + + override def getUnsupportedReasons(): Seq[String] = Seq( + "`BinaryType` inputs are not supported.", + "`WITHIN GROUP (ORDER BY ...)` is not supported.", + "Non-literal delimiters are not supported.", + "Non-default string collations are not supported.", + "`DISTINCT` falls back to Spark because Comet rejects multi-column distinct aggregates.") + + override def getSupportLevel(expr: ListAgg): SupportLevel = { + // Spark enforces `delimiter.foldable` at analysis time, so a non-literal delimiter would + // fail before reaching us. Match only the two shapes we actually handle. + if (!expr.child.dataType.isInstanceOf[StringType]) { + return Unsupported(Some(s"Unsupported child data type: ${expr.child.dataType}")) + } + if (hasNonDefaultStringCollation(expr.child.dataType)) { + return Unsupported(Some("Non-default string collations are not supported")) + } + expr.delimiter.dataType match { + case _: StringType if hasNonDefaultStringCollation(expr.delimiter.dataType) => + return Unsupported(Some("Non-default string collations on delimiter are not supported")) + case _: StringType | _: NullType => // ok + case other => + return Unsupported(Some(s"Unsupported delimiter data type: $other")) + } + expr.delimiter match { + case _: Literal => // ok + case _ => return Unsupported(Some("Non-literal delimiters are not supported")) + } + if (expr.orderExpressions.nonEmpty) { + return Unsupported(Some("`WITHIN GROUP (ORDER BY ...)` is not supported")) + } + Compatible() + } + + override def convert( + aggExpr: AggregateExpression, + expr: ListAgg, + inputs: Seq[Attribute], + binding: Boolean, + conf: SQLConf): Option[ExprOuterClass.AggExpr] = { + val childExpr = exprToProto(expr.child, inputs, binding) + val delimiterExpr = exprToProto(expr.delimiter, inputs, binding) + val dataType = serializeDataType(expr.dataType) + + if (childExpr.isDefined && delimiterExpr.isDefined && dataType.isDefined) { + val builder = ExprOuterClass.ListAgg.newBuilder() + builder.setChild(childExpr.get) + builder.setDelimiter(delimiterExpr.get) + builder.setDatatype(dataType.get) + Some( + ExprOuterClass.AggExpr + .newBuilder() + .setListAgg(builder) + .build()) + } else if (dataType.isEmpty) { + withFallbackReason(aggExpr, s"datatype ${expr.dataType} is not supported", expr.child) + None + } else { + withFallbackReason(aggExpr, expr.child, expr.delimiter) + None + } + } +} diff --git a/spark/src/main/spark-4.x/org/apache/comet/shims/Spark4xCometExprShim.scala b/spark/src/main/spark-4.x/org/apache/comet/shims/Spark4xCometExprShim.scala index 7efd17f68a..a7d13090ee 100644 --- a/spark/src/main/spark-4.x/org/apache/comet/shims/Spark4xCometExprShim.scala +++ b/spark/src/main/spark-4.x/org/apache/comet/shims/Spark4xCometExprShim.scala @@ -20,6 +20,7 @@ package org.apache.comet.shims import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.ListAgg import org.apache.spark.sql.catalyst.expressions.json.{JsonExpressionUtils, StructsToJsonEvaluator} import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke} import org.apache.spark.sql.catalyst.expressions.url.ParseUrlEvaluator @@ -27,7 +28,7 @@ import org.apache.spark.sql.types.ArrayType import org.apache.comet.CometExplainInfo import org.apache.comet.expressions.CometEvalMode -import org.apache.comet.serde.{CometExpressionSerde, CometMapSort, CometToPrettyString, CometWidthBucket} +import org.apache.comet.serde.{CometAggregateExpressionSerde, CometExpressionSerde, CometListAgg, CometMapSort, CometToPrettyString, CometWidthBucket} import org.apache.comet.serde.ExprOuterClass.Expr import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithFallbackReason, scalarFunctionExprToProtoWithReturnType} @@ -49,6 +50,8 @@ trait Spark4xCometExprShim extends CometExprShim4x { Map(classOf[ToPrettyString] -> CometToPrettyString) def sparkVersionSpecificMapExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map(classOf[MapSort] -> CometMapSort) + def sparkVersionSpecificAggregates: Map[Class[_], CometAggregateExpressionSerde[_]] = + Map(classOf[ListAgg] -> CometListAgg) def sparkVersionSpecificExprToProtoInternal( expr: Expression, diff --git a/spark/src/test/resources/sql-tests/expressions/aggregate/listagg.sql b/spark/src/test/resources/sql-tests/expressions/aggregate/listagg.sql new file mode 100644 index 0000000000..37bc43ee89 --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/aggregate/listagg.sql @@ -0,0 +1,157 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- listagg / string_agg is available starting in Spark 4.0. +-- MinSparkVersion: 4.0 +-- ConfigMatrix: parquet.enable.dictionary=false,true + +-- ============================================================ +-- Setup +-- ============================================================ + +statement +CREATE TABLE la_src(v string, grp string) USING parquet + +statement +INSERT INTO la_src VALUES + ('a', 'g1'), ('b', 'g1'), ('c', 'g1'), + ('x', 'g2'), (NULL, 'g2'), ('y', 'g2'), + (NULL, 'g3'), (NULL, 'g3'), + ('', 'g4'), ('z', 'g4') + +statement +CREATE TABLE la_empty(v string) USING parquet + +statement +CREATE TABLE la_dupes(v string, grp string) USING parquet + +statement +INSERT INTO la_dupes VALUES + ('a', 'g1'), ('a', 'g1'), ('b', 'g1'), + ('c', 'g2'), ('c', 'g2'), ('c', 'g2') + +statement +CREATE TABLE la_utf8(v string, grp string) USING parquet + +-- Multibyte UTF-8: `cafΓ©` (composed Γ© U+00E9), `naΓ―ve` (composed Γ― U+00EF), +-- `ζ—₯本θͺž` (three CJK codepoints), and an emoji sequence. +statement +INSERT INTO la_utf8 VALUES + ('cafΓ©', 'g1'), ('naΓ―ve', 'g1'), ('ζ—₯本θͺž', 'g1'), + ('ν•œκΈ€', 'g2'), ('β˜•οΈ', 'g2'), (NULL, 'g2') + +statement +CREATE TABLE la_bin(v binary, grp string) USING parquet + +statement +INSERT INTO la_bin VALUES + (X'DEAD', 'g1'), (X'BEEF', 'g1'), (NULL, 'g1'), + (X'CAFE', 'g2') + +-- ============================================================ +-- Basic: literal delimiter, sort the group so results are +-- deterministic across shuffles. +-- ============================================================ + +query +SELECT grp, listagg(v, ',') FROM (SELECT * FROM la_src ORDER BY grp, v) GROUP BY grp ORDER BY grp + +-- Alias `string_agg` should route through the same expression class. +query +SELECT grp, string_agg(v, ',') FROM (SELECT * FROM la_src ORDER BY grp, v) GROUP BY grp ORDER BY grp + +-- ============================================================ +-- No delimiter (defaults to empty string / NULL delimiter) +-- ============================================================ + +query +SELECT grp, listagg(v) FROM (SELECT * FROM la_src ORDER BY grp, v) GROUP BY grp ORDER BY grp + +-- ============================================================ +-- All-NULL group returns NULL; empty table returns NULL. +-- ============================================================ + +query +SELECT listagg(v, ',') FROM la_empty + +query +SELECT grp, listagg(v, '-') FROM (SELECT * FROM la_src WHERE grp = 'g3' ORDER BY v) GROUP BY grp + +-- ============================================================ +-- DISTINCT falls back to Spark: Comet rejects multi-column +-- distinct aggregates (listagg has two children: value and +-- delimiter), so `listagg(DISTINCT v, ',')` is not run natively. +-- ============================================================ + +query expect_fallback(Multi-column distinct aggregate not supported) +SELECT grp, listagg(DISTINCT v, ',') FROM (SELECT * FROM la_dupes ORDER BY grp, v) GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Global aggregate (no GROUP BY) +-- ============================================================ + +query +SELECT listagg(v, '|') FROM (SELECT * FROM la_src WHERE grp = 'g1' ORDER BY v) + +-- ============================================================ +-- Multi-char delimiter and empty-string delimiter +-- ============================================================ + +query +SELECT grp, listagg(v, ' -> ') FROM (SELECT * FROM la_src WHERE grp = 'g1' ORDER BY v) GROUP BY grp + +query +SELECT grp, listagg(v, '') FROM (SELECT * FROM la_src WHERE grp = 'g1' ORDER BY v) GROUP BY grp + +-- ============================================================ +-- Empty-string values inside the group +-- ============================================================ + +query +SELECT grp, listagg(v, ',') FROM (SELECT * FROM la_src WHERE grp = 'g4' ORDER BY v) GROUP BY grp + +-- ============================================================ +-- Multibyte UTF-8 values (composed accents, CJK, emoji) +-- ============================================================ + +query +SELECT grp, listagg(v, ',') FROM (SELECT * FROM la_utf8 ORDER BY grp, v) GROUP BY grp ORDER BY grp + +-- Multi-byte delimiter +query +SELECT grp, listagg(v, 'Β·') FROM (SELECT * FROM la_utf8 ORDER BY grp, v) GROUP BY grp ORDER BY grp + +-- ============================================================ +-- WITHIN GROUP (ORDER BY ...) is not implemented natively; +-- both ascending and descending must fall back to Spark. +-- ============================================================ + +query expect_fallback(`WITHIN GROUP (ORDER BY ...)` is not supported) +SELECT grp, listagg(v, ',') WITHIN GROUP (ORDER BY v) FROM la_src GROUP BY grp ORDER BY grp + +query expect_fallback(`WITHIN GROUP (ORDER BY ...)` is not supported) +SELECT grp, listagg(v, ',') WITHIN GROUP (ORDER BY v DESC) FROM la_src GROUP BY grp ORDER BY grp + +-- ============================================================ +-- BinaryType inputs are not supported natively; must fall back. +-- ============================================================ + +query expect_fallback(Unsupported child data type: BinaryType) +SELECT grp, listagg(v) FROM la_bin GROUP BY grp ORDER BY grp + +query expect_fallback(Unsupported child data type: BinaryType) +SELECT grp, listagg(v, X'42') FROM la_bin GROUP BY grp ORDER BY grp From 46a3025e9b85c26729228395498b85561949fb6f Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 3 Jul 2026 09:28:45 -0600 Subject: [PATCH 2/2] refactor: hoist CometListAgg reason strings and use `.foldable` Consolidate the reason strings into `private val`s so `getSupportLevel` and `getUnsupportedReasons` reference the same source of truth, and replace the manual `case _: Literal` delimiter check with the standard `.foldable` gate used elsewhere (e.g. `CometPercentile`). --- .../org/apache/comet/serde/CometListAgg.scala | 61 +++++++++++-------- 1 file changed, 34 insertions(+), 27 deletions(-) diff --git a/spark/src/main/spark-4.x/org/apache/comet/serde/CometListAgg.scala b/spark/src/main/spark-4.x/org/apache/comet/serde/CometListAgg.scala index 49b0f2a2ac..e4ad8a917d 100644 --- a/spark/src/main/spark-4.x/org/apache/comet/serde/CometListAgg.scala +++ b/spark/src/main/spark-4.x/org/apache/comet/serde/CometListAgg.scala @@ -19,7 +19,7 @@ package org.apache.comet.serde -import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal} +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, ListAgg} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{NullType, StringType} @@ -36,37 +36,44 @@ import org.apache.comet.serde.QueryPlanSerde.{exprToProto, hasNonDefaultStringCo */ object CometListAgg extends CometAggregateExpressionSerde[ListAgg] { + private val binaryChildReason = "`BinaryType` inputs are not supported." + private val withinGroupReason = "`WITHIN GROUP (ORDER BY ...)` is not supported." + private val nonFoldableDelimiterReason = "Non-literal delimiters are not supported." + private val collationReason = "Non-default string collations are not supported." + private val distinctReason = + "`DISTINCT` falls back to Spark because Comet rejects multi-column distinct aggregates." + override def getUnsupportedReasons(): Seq[String] = Seq( - "`BinaryType` inputs are not supported.", - "`WITHIN GROUP (ORDER BY ...)` is not supported.", - "Non-literal delimiters are not supported.", - "Non-default string collations are not supported.", - "`DISTINCT` falls back to Spark because Comet rejects multi-column distinct aggregates.") + binaryChildReason, + withinGroupReason, + nonFoldableDelimiterReason, + collationReason, + distinctReason) override def getSupportLevel(expr: ListAgg): SupportLevel = { - // Spark enforces `delimiter.foldable` at analysis time, so a non-literal delimiter would - // fail before reaching us. Match only the two shapes we actually handle. - if (!expr.child.dataType.isInstanceOf[StringType]) { - return Unsupported(Some(s"Unsupported child data type: ${expr.child.dataType}")) - } - if (hasNonDefaultStringCollation(expr.child.dataType)) { - return Unsupported(Some("Non-default string collations are not supported")) - } - expr.delimiter.dataType match { - case _: StringType if hasNonDefaultStringCollation(expr.delimiter.dataType) => - return Unsupported(Some("Non-default string collations on delimiter are not supported")) - case _: StringType | _: NullType => // ok + // Spark's analyzer already enforces `delimiter.foldable`, so this only ever rejects + // non-string / non-null delimiter types. + expr.child.dataType match { + case _: StringType if hasNonDefaultStringCollation(expr.child.dataType) => + Unsupported(Some(collationReason)) + case _: StringType => + expr.delimiter.dataType match { + case _: StringType if hasNonDefaultStringCollation(expr.delimiter.dataType) => + Unsupported(Some(collationReason)) + case _: StringType | _: NullType => + if (!expr.delimiter.foldable) { + Unsupported(Some(nonFoldableDelimiterReason)) + } else if (expr.orderExpressions.nonEmpty) { + Unsupported(Some(withinGroupReason)) + } else { + Compatible() + } + case other => + Unsupported(Some(s"Unsupported delimiter data type: $other")) + } case other => - return Unsupported(Some(s"Unsupported delimiter data type: $other")) - } - expr.delimiter match { - case _: Literal => // ok - case _ => return Unsupported(Some("Non-literal delimiters are not supported")) - } - if (expr.orderExpressions.nonEmpty) { - return Unsupported(Some("`WITHIN GROUP (ORDER BY ...)` is not supported")) + Unsupported(Some(s"Unsupported child data type: $other")) } - Compatible() } override def convert(