Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/source/contributor-guide/expression-audits/agg_funcs.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@
- Spark 3.5.8 (2026-05-26)
- Spark 4.0.1 (2026-05-26)

## listagg

- Spark 3.4.3 (audited 2026-07-03): does not exist. `ListAgg` was added in Spark 4.0.
- Spark 3.5.8 (audited 2026-07-03): does not exist.
- Spark 4.0.1 (audited 2026-07-03): `ListAgg(child, delimiter, orderExpressions)` in `aggregate/collect.scala`. Accepts `StringType` or `BinaryType` inputs; result type matches child. Skips nulls; empty or all-null groups return `NULL`. A `NULL` delimiter is treated as an empty string. `CometListAgg` maps only the simple form: `StringType` child with a literal `StringType`/`NullType` delimiter and no `WITHIN GROUP`. `BinaryType` inputs, `WITHIN GROUP (ORDER BY ...)`, non-literal delimiters, and non-default collations fall back to Spark. `DISTINCT` falls back because Comet rejects multi-column distinct aggregates (`ListAgg` has two children).
- Spark 4.1.1 (audited 2026-07-03): byte-identical to 4.0.1.
- Native accumulator (`SparkListAgg`) returns `Utf8` but keeps its intermediate state as `Binary`, matching Spark's `TypedImperativeAggregate` buffer schema so the Comet shuffle layer does not insert a `Utf8` → `Binary` cast the merge side cannot read back.

## median

- Spark 3.4.3 (audited 2026-06-24): `Median(child)` is a `RuntimeReplaceableAggregate` with `replacement = Percentile(child, Literal(0.5))`. Catalyst rewrites `median(x)` to `percentile(x, 0.5)` before Comet sees the plan, so it is served by `CometPercentile`.
Expand Down
4 changes: 2 additions & 2 deletions docs/source/user-guide/latest/expressions.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ The tables below list every Spark built-in expression with its current status.
| `kurtosis` | 🔜 | tracking [#4098](https://github.com/apache/datafusion-comet/issues/4098) |
| `last` | ✅ | |
| `last_value` | ✅ | |
| `listagg` | 🔜 | String aggregation |
| `listagg` | | Spark 4.0+. `StringType` input with a literal delimiter; `WITHIN GROUP (ORDER BY ...)` and `BinaryType` inputs fall back to Spark. |
| `max` | ✅ | |
| `max_by` | 🔜 | [#3841](https://github.com/apache/datafusion-comet/issues/3841) |
| `mean` | ✅ | |
Expand All @@ -119,7 +119,7 @@ The tables below list every Spark built-in expression with its current status.
| `stddev` | ✅ | |
| `stddev_pop` | ✅ | |
| `stddev_samp` | ✅ | |
| `string_agg` | 🔜 | String aggregation (alias of `listagg`) |
| `string_agg` | | Alias of `listagg`; same restrictions apply. |
| `sum` | ✅ | |
| `try_avg` | ✅ | Interval types fall back |
| `try_sum` | ✅ | |
Expand Down
9 changes: 8 additions & 1 deletion native/core/src/execution/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ use datafusion::{
use datafusion_comet_spark_expr::{
create_comet_physical_fun, create_comet_physical_fun_with_eval_mode, BinaryOutputStyle,
BloomFilterAgg, BloomFilterMightContain, CsvWriteOptions, EvalMode, SparkArraysZipFunc,
SparkBloomFilterVersion, SumInteger, ToCsv,
SparkBloomFilterVersion, SparkListAgg, SumInteger, ToCsv,
};
use datafusion_spark::function::aggregate::collect::SparkCollectSet;
use iceberg::expr::Bind;
Expand Down Expand Up @@ -2653,6 +2653,13 @@ impl PhysicalPlanner {
let func = AggregateUDF::new_from_impl(SparkCollectSet::new());
Self::create_aggr_func_expr("collect_set", schema, vec![child], func)
}
AggExprStruct::ListAgg(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?;
let delimiter =
self.create_expr(expr.delimiter.as_ref().unwrap(), Arc::clone(&schema))?;
let func = AggregateUDF::new_from_impl(SparkListAgg::new());
Self::create_aggr_func_expr("listagg", schema, vec![child, delimiter], func)
}
}
}

Expand Down
14 changes: 14 additions & 0 deletions native/proto/src/proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ message AggExpr {
BloomFilterAgg bloomFilterAgg = 16;
CollectSet collectSet = 17;
Percentile percentile = 18;
ListAgg listAgg = 19;
}

// Optional filter expression for SQL FILTER (WHERE ...) clause.
Expand Down Expand Up @@ -277,6 +278,19 @@ message CollectSet {
DataType datatype = 2;
}

// Spark 4.0+ LISTAGG / STRING_AGG aggregate.
//
// Comet only serializes the simple form: a StringType child with a literal
// (or NULL) delimiter and no WITHIN GROUP ORDER BY. DISTINCT is handled by
// Spark's multi-stage plan rewrite before the aggregate reaches Comet.
message ListAgg {
Expr child = 1;
// Literal delimiter expression. NULL delimiter is normalized to empty string
// by Spark's semantics.
Expr delimiter = 2;
DataType datatype = 3;
}

enum EvalMode {
LEGACY = 0;
TRY = 1;
Expand Down
284 changes: 284 additions & 0 deletions native/spark-expr/src/agg_funcs/list_agg.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,284 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Spark-compatible `listagg` / `string_agg` aggregate function.
//!
//! Implements the simple form of Spark 4.0's `LISTAGG(expr, delimiter)` (no
//! `WITHIN GROUP (ORDER BY ...)`, no DISTINCT — DISTINCT is rewritten into a
//! multi-stage plan by Spark before it reaches Comet). Differences from
//! DataFusion's `string_agg`:
//!
//! * Returns `Utf8` to match Spark's `StringType` result type; DataFusion's
//! `string_agg` returns `LargeUtf8`.
//! * A `NULL` delimiter is treated as the empty string (Spark treats `NULL` as
//! the default empty delimiter; the JVM serde forwards the literal as-is).
//! * The delimiter is read once from the accumulator args (a literal is
//! enforced by Spark's analyzer).
//!
//! The intermediate state is exposed as `Binary` because Spark's `ListAgg` is
//! a `TypedImperativeAggregate` whose Catalyst buffer schema is `BinaryType`.
//! Emitting `Utf8` here would force a Comet shuffle-side cast (`Utf8` →
//! `Binary`) that the merge side then can no longer read.

use std::hash::Hash;
use std::mem::size_of_val;

use arrow::array::{ArrayRef, StringArray};
use arrow::datatypes::{DataType, Field, FieldRef};
use datafusion::common::cast::{as_binary_array, as_string_array};
use datafusion::common::{internal_datafusion_err, not_impl_err, Result, ScalarValue};
use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs};
use datafusion::logical_expr::utils::format_state_name;
use datafusion::logical_expr::{
Accumulator, AggregateUDFImpl, Signature, TypeSignature, Volatility,
};
use datafusion::physical_expr::expressions::Literal;

#[derive(Debug, PartialEq, Eq, Hash)]
pub struct SparkListAgg {
signature: Signature,
}

impl Default for SparkListAgg {
fn default() -> Self {
Self::new()
}
}

impl SparkListAgg {
pub fn new() -> Self {
Self {
signature: Signature::one_of(
vec![
TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]),
TypeSignature::Exact(vec![DataType::Utf8, DataType::Null]),
],
Volatility::Immutable,
),
}
}
}

impl AggregateUDFImpl for SparkListAgg {
fn name(&self) -> &str {
"listagg"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Utf8)
}

fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
// Spark's ListAgg is a TypedImperativeAggregate — Catalyst declares its
// intermediate buffer as `BinaryType`. Match that so Comet's shuffle
// layer doesn't have to insert a Utf8 -> Binary cast that the merge
// side then can't read back.
Ok(vec![Field::new(
format_state_name(args.name, "listagg"),
DataType::Binary,
true,
)
.into()])
}

fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
let Some(lit) = (*acc_args.exprs[1]).downcast_ref::<Literal>() else {
return not_impl_err!(
"listagg delimiter must be a literal; got {:?}",
acc_args.exprs[1]
);
};
let delimiter = if lit.value().is_null() {
String::new()
} else if let Some(s) = lit.value().try_as_str() {
s.unwrap_or("").to_string()
} else {
return not_impl_err!(
"listagg delimiter literal must be Utf8; got {:?}",
lit.value()
);
};
Ok(Box::new(ListAggAccumulator::new(delimiter)))
}
}

#[derive(Debug)]
struct ListAggAccumulator {
delimiter: String,
accumulated: String,
has_value: bool,
}

impl ListAggAccumulator {
fn new(delimiter: String) -> Self {
Self {
delimiter,
accumulated: String::new(),
has_value: false,
}
}

#[inline]
fn append_values(&mut self, array: &StringArray) {
for value in array.iter().flatten() {
if self.has_value {
self.accumulated.push_str(&self.delimiter);
}
self.accumulated.push_str(value);
self.has_value = true;
}
}
}

impl Accumulator for ListAggAccumulator {
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let array = values.first().ok_or_else(|| {
internal_datafusion_err!("listagg update_batch expected the values array")
})?;
self.append_values(as_string_array(array)?);
Ok(())
}

fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
let array = states.first().ok_or_else(|| {
internal_datafusion_err!("listagg merge_batch expected the state array")
})?;
// Partial state is emitted as `Binary` (see `state_fields`); each
// entry is UTF-8 bytes originally produced by another partition's
// accumulator.
let bin = as_binary_array(array)?;
for value in bin.iter().flatten() {
let s = std::str::from_utf8(value).map_err(|e| {
internal_datafusion_err!("listagg merge_batch got non-UTF-8 partial state: {e}")
})?;
if self.has_value {
self.accumulated.push_str(&self.delimiter);
}
self.accumulated.push_str(s);
self.has_value = true;
}
Ok(())
}

fn state(&mut self) -> Result<Vec<ScalarValue>> {
let value = if self.has_value {
ScalarValue::Binary(Some(std::mem::take(&mut self.accumulated).into_bytes()))
} else {
ScalarValue::Binary(None)
};
self.has_value = false;
Ok(vec![value])
}

fn evaluate(&mut self) -> Result<ScalarValue> {
if self.has_value {
Ok(ScalarValue::Utf8(Some(self.accumulated.clone())))
} else {
Ok(ScalarValue::Utf8(None))
}
}

fn size(&self) -> usize {
size_of_val(self) + self.delimiter.capacity() + self.accumulated.capacity()
}
}

#[cfg(test)]
mod tests {
use super::*;
use arrow::array::{BinaryArray, StringArray};
use std::sync::Arc;

fn utf8(items: &[Option<&str>]) -> ArrayRef {
Arc::new(StringArray::from(items.to_vec()))
}

fn some(items: &[&str]) -> ArrayRef {
Arc::new(StringArray::from(items.to_vec()))
}

#[test]
fn joins_non_null_values_with_delimiter() -> Result<()> {
let mut acc = ListAggAccumulator::new(",".to_string());
acc.update_batch(&[some(&["a", "b", "c"])])?;
let ScalarValue::Utf8(Some(s)) = acc.evaluate()? else {
panic!("expected Utf8");
};
assert_eq!(s, "a,b,c");
Ok(())
}

#[test]
fn empty_delimiter_concatenates() -> Result<()> {
let mut acc = ListAggAccumulator::new(String::new());
acc.update_batch(&[some(&["a", "b", "c"])])?;
let ScalarValue::Utf8(Some(s)) = acc.evaluate()? else {
panic!("expected Utf8");
};
assert_eq!(s, "abc");
Ok(())
}

#[test]
fn skips_null_inputs() -> Result<()> {
let mut acc = ListAggAccumulator::new(",".to_string());
acc.update_batch(&[utf8(&[Some("a"), None, Some("b"), None])])?;
let ScalarValue::Utf8(Some(s)) = acc.evaluate()? else {
panic!("expected Utf8");
};
assert_eq!(s, "a,b");
Ok(())
}

#[test]
fn returns_null_on_all_null_or_empty_input() -> Result<()> {
let mut acc = ListAggAccumulator::new(",".to_string());
acc.update_batch(&[utf8(&[None, None])])?;
assert!(matches!(acc.evaluate()?, ScalarValue::Utf8(None)));

let mut empty = ListAggAccumulator::new(",".to_string());
assert!(matches!(empty.evaluate()?, ScalarValue::Utf8(None)));
Ok(())
}

#[test]
fn merge_state_across_partitions() -> Result<()> {
let mut a = ListAggAccumulator::new(",".to_string());
a.update_batch(&[some(&["a", "b"])])?;
let state_bytes = match a.state()?.remove(0) {
ScalarValue::Binary(Some(b)) => b,
other => panic!("unexpected state {other:?}"),
};

let mut b = ListAggAccumulator::new(",".to_string());
b.update_batch(&[some(&["c", "d"])])?;
let partial_state: ArrayRef =
Arc::new(BinaryArray::from(vec![Some(state_bytes.as_slice())]));
b.merge_batch(&[partial_state])?;

let ScalarValue::Utf8(Some(s)) = b.evaluate()? else {
panic!("expected Utf8");
};
// partition A's already-joined "a,b" is appended as one value.
assert_eq!(s, "c,d,a,b");
Ok(())
}
}
2 changes: 2 additions & 0 deletions native/spark-expr/src/agg_funcs/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ mod avg;
mod avg_decimal;
mod correlation;
mod covariance;
mod list_agg;
mod stddev;
mod sum_decimal;
mod sum_int;
Expand All @@ -29,6 +30,7 @@ pub use avg::Avg;
pub use avg_decimal::AvgDecimal;
pub use correlation::Correlation;
pub use covariance::Covariance;
pub use list_agg::SparkListAgg;
pub use stddev::Stddev;
pub use sum_decimal::SumDecimal;
pub use sum_int::SumInteger;
Expand Down
Loading
Loading