From 04f78db8f19e88cf9a368c81e1a09360b607660d Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 2 Jul 2026 17:39:52 -0600 Subject: [PATCH 01/12] feat: add DataSketches HLL wrapper and Spark compatibility spike Wrap the pure-Rust datasketches crate's HLL_8 sketch/union behind SparkHllSketch/SparkHllUnion, hashing inputs via the crate's hash_value wrappers (raw_bytes, sign_extend) so MurmurHash3-x64-128 input matches DataSketches-Java. Verified cross-engine: Comet-produced sketches are byte-identical to Spark hll_sketch_agg output for HLL-array mode and mutually readable for low-cardinality List/Set mode. --- native/Cargo.lock | 7 + native/Cargo.toml | 1 + native/spark-expr/Cargo.toml | 1 + native/spark-expr/src/agg_funcs/hll_sketch.rs | 209 ++++++++++++++++++ native/spark-expr/src/agg_funcs/mod.rs | 2 + .../testdata/hll_sketch_spark_lgk12.bin | Bin 0 -> 4136 bytes 6 files changed, 220 insertions(+) create mode 100644 native/spark-expr/src/agg_funcs/hll_sketch.rs create mode 100644 native/spark-expr/src/agg_funcs/testdata/hll_sketch_spark_lgk12.bin diff --git a/native/Cargo.lock b/native/Cargo.lock index adb764fbfb..4fc27f0299 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -2141,6 +2141,7 @@ dependencies = [ "datafusion", "datafusion-comet-common", "datafusion-comet-jni-bridge", + "datasketches", "futures", "hex", "jni 0.22.4", @@ -2735,6 +2736,12 @@ dependencies = [ "sqlparser", ] +[[package]] +name = "datasketches" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46c4cf71a36b46dcfc00e5014c0c20ccad2b1b6a008304d7d57d2749b2d41b3d" + [[package]] name = "debugid" version = "0.8.0" diff --git a/native/Cargo.toml b/native/Cargo.toml index 3e797eb968..0c79dad57d 100644 --- a/native/Cargo.toml +++ b/native/Cargo.toml @@ -49,6 +49,7 @@ datafusion-comet-proto = { path = "proto" } datafusion-comet-shuffle = { path = "shuffle" } chrono = { version = "0.4", default-features = false, features = ["clock"] } chrono-tz = { version = "0.10" } +datasketches = { version = "0.3.0", features = ["hll"] } futures = "0.3.32" num = "0.4" rand = "0.10" diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml index 800fe3ecb1..d54094ff40 100644 --- a/native/spark-expr/Cargo.toml +++ b/native/spark-expr/Cargo.toml @@ -44,6 +44,7 @@ twox-hash = "2.1.2" rand = { workspace = true } hex = "0.4.3" base64 = "0.22.1" +datasketches = { workspace = true } [dev-dependencies] arrow = {workspace = true} diff --git a/native/spark-expr/src/agg_funcs/hll_sketch.rs b/native/spark-expr/src/agg_funcs/hll_sketch.rs new file mode 100644 index 0000000000..5e40c54aae --- /dev/null +++ b/native/spark-expr/src/agg_funcs/hll_sketch.rs @@ -0,0 +1,209 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Thin wrapper over the `datasketches` crate's HLL sketch, isolating all +//! crate-specific API so Comet's aggregate/scalar code depends on a stable +//! surface. Every sketch uses `HllType::Hll8` and DataSketches' +//! `DEFAULT_UPDATE_SEED` (9001), matching Spark's `HllSketchAgg`. +//! +//! Input hashing goes through the crate's `hash_value` wrappers +//! (`raw_bytes` for strings/binary without Rust's length prefix, `sign_extend` +//! for narrow integers) so the MurmurHash3-x64-128 input bytes are identical to +//! DataSketches-Java. This makes the sketches mutually readable with Spark. +//! +//! Note: the crate serializes List/Set (low-cardinality) modes in DataSketches +//! *compact* form, whereas Spark emits the *updatable* form. The bytes are +//! therefore not byte-identical to Spark's output for small inputs, but +//! DataSketches `deserialize` reads both forms, so estimates round-trip in both +//! directions. Comet must own both Partial and Final aggregation +//! (`supportsMixedPartialFinal = false`) so this compact intermediate is only +//! ever read back by Comet. + +use datafusion::error::DataFusionError; +use datasketches::hash_value::{raw_bytes, sign_extend}; +use datasketches::hll::{HllSketch, HllType, HllUnion}; + +/// A DataSketches HLL_8 sketch configured to match Spark's `HllSketchAgg`. +#[derive(Debug)] +pub struct SparkHllSketch { + inner: HllSketch, +} + +impl SparkHllSketch { + /// Create an empty HLL_8 sketch with the given `lgConfigK`. + pub fn new(lg_config_k: u8) -> Self { + Self { + inner: HllSketch::new(lg_config_k, HllType::Hll8), + } + } + + /// Update with a 64-bit integer. Spark widens narrower integrals to `long` + /// before hashing; callers should pass the already-widened value here. + /// Rust's `Hash` for `i64` writes 8 little-endian bytes with no prefix, + /// matching DataSketches-Java `update(long)`. + pub fn update_i64(&mut self, v: i64) { + self.inner.update(v); + } + + /// Update with a narrow signed integer, sign-extending to 64 bits exactly as + /// Spark's `toLong` does before hashing. + pub fn update_i32(&mut self, v: i32) { + self.inner.update(sign_extend::from_i32(v)); + } + pub fn update_i16(&mut self, v: i16) { + self.inner.update(sign_extend::from_i16(v)); + } + pub fn update_i8(&mut self, v: i8) { + self.inner.update(sign_extend::from_i8(v)); + } + + /// Update with raw bytes (used for both StringType UTF-8 bytes and + /// BinaryType), hashing without Rust's slice length prefix. Empty inputs are + /// skipped, matching DataSketches (and Spark), which ignore empty values. + pub fn update_bytes(&mut self, v: &[u8]) { + if v.is_empty() { + return; + } + self.inner.update(raw_bytes::from_slice(v)); + } + + /// Serialize to DataSketches bytes (compact for List/Set modes, full for HLL + /// array modes). Readable by Spark's `hll_sketch_estimate` / `hll_union_agg`. + pub fn to_sketch_bytes(&self) -> Vec { + self.inner.serialize() + } + + /// Deserialize a DataSketches sketch (either compact or updatable form). + pub fn from_bytes(bytes: &[u8]) -> Result { + HllSketch::deserialize(bytes) + .map(|inner| Self { inner }) + .map_err(|e| DataFusionError::Internal(format!("invalid HLL sketch bytes: {e}"))) + } + + /// The configured `lgConfigK`. + pub fn lg_config_k(&self) -> u8 { + self.inner.lg_config_k() + } + + /// Raw cardinality estimate (caller rounds to `i64` for Spark). + pub fn estimate(&self) -> f64 { + self.inner.estimate() + } + + /// Merge another sketch into this one via a union, keeping HLL_8 output. + pub fn merge_sketch(&mut self, other: &SparkHllSketch) { + let mut u = HllUnion::new(self.lg_config_k()); + u.update(&self.inner); + u.update(&other.inner); + self.inner = u.to_sketch(HllType::Hll8); + } +} + +/// A DataSketches HLL union configured to match Spark's `HllUnionAgg`. +#[derive(Debug)] +pub struct SparkHllUnion { + inner: HllUnion, +} + +impl SparkHllUnion { + /// Create an empty union with the given `lgMaxK` (Spark fixes this at 12). + pub fn new(lg_max_k: u8) -> Self { + Self { + inner: HllUnion::new(lg_max_k), + } + } + + /// Merge a sketch into the union. + pub fn merge(&mut self, sketch: &SparkHllSketch) { + self.inner.update(&sketch.inner); + } + + /// The union result as an HLL_8 sketch's serialized bytes. + pub fn to_sketch_bytes(&self) -> Vec { + self.inner.to_sketch(HllType::Hll8).serialize() + } +} + +/// Estimate the distinct count from serialized sketch bytes, rounded to the +/// nearest `i64` (Spark's `hll_sketch_estimate` returns a `Long`). +pub fn estimate_from_bytes(bytes: &[u8]) -> Result { + let sketch = SparkHllSketch::from_bytes(bytes)?; + Ok(sketch.estimate().round() as i64) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn sketch_roundtrips_and_estimates() { + let mut s = SparkHllSketch::new(12); + for i in 0..1000i64 { + s.update_i64(i); + } + let bytes = s.to_sketch_bytes(); + let est = estimate_from_bytes(&bytes).unwrap(); + assert!((est - 1000).abs() <= 30, "estimate {est} not within 3% of 1000"); + } + + #[test] + fn union_merges_two_sketches() { + let mut a = SparkHllSketch::new(12); + for i in 0..1000i64 { + a.update_i64(i); + } + let mut b = SparkHllSketch::new(12); + for i in 500..1500i64 { + b.update_i64(i); + } + let mut u = SparkHllUnion::new(12); + u.merge(&a); + u.merge(&b); + let est = estimate_from_bytes(&u.to_sketch_bytes()).unwrap(); + assert!((est - 1500).abs() <= 45, "union estimate {est} not within 3% of 1500"); + } + + /// A sketch built from raw bytes (StringType/BinaryType path) round-trips and + /// estimates. Empty inputs are skipped, so they do not affect the estimate. + #[test] + fn byte_input_roundtrips_and_estimates() { + let mut s = SparkHllSketch::new(12); + for i in 0..1000i64 { + s.update_bytes(format!("val-{i}").as_bytes()); + } + s.update_bytes(b""); // skipped, no effect + let est = estimate_from_bytes(&s.to_sketch_bytes()).unwrap(); + assert!((est - 1000).abs() <= 30, "estimate {est} not within 3% of 1000"); + } + + /// Cross-engine regression guard: `testdata/hll_sketch_spark_lgk12.bin` was + /// produced by Spark 3.5's `hll_sketch_agg(id)` over `range(0, 1000)`. Comet + /// must read it and estimate the distinct count, proving the crate's + /// serialization stays DataSketches-Java compatible across crate upgrades. + /// (For this HLL_8 input the Comet-produced bytes are byte-identical to + /// Spark's; low-cardinality List/Set sketches differ in bytes but remain + /// mutually readable.) + #[test] + fn reads_spark_produced_sketch() { + let bytes = include_bytes!("testdata/hll_sketch_spark_lgk12.bin"); + let est = estimate_from_bytes(bytes).unwrap(); + assert!( + (est - 1000).abs() <= 30, + "estimate {est} of Spark-produced sketch not within 3% of 1000" + ); + } +} diff --git a/native/spark-expr/src/agg_funcs/mod.rs b/native/spark-expr/src/agg_funcs/mod.rs index 2a0322e46c..57af25b35e 100644 --- a/native/spark-expr/src/agg_funcs/mod.rs +++ b/native/spark-expr/src/agg_funcs/mod.rs @@ -19,6 +19,7 @@ mod avg; mod avg_decimal; mod correlation; mod covariance; +mod hll_sketch; mod stddev; mod sum_decimal; mod sum_int; @@ -29,6 +30,7 @@ pub use avg::Avg; pub use avg_decimal::AvgDecimal; pub use correlation::Correlation; pub use covariance::Covariance; +pub use hll_sketch::{estimate_from_bytes, SparkHllSketch, SparkHllUnion}; pub use stddev::Stddev; pub use sum_decimal::SumDecimal; pub use sum_int::SumInteger; diff --git a/native/spark-expr/src/agg_funcs/testdata/hll_sketch_spark_lgk12.bin b/native/spark-expr/src/agg_funcs/testdata/hll_sketch_spark_lgk12.bin new file mode 100644 index 0000000000000000000000000000000000000000..761477d5f69f2b44be248c34f56667948df43310 GIT binary patch literal 4136 zcmZWs+mX~j4D@UFs4KqtClewd1geOE1jv9YK1hNj>6bL}GH~{c9<7!}V|(BCw~yPl zy??(xef{zI*B8ux{`vNoca7iFVjR2gz8>2-mo>M|!R1x5W21E6jwhwEkO2T15)=HF zc)JR@wAZpL%yloL23y___#q}OPK1wz9o~%4-tj7DYn1{??fkb}d4{G2y?0w4}Yy)g%jM&6G^1U(Xa4I1dCO-8g@d1i6F9! z+J+NIBGe5R7N`2NOjWW5j{~DqmgS6IGM6LkdrK0R)W^JkwQaePuM7BK47d^?ua}V2rve0txIpt1Eb8TmBkB$+g@HH`9YES zR!d^4(>emjQe|Ryu|m4wjAAUt*kraTb^{Dns24g^!AqwE=5;41eMn?ffucY)jwU&= z*6|7;DuyrCcOR%ho*ZjLZ^tB>cubS&0r}j{aNrH1xOx7;GF-=a3T1|Rc{u>7QQ(JO zWkA)IFi>S4-KzBG08FpqkL_PoXTN)MA zOcI;SvEC8Nd~}obGLju7 zC&+28I1bgtB8Ttv^93Xj`LWrb>MN4iE*iRQlX$xrrxnLQP zNdvDqGIPG!A(h>@4PRyaQ!Mig$K05i%ca&|Ypgy&5FP<@J}6 z@Is%^qNy=qO4A2_#CO zQt7b7D2@G`GUx8Q^_#egVXx^aiOlu(Ahus4K*6}&NdAsvCWz=?Z2PI8G`@YJnw<&J c+9;Xmtur1CRPLc&ktAExT_@w26MgXa4=a5PVE_OC literal 0 HcmV?d00001 From 501e8ab7a0f3b2cbfc7d0beecc45a0c093adf3b0 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 2 Jul 2026 17:45:15 -0600 Subject: [PATCH 02/12] feat: add version-specific aggregate serde registration hook --- .../apache/comet/serde/QueryPlanSerde.scala | 45 ++++++++++--------- .../apache/comet/shims/CometExprShim.scala | 4 +- .../apache/comet/shims/CometExprShim.scala | 4 +- .../comet/shims/Spark4xCometExprShim.scala | 4 +- 4 files changed, 33 insertions(+), 24 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index b752f41d74..379e695192 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -384,27 +384,30 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim { /** * Mapping of Spark aggregate expression class to Comet expression handler. */ - val aggrSerdeMap: Map[Class[_], CometAggregateExpressionSerde[_]] = Map( - classOf[Average] -> CometAverage, - classOf[BitAndAgg] -> CometBitAndAgg, - classOf[BitOrAgg] -> CometBitOrAgg, - classOf[BitXorAgg] -> CometBitXOrAgg, - classOf[BloomFilterAggregate] -> CometBloomFilterAggregate, - classOf[CollectSet] -> CometCollectSet, - classOf[Corr] -> CometCorr, - classOf[Count] -> CometCount, - classOf[CovPopulation] -> CometCovPopulation, - classOf[CovSample] -> CometCovSample, - classOf[First] -> CometFirst, - classOf[Last] -> CometLast, - classOf[Max] -> CometMax, - classOf[Min] -> CometMin, - classOf[Percentile] -> CometPercentile, - classOf[StddevPop] -> CometStddevPop, - classOf[StddevSamp] -> CometStddevSamp, - classOf[Sum] -> CometSum, - classOf[VariancePop] -> CometVariancePop, - classOf[VarianceSamp] -> CometVarianceSamp) + val aggrSerdeMap: Map[Class[_], CometAggregateExpressionSerde[_]] = { + val base: Map[Class[_], CometAggregateExpressionSerde[_]] = Map( + classOf[Average] -> CometAverage, + classOf[BitAndAgg] -> CometBitAndAgg, + classOf[BitOrAgg] -> CometBitOrAgg, + classOf[BitXorAgg] -> CometBitXOrAgg, + classOf[BloomFilterAggregate] -> CometBloomFilterAggregate, + classOf[CollectSet] -> CometCollectSet, + classOf[Corr] -> CometCorr, + classOf[Count] -> CometCount, + classOf[CovPopulation] -> CometCovPopulation, + classOf[CovSample] -> CometCovSample, + classOf[First] -> CometFirst, + classOf[Last] -> CometLast, + classOf[Max] -> CometMax, + classOf[Min] -> CometMin, + classOf[Percentile] -> CometPercentile, + classOf[StddevPop] -> CometStddevPop, + classOf[StddevSamp] -> CometStddevSamp, + classOf[Sum] -> CometSum, + classOf[VariancePop] -> CometVariancePop, + classOf[VarianceSamp] -> CometVarianceSamp) + base ++ sparkVersionSpecificAggregates + } /** * Returns true if all aggregate expressions in the list have intermediate buffer formats that diff --git a/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala index 1ad9ec75bf..43837f266f 100644 --- a/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.Sum import org.apache.comet.expressions.CometEvalMode -import org.apache.comet.serde.{CometExpressionSerde, CometStringDecode} +import org.apache.comet.serde.{CometAggregateExpressionSerde, CometExpressionSerde, CometStringDecode} import org.apache.comet.serde.ExprOuterClass.{BinaryOutputStyle, Expr} /** @@ -44,6 +44,8 @@ trait CometExprShim { Map.empty def sparkVersionSpecificMapExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map.empty + def sparkVersionSpecificAggregates: Map[Class[_], CometAggregateExpressionSerde[_]] = + Map.empty def sparkVersionSpecificExprToProtoInternal( expr: Expression, diff --git a/spark/src/main/spark-3.5/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-3.5/org/apache/comet/shims/CometExprShim.scala index 0be1185f59..42290ab565 100644 --- a/spark/src/main/spark-3.5/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-3.5/org/apache/comet/shims/CometExprShim.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.Sum import org.apache.comet.expressions.CometEvalMode -import org.apache.comet.serde.{CometExpressionSerde, CometStringDecode, CometToPrettyString, CometWidthBucket} +import org.apache.comet.serde.{CometAggregateExpressionSerde, CometExpressionSerde, CometStringDecode, CometToPrettyString, CometWidthBucket} import org.apache.comet.serde.ExprOuterClass.{BinaryOutputStyle, Expr} /** @@ -44,6 +44,8 @@ trait CometExprShim { Map(classOf[ToPrettyString] -> CometToPrettyString) def sparkVersionSpecificMapExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map.empty + def sparkVersionSpecificAggregates: Map[Class[_], CometAggregateExpressionSerde[_]] = + Map.empty def sparkVersionSpecificExprToProtoInternal( expr: Expression, diff --git a/spark/src/main/spark-4.x/org/apache/comet/shims/Spark4xCometExprShim.scala b/spark/src/main/spark-4.x/org/apache/comet/shims/Spark4xCometExprShim.scala index 7efd17f68a..9f9290920b 100644 --- a/spark/src/main/spark-4.x/org/apache/comet/shims/Spark4xCometExprShim.scala +++ b/spark/src/main/spark-4.x/org/apache/comet/shims/Spark4xCometExprShim.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.types.ArrayType import org.apache.comet.CometExplainInfo import org.apache.comet.expressions.CometEvalMode -import org.apache.comet.serde.{CometExpressionSerde, CometMapSort, CometToPrettyString, CometWidthBucket} +import org.apache.comet.serde.{CometAggregateExpressionSerde, CometExpressionSerde, CometMapSort, CometToPrettyString, CometWidthBucket} import org.apache.comet.serde.ExprOuterClass.Expr import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithFallbackReason, scalarFunctionExprToProtoWithReturnType} @@ -49,6 +49,8 @@ trait Spark4xCometExprShim extends CometExprShim4x { Map(classOf[ToPrettyString] -> CometToPrettyString) def sparkVersionSpecificMapExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map(classOf[MapSort] -> CometMapSort) + def sparkVersionSpecificAggregates: Map[Class[_], CometAggregateExpressionSerde[_]] = + Map.empty def sparkVersionSpecificExprToProtoInternal( expr: Expression, From 12ccf610e65cf2c49326faad233d4a3f885b62b2 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 2 Jul 2026 17:49:23 -0600 Subject: [PATCH 03/12] feat: add HllSketchAgg and HllUnionAgg proto messages Adds the two protobuf messages needed to serialize Spark's hll_sketch_agg and hll_union_agg aggregate functions, wired into the AggExpr oneof as field numbers 19 and 20. --- native/proto/src/proto/expr.proto | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 5b2a6ce9ee..978c0e0186 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -145,6 +145,8 @@ message AggExpr { BloomFilterAgg bloomFilterAgg = 16; CollectSet collectSet = 17; Percentile percentile = 18; + HllSketchAgg hllSketchAgg = 19; + HllUnionAgg hllUnionAgg = 20; } // Optional filter expression for SQL FILTER (WHERE ...) clause. @@ -271,6 +273,20 @@ enum BloomFilterVersion { BLOOM_FILTER_VERSION_V2 = 2; } +message HllSketchAgg { + // Child value expression (integral, string, or binary). + Expr child = 1; + // DataSketches lgConfigK (log2 of the number of buckets), Spark default 12. + int32 lg_config_k = 2; +} + +message HllUnionAgg { + // Child sketch expression (Binary column of serialized HLL sketches). + Expr child = 1; + // When false, Spark errors if input sketches have differing lgConfigK. + bool allow_different_lg_config_k = 2; +} + message CollectSet { Expr child = 1; DataType datatype = 2; From 5dc29b0045f2912405183f59f9895b1f8ca03b83 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 2 Jul 2026 18:04:45 -0600 Subject: [PATCH 04/12] feat: add native hll_sketch_agg accumulator and planner arm Wire up HllSketchAgg as an AggregateUDFImpl backed by SparkHllSketch, accepting Int8/16/32/64, Utf8, and Binary inputs and returning a serialized HLL sketch as Binary. Null groups evaluate to NULL, matching Spark's HllSketchAgg. Wires the new AggExprStruct::HllSketchAgg arm into the native planner's create_agg_expr. Also add a placeholder HllUnionAgg planner arm returning a clear "not yet supported" error, since that oneof variant already exists in the proto but its native accumulator lands in a follow-on task. --- native/core/src/execution/planner.rs | 13 +- .../src/agg_funcs/hll_sketch_agg.rs | 196 ++++++++++++++++++ native/spark-expr/src/agg_funcs/mod.rs | 2 + 3 files changed, 209 insertions(+), 2 deletions(-) create mode 100644 native/spark-expr/src/agg_funcs/hll_sketch_agg.rs diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 25162332fd..c4bbdca223 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -130,8 +130,8 @@ use datafusion_comet_proto::{ use datafusion_comet_spark_expr::{ jvm_udf::JvmScalarUdfExpr, ArrayInsert, Avg, AvgDecimal, Cast, CheckOverflow, Correlation, Covariance, CreateNamedStruct, DecimalRescaleCheckOverflow, GetArrayStructFields, - GetStructField, IfExpr, ListExtract, NormalizeNaNAndZero, SparkCastOptions, Stddev, SumDecimal, - ToJson, UnboundColumn, Variance, WideDecimalBinaryExpr, WideDecimalOp, + GetStructField, HllSketchAgg, IfExpr, ListExtract, NormalizeNaNAndZero, SparkCastOptions, + Stddev, SumDecimal, ToJson, UnboundColumn, Variance, WideDecimalBinaryExpr, WideDecimalOp, }; use itertools::Itertools; use jni::objects::{Global, JObject}; @@ -2648,11 +2648,20 @@ impl PhysicalPlanner { )); Self::create_aggr_func_expr("bloom_filter_agg", schema, vec![child], func) } + AggExprStruct::HllSketchAgg(expr) => { + let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?; + let func = AggregateUDF::new_from_impl(HllSketchAgg::new(expr.lg_config_k)); + Self::create_aggr_func_expr("hll_sketch_agg", schema, vec![child], func) + } AggExprStruct::CollectSet(expr) => { let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?; let func = AggregateUDF::new_from_impl(SparkCollectSet::new()); Self::create_aggr_func_expr("collect_set", schema, vec![child], func) } + // hll_union_agg's native accumulator + planner arm is wired up in a follow-on task. + AggExprStruct::HllUnionAgg(_) => Err(ExecutionError::GeneralError( + "hll_union_agg is not yet supported".to_string(), + )), } } diff --git a/native/spark-expr/src/agg_funcs/hll_sketch_agg.rs b/native/spark-expr/src/agg_funcs/hll_sketch_agg.rs new file mode 100644 index 0000000000..e962f984f1 --- /dev/null +++ b/native/spark-expr/src/agg_funcs/hll_sketch_agg.rs @@ -0,0 +1,196 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::agg_funcs::hll_sketch::SparkHllSketch; +use arrow::array::Array; +use arrow::array::ArrayRef; +use arrow::array::BinaryArray; +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion::common::{downcast_value, ScalarValue}; +use datafusion::error::{DataFusionError, Result}; +use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion::logical_expr::{AggregateUDFImpl, Signature, Volatility}; +use datafusion::physical_plan::Accumulator; +use std::sync::Arc; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct HllSketchAgg { + signature: Signature, + lg_config_k: i32, +} + +impl HllSketchAgg { + pub fn new(lg_config_k: i32) -> Self { + Self { + signature: Signature::uniform( + 1, + vec![ + DataType::Int8, + DataType::Int16, + DataType::Int32, + DataType::Int64, + DataType::Utf8, + DataType::Binary, + ], + Volatility::Immutable, + ), + lg_config_k, + } + } +} + +impl AggregateUDFImpl for HllSketchAgg { + fn name(&self) -> &str { + "hll_sketch_agg" + } + fn signature(&self) -> &Signature { + &self.signature + } + fn return_type(&self, _: &[DataType]) -> Result { + Ok(DataType::Binary) + } + fn accumulator(&self, _: AccumulatorArgs) -> Result> { + Ok(Box::new(HllSketchAccumulator::new(self.lg_config_k as u8))) + } + fn state_fields(&self, _: StateFieldsArgs) -> Result> { + Ok(vec![Arc::new(Field::new("sketch", DataType::Binary, true))]) + } + fn groups_accumulator_supported(&self, _: AccumulatorArgs) -> bool { + false + } +} + +#[derive(Debug)] +pub struct HllSketchAccumulator { + sketch: SparkHllSketch, + saw_input: bool, +} + +impl HllSketchAccumulator { + pub fn new(lg_config_k: u8) -> Self { + Self { + sketch: SparkHllSketch::new(lg_config_k), + saw_input: false, + } + } +} + +impl Accumulator for HllSketchAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + } + let arr = &values[0]; + (0..arr.len()).try_for_each(|i| { + match ScalarValue::try_from_array(arr, i)? { + ScalarValue::Int8(Some(v)) => { + self.sketch.update_i64(v as i64); + self.saw_input = true; + } + ScalarValue::Int16(Some(v)) => { + self.sketch.update_i64(v as i64); + self.saw_input = true; + } + ScalarValue::Int32(Some(v)) => { + self.sketch.update_i64(v as i64); + self.saw_input = true; + } + ScalarValue::Int64(Some(v)) => { + self.sketch.update_i64(v); + self.saw_input = true; + } + ScalarValue::Utf8(Some(v)) => { + self.sketch.update_bytes(v.as_bytes()); + self.saw_input = true; + } + ScalarValue::Binary(Some(v)) => { + self.sketch.update_bytes(&v); + self.saw_input = true; + } + // Spark's HllSketchAgg ignores null inputs. + ScalarValue::Int8(None) + | ScalarValue::Int16(None) + | ScalarValue::Int32(None) + | ScalarValue::Int64(None) + | ScalarValue::Utf8(None) + | ScalarValue::Binary(None) => {} + other => { + return Err(DataFusionError::Internal(format!( + "hll_sketch_agg received an unsupported input type: {other:?}" + ))) + } + } + Ok(()) + }) + } + + fn evaluate(&mut self) -> Result { + // Spark returns a non-null sketch even for empty groups only when it saw input; + // an empty group yields NULL. + if !self.saw_input { + return Ok(ScalarValue::Binary(None)); + } + Ok(ScalarValue::Binary(Some(self.sketch.to_sketch_bytes()))) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } + + fn state(&mut self) -> Result> { + if !self.saw_input { + return Ok(vec![ScalarValue::Binary(None)]); + } + Ok(vec![ScalarValue::Binary(Some( + self.sketch.to_sketch_bytes(), + ))]) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let arr = downcast_value!(states[0], BinaryArray); + for i in 0..arr.len() { + if arr.is_null(i) { + continue; + } + let peer = SparkHllSketch::from_bytes(arr.value(i))?; + // Merge peer into self by unioning; reuse SparkHllUnion via sketch merge. + self.sketch.merge_sketch(&peer); + self.saw_input = true; + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::Int64Array; + use datafusion::physical_plan::Accumulator; + use std::sync::Arc; + + #[test] + fn accumulates_and_estimates() { + let mut acc = HllSketchAccumulator::new(12); + let arr = Arc::new(Int64Array::from((0..1000i64).collect::>())); + acc.update_batch(&[arr]).unwrap(); + let ScalarValue::Binary(Some(bytes)) = acc.evaluate().unwrap() else { + panic!("expected binary") + }; + let est = crate::agg_funcs::estimate_from_bytes(&bytes).unwrap(); + assert!((est - 1000).abs() <= 30, "estimate {est}"); + } +} diff --git a/native/spark-expr/src/agg_funcs/mod.rs b/native/spark-expr/src/agg_funcs/mod.rs index 57af25b35e..3211de7db2 100644 --- a/native/spark-expr/src/agg_funcs/mod.rs +++ b/native/spark-expr/src/agg_funcs/mod.rs @@ -20,6 +20,7 @@ mod avg_decimal; mod correlation; mod covariance; mod hll_sketch; +mod hll_sketch_agg; mod stddev; mod sum_decimal; mod sum_int; @@ -31,6 +32,7 @@ pub use avg_decimal::AvgDecimal; pub use correlation::Correlation; pub use covariance::Covariance; pub use hll_sketch::{estimate_from_bytes, SparkHllSketch, SparkHllUnion}; +pub use hll_sketch_agg::HllSketchAgg; pub use stddev::Stddev; pub use sum_decimal::SumDecimal; pub use sum_int::SumInteger; From a52944d18e23325e742e59af6f5bf644badd2207 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 2 Jul 2026 18:17:12 -0600 Subject: [PATCH 05/12] feat: add hll_sketch_agg Scala serde for Spark 4.x --- .../comet/serde/CometHllSketchAgg.scala | 91 +++++++++++++++++++ .../comet/shims/Spark4xCometExprShim.scala | 5 +- 2 files changed, 94 insertions(+), 2 deletions(-) create mode 100644 spark/src/main/spark-4.x/org/apache/comet/serde/CometHllSketchAgg.scala diff --git a/spark/src/main/spark-4.x/org/apache/comet/serde/CometHllSketchAgg.scala b/spark/src/main/spark-4.x/org/apache/comet/serde/CometHllSketchAgg.scala new file mode 100644 index 0000000000..7d5c907df6 --- /dev/null +++ b/spark/src/main/spark-4.x/org/apache/comet/serde/CometHllSketchAgg.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, HllSketchAgg} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{BinaryType, IntegerType, LongType, StringType} + +import org.apache.comet.CometSparkSessionExtensions.withFallbackReason +import org.apache.comet.serde.QueryPlanSerde.exprToProto + +// In Spark 4.0, HllSketchAgg's fields are `left` (the child value expression) and `right` (the +// lgConfigK expression), not `child`/`lgConfigKExpression`. Accepted input types are Integer, +// Long, String, and Binary (Byte/Short are not accepted by Spark's HllSketchAgg). +object CometHllSketchAgg extends CometAggregateExpressionSerde[HllSketchAgg] { + + // DataSketches valid lgConfigK range; outside this Spark itself throws, so we + // fall back rather than forward an out-of-range value to native. + private val MinLgConfigK = 4 + private val MaxLgConfigK = 21 + + private val nonLiteralLgConfigKReason = + "The lgConfigK argument must be a foldable literal." + private val inputTypeReason = + "Only int, long, string, and binary input types are supported." + + override def getUnsupportedReasons(): Seq[String] = + Seq(nonLiteralLgConfigKReason, inputTypeReason) + + override def getSupportLevel(expr: HllSketchAgg): SupportLevel = { + if (!expr.right.foldable) { + return Unsupported(Some(nonLiteralLgConfigKReason)) + } + val lgConfigK = expr.right.eval() match { + case i: Int => i + case l: Long => l.toInt + case _ => return Unsupported(Some(nonLiteralLgConfigKReason)) + } + if (lgConfigK < MinLgConfigK || lgConfigK > MaxLgConfigK) { + return Unsupported(Some(s"lgConfigK must be in [$MinLgConfigK, $MaxLgConfigK]")) + } + expr.left.dataType match { + case IntegerType | LongType | StringType | BinaryType => + Compatible(None) + case _ => Unsupported(Some(inputTypeReason)) + } + } + + override def convert( + aggExpr: AggregateExpression, + expr: HllSketchAgg, + inputs: Seq[Attribute], + binding: Boolean, + conf: SQLConf): Option[ExprOuterClass.AggExpr] = { + val childExpr = exprToProto(expr.left, inputs, binding) + val lgConfigK = expr.right.eval() match { + case i: Int => i + case l: Long => l.toInt + case other => + withFallbackReason(aggExpr, s"Unsupported lgConfigK literal: $other", expr.left) + return None + } + if (childExpr.isDefined) { + val builder = ExprOuterClass.HllSketchAgg.newBuilder() + builder.setChild(childExpr.get) + builder.setLgConfigK(lgConfigK) + Some(ExprOuterClass.AggExpr.newBuilder().setHllSketchAgg(builder).build()) + } else { + withFallbackReason(aggExpr, expr.left) + None + } + } +} diff --git a/spark/src/main/spark-4.x/org/apache/comet/shims/Spark4xCometExprShim.scala b/spark/src/main/spark-4.x/org/apache/comet/shims/Spark4xCometExprShim.scala index 9f9290920b..1f11b04c3a 100644 --- a/spark/src/main/spark-4.x/org/apache/comet/shims/Spark4xCometExprShim.scala +++ b/spark/src/main/spark-4.x/org/apache/comet/shims/Spark4xCometExprShim.scala @@ -20,6 +20,7 @@ package org.apache.comet.shims import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.HllSketchAgg import org.apache.spark.sql.catalyst.expressions.json.{JsonExpressionUtils, StructsToJsonEvaluator} import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke} import org.apache.spark.sql.catalyst.expressions.url.ParseUrlEvaluator @@ -27,7 +28,7 @@ import org.apache.spark.sql.types.ArrayType import org.apache.comet.CometExplainInfo import org.apache.comet.expressions.CometEvalMode -import org.apache.comet.serde.{CometAggregateExpressionSerde, CometExpressionSerde, CometMapSort, CometToPrettyString, CometWidthBucket} +import org.apache.comet.serde.{CometAggregateExpressionSerde, CometExpressionSerde, CometHllSketchAgg, CometMapSort, CometToPrettyString, CometWidthBucket} import org.apache.comet.serde.ExprOuterClass.Expr import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithFallbackReason, scalarFunctionExprToProtoWithReturnType} @@ -50,7 +51,7 @@ trait Spark4xCometExprShim extends CometExprShim4x { def sparkVersionSpecificMapExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map(classOf[MapSort] -> CometMapSort) def sparkVersionSpecificAggregates: Map[Class[_], CometAggregateExpressionSerde[_]] = - Map.empty + Map(classOf[HllSketchAgg] -> CometHllSketchAgg) def sparkVersionSpecificExprToProtoInternal( expr: Expression, From f31cbf82c594ffa128548b20c3037f19c8bc8ba3 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 2 Jul 2026 18:25:19 -0600 Subject: [PATCH 06/12] feat: add hll_sketch_estimate scalar function --- native/spark-expr/src/comet_scalar_funcs.rs | 5 ++ native/spark-expr/src/hll_scalar.rs | 60 +++++++++++++++++++ native/spark-expr/src/lib.rs | 2 + .../comet/serde/CometHllSketchEstimate.scala | 40 +++++++++++++ .../comet/shims/Spark4xCometExprShim.scala | 6 +- 5 files changed, 111 insertions(+), 2 deletions(-) create mode 100644 native/spark-expr/src/hll_scalar.rs create mode 100644 spark/src/main/spark-4.x/org/apache/comet/serde/CometHllSketchEstimate.scala diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index 42ee72c82a..d52212873d 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -16,6 +16,7 @@ // under the License. use crate::hash_funcs::*; +use crate::hll_scalar::spark_hll_sketch_estimate; use crate::json_funcs::JsonArrayLength; use crate::map_funcs::spark_map_sort; use crate::math_funcs::abs::abs; @@ -221,6 +222,10 @@ pub fn create_comet_physical_fun_with_eval_mode( let func = Arc::new(spark_map_sort); make_comet_scalar_udf!("spark_map_sort", func, without data_type) } + "hll_sketch_estimate" => { + let func = Arc::new(|args: &[ColumnarValue]| spark_hll_sketch_estimate(args)); + make_comet_scalar_udf!("hll_sketch_estimate", func, without data_type) + } "to_time" => { make_comet_scalar_udf!("to_time", spark_to_time, without data_type, fail_on_error) } diff --git a/native/spark-expr/src/hll_scalar.rs b/native/spark-expr/src/hll_scalar.rs new file mode 100644 index 0000000000..755b5df5fa --- /dev/null +++ b/native/spark-expr/src/hll_scalar.rs @@ -0,0 +1,60 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::agg_funcs::estimate_from_bytes; +use arrow::array::{Array, BinaryArray, Int64Array}; +use datafusion::common::Result; +use datafusion::physical_plan::ColumnarValue; +use std::sync::Arc; + +/// Spark hll_sketch_estimate: Binary sketch -> Long distinct-count estimate. +pub fn spark_hll_sketch_estimate(args: &[ColumnarValue]) -> Result { + let arrays = ColumnarValue::values_to_arrays(args)?; + let input = arrays[0].as_any().downcast_ref::().unwrap(); + let mut out = Int64Array::builder(input.len()); + for i in 0..input.len() { + if input.is_null(i) { + out.append_null(); + } else { + out.append_value(estimate_from_bytes(input.value(i))?); + } + } + Ok(ColumnarValue::Array(Arc::new(out.finish()))) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::agg_funcs::SparkHllSketch; + + #[test] + fn estimates_from_sketch_column() { + let mut s = SparkHllSketch::new(12); + for i in 0..1000i64 { + s.update_i64(i); + } + let arr = Arc::new(BinaryArray::from(vec![Some( + s.to_sketch_bytes().as_slice(), + )])); + let out = spark_hll_sketch_estimate(&[ColumnarValue::Array(arr)]).unwrap(); + let ColumnarValue::Array(a) = out else { + panic!() + }; + let est = a.as_any().downcast_ref::().unwrap().value(0); + assert!((est - 1000).abs() <= 30, "estimate {est}"); + } +} diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index 174a4ada9a..baca331dc9 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -59,6 +59,8 @@ pub mod jvm_udf; mod conditional_funcs; mod conversion_funcs; +mod hll_scalar; +pub use hll_scalar::spark_hll_sketch_estimate; mod map_funcs; pub use map_funcs::spark_map_sort; mod math_funcs; diff --git a/spark/src/main/spark-4.x/org/apache/comet/serde/CometHllSketchEstimate.scala b/spark/src/main/spark-4.x/org/apache/comet/serde/CometHllSketchEstimate.scala new file mode 100644 index 0000000000..1f032abdc3 --- /dev/null +++ b/spark/src/main/spark-4.x/org/apache/comet/serde/CometHllSketchEstimate.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde + +import org.apache.spark.sql.catalyst.expressions.{Attribute, HllSketchEstimate} +import org.apache.spark.sql.types.LongType + +import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithFallbackReason, scalarFunctionExprToProtoWithReturnType} + +object CometHllSketchEstimate extends CometExpressionSerde[HllSketchEstimate] { + override def convert( + expr: HllSketchEstimate, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val childExpr = exprToProtoInternal(expr.child, inputs, binding) + val estimateExpr = scalarFunctionExprToProtoWithReturnType( + "hll_sketch_estimate", + LongType, + failOnError = false, + childExpr) + optExprWithFallbackReason(estimateExpr, expr, expr.child) + } +} diff --git a/spark/src/main/spark-4.x/org/apache/comet/shims/Spark4xCometExprShim.scala b/spark/src/main/spark-4.x/org/apache/comet/shims/Spark4xCometExprShim.scala index 1f11b04c3a..bf8c86b193 100644 --- a/spark/src/main/spark-4.x/org/apache/comet/shims/Spark4xCometExprShim.scala +++ b/spark/src/main/spark-4.x/org/apache/comet/shims/Spark4xCometExprShim.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.types.ArrayType import org.apache.comet.CometExplainInfo import org.apache.comet.expressions.CometEvalMode -import org.apache.comet.serde.{CometAggregateExpressionSerde, CometExpressionSerde, CometHllSketchAgg, CometMapSort, CometToPrettyString, CometWidthBucket} +import org.apache.comet.serde.{CometAggregateExpressionSerde, CometExpressionSerde, CometHllSketchAgg, CometHllSketchEstimate, CometMapSort, CometToPrettyString, CometWidthBucket} import org.apache.comet.serde.ExprOuterClass.Expr import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithFallbackReason, scalarFunctionExprToProtoWithReturnType} @@ -47,7 +47,9 @@ trait Spark4xCometExprShim extends CometExprShim4x { def sparkVersionSpecificMathExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map(classOf[WidthBucket] -> CometWidthBucket) def sparkVersionSpecificMiscExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = - Map(classOf[ToPrettyString] -> CometToPrettyString) + Map( + classOf[ToPrettyString] -> CometToPrettyString, + classOf[HllSketchEstimate] -> CometHllSketchEstimate) def sparkVersionSpecificMapExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map(classOf[MapSort] -> CometMapSort) def sparkVersionSpecificAggregates: Map[Class[_], CometAggregateExpressionSerde[_]] = From 4f0f289939f49c1cecbda25de8b89807f595d81b Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 2 Jul 2026 18:56:03 -0600 Subject: [PATCH 07/12] feat: mark HLL expressions incompatible and add opt-in end-to-end test [skip ci] --- .../comet/serde/CometHllSketchAgg.scala | 7 +++++- .../comet/serde/CometHllSketchEstimate.scala | 8 ++++++ .../apache/comet/CometExpressionSuite.scala | 25 +++++++++++++++++++ 3 files changed, 39 insertions(+), 1 deletion(-) diff --git a/spark/src/main/spark-4.x/org/apache/comet/serde/CometHllSketchAgg.scala b/spark/src/main/spark-4.x/org/apache/comet/serde/CometHllSketchAgg.scala index 7d5c907df6..5141e0c952 100644 --- a/spark/src/main/spark-4.x/org/apache/comet/serde/CometHllSketchAgg.scala +++ b/spark/src/main/spark-4.x/org/apache/comet/serde/CometHllSketchAgg.scala @@ -41,10 +41,15 @@ object CometHllSketchAgg extends CometAggregateExpressionSerde[HllSketchAgg] { "The lgConfigK argument must be a foldable literal." private val inputTypeReason = "Only int, long, string, and binary input types are supported." + private val incompatReason = + "Comet uses a Rust DataSketches port; HLL sketch bytes and estimates may differ " + + "slightly from Spark." override def getUnsupportedReasons(): Seq[String] = Seq(nonLiteralLgConfigKReason, inputTypeReason) + override def getIncompatibleReasons(): Seq[String] = Seq(incompatReason) + override def getSupportLevel(expr: HllSketchAgg): SupportLevel = { if (!expr.right.foldable) { return Unsupported(Some(nonLiteralLgConfigKReason)) @@ -59,7 +64,7 @@ object CometHllSketchAgg extends CometAggregateExpressionSerde[HllSketchAgg] { } expr.left.dataType match { case IntegerType | LongType | StringType | BinaryType => - Compatible(None) + Incompatible(Some(incompatReason)) case _ => Unsupported(Some(inputTypeReason)) } } diff --git a/spark/src/main/spark-4.x/org/apache/comet/serde/CometHllSketchEstimate.scala b/spark/src/main/spark-4.x/org/apache/comet/serde/CometHllSketchEstimate.scala index 1f032abdc3..454b53eed8 100644 --- a/spark/src/main/spark-4.x/org/apache/comet/serde/CometHllSketchEstimate.scala +++ b/spark/src/main/spark-4.x/org/apache/comet/serde/CometHllSketchEstimate.scala @@ -25,6 +25,14 @@ import org.apache.spark.sql.types.LongType import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithFallbackReason, scalarFunctionExprToProtoWithReturnType} object CometHllSketchEstimate extends CometExpressionSerde[HllSketchEstimate] { + private val incompatReason = + "Comet uses a Rust DataSketches port; HLL estimates may differ slightly from Spark." + + override def getIncompatibleReasons(): Seq[String] = Seq(incompatReason) + + override def getSupportLevel(expr: HllSketchEstimate): SupportLevel = + Incompatible(Some(incompatReason)) + override def convert( expr: HllSketchEstimate, inputs: Seq[Attribute], diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 390a7c4908..8faecb24a5 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -3380,4 +3380,29 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("hll_sketch_agg and hll_sketch_estimate (incompatible, opt-in)") { + assume(isSpark40Plus) + // HLL is approximate: Comet's Rust DataSketches estimator differs slightly from + // Spark's after a merge, so these functions are Incompatible. Opt in, assert the + // query runs natively (no fallback), and that the estimate is within HLL error of + // the TRUE distinct count (700). Do NOT compare bit-exactly to Spark. + withSQLConf( + "spark.comet.expression.HllSketchAgg.allowIncompatible" -> "true", + "spark.comet.expression.HllSketchEstimate.allowIncompatible" -> "true") { + withParquetTable((0 until 1000).map(i => (i % 700, i)), "tbl") { + def checkEstimate(query: String): Unit = { + val df = sql(query) + checkCometOperators(stripAQEPlan(df.queryExecution.executedPlan)) + val est = df.collect().head.getLong(0) + assert( + math.abs(est - 700).toDouble / 700 <= 0.05, + s"estimate $est not within 5% of the true distinct count 700 for: $query") + } + checkEstimate("SELECT hll_sketch_estimate(hll_sketch_agg(_1)) FROM tbl") + checkEstimate("SELECT hll_sketch_estimate(hll_sketch_agg(_1, 14)) FROM tbl") + checkEstimate("SELECT hll_sketch_estimate(hll_sketch_agg(cast(_1 as string))) FROM tbl") + } + } + } + } From 342bd6e18def1c80c40bef5f84ab2528fa560549 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 2 Jul 2026 19:05:00 -0600 Subject: [PATCH 08/12] feat: add native and serde for hll_union_agg [skip ci] --- native/core/src/execution/planner.rs | 15 +- .../spark-expr/src/agg_funcs/hll_union_agg.rs | 177 ++++++++++++++++++ native/spark-expr/src/agg_funcs/mod.rs | 2 + .../apache/comet/serde/CometHllUnionAgg.scala | 74 ++++++++ .../comet/shims/Spark4xCometExprShim.scala | 6 +- 5 files changed, 265 insertions(+), 9 deletions(-) create mode 100644 native/spark-expr/src/agg_funcs/hll_union_agg.rs create mode 100644 spark/src/main/spark-4.x/org/apache/comet/serde/CometHllUnionAgg.scala diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index c4bbdca223..e63e09c8ee 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -130,8 +130,9 @@ use datafusion_comet_proto::{ use datafusion_comet_spark_expr::{ jvm_udf::JvmScalarUdfExpr, ArrayInsert, Avg, AvgDecimal, Cast, CheckOverflow, Correlation, Covariance, CreateNamedStruct, DecimalRescaleCheckOverflow, GetArrayStructFields, - GetStructField, HllSketchAgg, IfExpr, ListExtract, NormalizeNaNAndZero, SparkCastOptions, - Stddev, SumDecimal, ToJson, UnboundColumn, Variance, WideDecimalBinaryExpr, WideDecimalOp, + GetStructField, HllSketchAgg, HllUnionAgg, IfExpr, ListExtract, NormalizeNaNAndZero, + SparkCastOptions, Stddev, SumDecimal, ToJson, UnboundColumn, Variance, WideDecimalBinaryExpr, + WideDecimalOp, }; use itertools::Itertools; use jni::objects::{Global, JObject}; @@ -2658,10 +2659,12 @@ impl PhysicalPlanner { let func = AggregateUDF::new_from_impl(SparkCollectSet::new()); Self::create_aggr_func_expr("collect_set", schema, vec![child], func) } - // hll_union_agg's native accumulator + planner arm is wired up in a follow-on task. - AggExprStruct::HllUnionAgg(_) => Err(ExecutionError::GeneralError( - "hll_union_agg is not yet supported".to_string(), - )), + AggExprStruct::HllUnionAgg(expr) => { + let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?; + let func = + AggregateUDF::new_from_impl(HllUnionAgg::new(expr.allow_different_lg_config_k)); + Self::create_aggr_func_expr("hll_union_agg", schema, vec![child], func) + } } } diff --git a/native/spark-expr/src/agg_funcs/hll_union_agg.rs b/native/spark-expr/src/agg_funcs/hll_union_agg.rs new file mode 100644 index 0000000000..9d3658cef7 --- /dev/null +++ b/native/spark-expr/src/agg_funcs/hll_union_agg.rs @@ -0,0 +1,177 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::agg_funcs::hll_sketch::{SparkHllSketch, SparkHllUnion}; +use arrow::array::{Array, ArrayRef, BinaryArray}; +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion::common::{downcast_value, ScalarValue}; +use datafusion::error::{DataFusionError, Result}; +use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion::logical_expr::{AggregateUDFImpl, Signature, Volatility}; +use datafusion::physical_plan::Accumulator; +use std::sync::Arc; + +// NOTE: matches bloom_filter_agg.rs for DataFusion 54.0.0 - no `as_any` method on +// AggregateUDFImpl, and PartialEq/Eq/Hash are required (DynEq/DynHash). +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct HllUnionAgg { + signature: Signature, + allow_different_lg_config_k: bool, +} + +impl HllUnionAgg { + pub fn new(allow_different_lg_config_k: bool) -> Self { + Self { + signature: Signature::uniform(1, vec![DataType::Binary], Volatility::Immutable), + allow_different_lg_config_k, + } + } +} + +impl AggregateUDFImpl for HllUnionAgg { + fn name(&self) -> &str { + "hll_union_agg" + } + fn signature(&self) -> &Signature { + &self.signature + } + fn return_type(&self, _: &[DataType]) -> Result { + Ok(DataType::Binary) + } + fn accumulator(&self, _: AccumulatorArgs) -> Result> { + Ok(Box::new(HllUnionAccumulator::new( + self.allow_different_lg_config_k, + ))) + } + fn state_fields(&self, _: StateFieldsArgs) -> Result> { + Ok(vec![Arc::new(Field::new("sketch", DataType::Binary, true))]) + } + fn groups_accumulator_supported(&self, _: AccumulatorArgs) -> bool { + false + } +} + +#[derive(Debug)] +pub struct HllUnionAccumulator { + // Spark's HllUnionAgg defers creating the Union until the first sketch is seen, + // then builds `new Union(sketch.getLgConfigK)` - so lgMaxK is NOT a fixed 12. + union: Option, + allow_different_lg_config_k: bool, + seen_lg_config_k: Option, +} + +impl HllUnionAccumulator { + pub fn new(allow_different_lg_config_k: bool) -> Self { + Self { + union: None, + allow_different_lg_config_k, + seen_lg_config_k: None, + } + } + + fn absorb(&mut self, bytes: &[u8]) -> Result<()> { + let sketch = SparkHllSketch::from_bytes(bytes)?; + let k = sketch.lg_config_k(); + match self.seen_lg_config_k { + None => { + // Lazily instantiate the union from the first sketch's lgConfigK. + self.seen_lg_config_k = Some(k); + self.union = Some(SparkHllUnion::new(k)); + } + Some(prev) if prev != k && !self.allow_different_lg_config_k => { + return Err(DataFusionError::Execution(format!( + "Sketches have different lgConfigK values: {prev} and {k}. \ + Set allowDifferentLgConfigK to true to enable unions of different lgConfigK." + ))); + } + _ => {} + } + self.union.as_mut().unwrap().merge(&sketch); + Ok(()) + } +} + +impl Accumulator for HllUnionAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + } + let arr = downcast_value!(values[0], BinaryArray); + for i in 0..arr.len() { + if !arr.is_null(i) { + self.absorb(arr.value(i))?; + } + } + Ok(()) + } + fn evaluate(&mut self) -> Result { + match &self.union { + Some(u) => Ok(ScalarValue::Binary(Some(u.to_sketch_bytes()))), + None => Ok(ScalarValue::Binary(None)), + } + } + fn size(&self) -> usize { + std::mem::size_of_val(self) + } + fn state(&mut self) -> Result> { + match &self.union { + Some(u) => Ok(vec![ScalarValue::Binary(Some(u.to_sketch_bytes()))]), + None => Ok(vec![ScalarValue::Binary(None)]), + } + } + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let arr = downcast_value!(states[0], BinaryArray); + for i in 0..arr.len() { + if !arr.is_null(i) { + self.absorb(arr.value(i))?; + } + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::agg_funcs::hll_sketch::SparkHllSketch; + use arrow::array::BinaryArray; + use datafusion::physical_plan::Accumulator; + use std::sync::Arc; + + #[test] + fn unions_sketch_column() { + let mut a = SparkHllSketch::new(12); + for i in 0..1000i64 { + a.update_i64(i); + } + let mut b = SparkHllSketch::new(12); + for i in 500..1500i64 { + b.update_i64(i); + } + let arr = Arc::new(BinaryArray::from(vec![ + Some(a.to_sketch_bytes().as_slice()), + Some(b.to_sketch_bytes().as_slice()), + ])); + let mut acc = HllUnionAccumulator::new(false); + acc.update_batch(&[arr]).unwrap(); + let ScalarValue::Binary(Some(bytes)) = acc.evaluate().unwrap() else { + panic!() + }; + let est = crate::agg_funcs::estimate_from_bytes(&bytes).unwrap(); + assert!((est - 1500).abs() <= 45, "union estimate {est}"); + } +} diff --git a/native/spark-expr/src/agg_funcs/mod.rs b/native/spark-expr/src/agg_funcs/mod.rs index 3211de7db2..f74e88b848 100644 --- a/native/spark-expr/src/agg_funcs/mod.rs +++ b/native/spark-expr/src/agg_funcs/mod.rs @@ -21,6 +21,7 @@ mod correlation; mod covariance; mod hll_sketch; mod hll_sketch_agg; +mod hll_union_agg; mod stddev; mod sum_decimal; mod sum_int; @@ -33,6 +34,7 @@ pub use correlation::Correlation; pub use covariance::Covariance; pub use hll_sketch::{estimate_from_bytes, SparkHllSketch, SparkHllUnion}; pub use hll_sketch_agg::HllSketchAgg; +pub use hll_union_agg::HllUnionAgg; pub use stddev::Stddev; pub use sum_decimal::SumDecimal; pub use sum_int::SumInteger; diff --git a/spark/src/main/spark-4.x/org/apache/comet/serde/CometHllUnionAgg.scala b/spark/src/main/spark-4.x/org/apache/comet/serde/CometHllUnionAgg.scala new file mode 100644 index 0000000000..156f26e675 --- /dev/null +++ b/spark/src/main/spark-4.x/org/apache/comet/serde/CometHllUnionAgg.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, HllUnionAgg} +import org.apache.spark.sql.internal.SQLConf + +import org.apache.comet.CometSparkSessionExtensions.withFallbackReason +import org.apache.comet.serde.QueryPlanSerde.exprToProto + +// IMPORTANT: In Spark 4.0, HllUnionAgg's fields are `left` (the child binary sketch +// expression) and `right` (the allowDifferentLgConfigK boolean expression) - NOT +// `child`/`allowDifferentLgConfigKExpression`. (Verified against Spark 4.0 source.) +object CometHllUnionAgg extends CometAggregateExpressionSerde[HllUnionAgg] { + + private val incompatReason = + "Comet uses a Rust DataSketches port; HLL sketch bytes and estimates may differ slightly " + + "from Spark." + + override def getIncompatibleReasons(): Seq[String] = Seq(incompatReason) + + override def getSupportLevel(expr: HllUnionAgg): SupportLevel = { + if (!expr.right.foldable) { + Unsupported(Some("The allowDifferentLgConfigK argument must be a foldable literal.")) + } else { + Incompatible(Some(incompatReason)) + } + } + + override def convert( + aggExpr: AggregateExpression, + expr: HllUnionAgg, + inputs: Seq[Attribute], + binding: Boolean, + conf: SQLConf): Option[ExprOuterClass.AggExpr] = { + val childExpr = exprToProto(expr.left, inputs, binding) + val allow = expr.right.eval() match { + case b: Boolean => b + case other => + withFallbackReason( + aggExpr, + s"Unsupported allowDifferentLgConfigK literal: $other", + expr.left) + return None + } + if (childExpr.isDefined) { + val builder = ExprOuterClass.HllUnionAgg.newBuilder() + builder.setChild(childExpr.get) + builder.setAllowDifferentLgConfigK(allow) + Some(ExprOuterClass.AggExpr.newBuilder().setHllUnionAgg(builder).build()) + } else { + withFallbackReason(aggExpr, expr.left) + None + } + } +} diff --git a/spark/src/main/spark-4.x/org/apache/comet/shims/Spark4xCometExprShim.scala b/spark/src/main/spark-4.x/org/apache/comet/shims/Spark4xCometExprShim.scala index bf8c86b193..cfab80033a 100644 --- a/spark/src/main/spark-4.x/org/apache/comet/shims/Spark4xCometExprShim.scala +++ b/spark/src/main/spark-4.x/org/apache/comet/shims/Spark4xCometExprShim.scala @@ -20,7 +20,7 @@ package org.apache.comet.shims import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.HllSketchAgg +import org.apache.spark.sql.catalyst.expressions.aggregate.{HllSketchAgg, HllUnionAgg} import org.apache.spark.sql.catalyst.expressions.json.{JsonExpressionUtils, StructsToJsonEvaluator} import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke} import org.apache.spark.sql.catalyst.expressions.url.ParseUrlEvaluator @@ -28,7 +28,7 @@ import org.apache.spark.sql.types.ArrayType import org.apache.comet.CometExplainInfo import org.apache.comet.expressions.CometEvalMode -import org.apache.comet.serde.{CometAggregateExpressionSerde, CometExpressionSerde, CometHllSketchAgg, CometHllSketchEstimate, CometMapSort, CometToPrettyString, CometWidthBucket} +import org.apache.comet.serde.{CometAggregateExpressionSerde, CometExpressionSerde, CometHllSketchAgg, CometHllSketchEstimate, CometHllUnionAgg, CometMapSort, CometToPrettyString, CometWidthBucket} import org.apache.comet.serde.ExprOuterClass.Expr import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithFallbackReason, scalarFunctionExprToProtoWithReturnType} @@ -53,7 +53,7 @@ trait Spark4xCometExprShim extends CometExprShim4x { def sparkVersionSpecificMapExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map(classOf[MapSort] -> CometMapSort) def sparkVersionSpecificAggregates: Map[Class[_], CometAggregateExpressionSerde[_]] = - Map(classOf[HllSketchAgg] -> CometHllSketchAgg) + Map(classOf[HllSketchAgg] -> CometHllSketchAgg, classOf[HllUnionAgg] -> CometHllUnionAgg) def sparkVersionSpecificExprToProtoInternal( expr: Expression, From a9f88933175ef18f3ec86a495f724fb595b15d83 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 2 Jul 2026 19:12:16 -0600 Subject: [PATCH 09/12] feat: add hll_union scalar function [skip ci] --- native/spark-expr/src/comet_scalar_funcs.rs | 6 +- native/spark-expr/src/hll_scalar.rs | 77 ++++++++++++++++++- native/spark-expr/src/lib.rs | 1 + .../apache/comet/serde/CometHllUnion.scala | 59 ++++++++++++++ .../comet/shims/Spark4xCometExprShim.scala | 5 +- 5 files changed, 144 insertions(+), 4 deletions(-) create mode 100644 spark/src/main/spark-4.x/org/apache/comet/serde/CometHllUnion.scala diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index d52212873d..15de6bac49 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -16,7 +16,7 @@ // under the License. use crate::hash_funcs::*; -use crate::hll_scalar::spark_hll_sketch_estimate; +use crate::hll_scalar::{spark_hll_sketch_estimate, spark_hll_union}; use crate::json_funcs::JsonArrayLength; use crate::map_funcs::spark_map_sort; use crate::math_funcs::abs::abs; @@ -226,6 +226,10 @@ pub fn create_comet_physical_fun_with_eval_mode( let func = Arc::new(|args: &[ColumnarValue]| spark_hll_sketch_estimate(args)); make_comet_scalar_udf!("hll_sketch_estimate", func, without data_type) } + "hll_union" => { + let func = Arc::new(|args: &[ColumnarValue]| spark_hll_union(args)); + make_comet_scalar_udf!("hll_union", func, without data_type) + } "to_time" => { make_comet_scalar_udf!("to_time", spark_to_time, without data_type, fail_on_error) } diff --git a/native/spark-expr/src/hll_scalar.rs b/native/spark-expr/src/hll_scalar.rs index 755b5df5fa..f9a9648b20 100644 --- a/native/spark-expr/src/hll_scalar.rs +++ b/native/spark-expr/src/hll_scalar.rs @@ -17,7 +17,7 @@ use crate::agg_funcs::estimate_from_bytes; use arrow::array::{Array, BinaryArray, Int64Array}; -use datafusion::common::Result; +use datafusion::common::{DataFusionError, Result}; use datafusion::physical_plan::ColumnarValue; use std::sync::Arc; @@ -36,6 +36,43 @@ pub fn spark_hll_sketch_estimate(args: &[ColumnarValue]) -> Result Result { + use crate::agg_funcs::{SparkHllSketch, SparkHllUnion}; + use arrow::array::BooleanArray; + let arrays = ColumnarValue::values_to_arrays(args)?; + let a = arrays[0].as_any().downcast_ref::().unwrap(); + let b = arrays[1].as_any().downcast_ref::().unwrap(); + let allow = arrays[2].as_any().downcast_ref::().unwrap(); + let mut out = arrow::array::BinaryBuilder::new(); + for i in 0..a.len() { + if a.is_null(i) || b.is_null(i) { + out.append_null(); + continue; + } + let sa = SparkHllSketch::from_bytes(a.value(i))?; + let sb = SparkHllSketch::from_bytes(b.value(i))?; + let allow_i = !allow.is_null(i) && allow.value(i); + if !allow_i && sa.lg_config_k() != sb.lg_config_k() { + return Err(DataFusionError::Execution(format!( + "Sketches have different lgConfigK values: {} and {}. \ + Set allowDifferentLgConfigK to true to enable unions of different lgConfigK.", + sa.lg_config_k(), + sb.lg_config_k() + ))); + } + // Spark builds `new Union(min(k1, k2))`. + let mut u = SparkHllUnion::new(sa.lg_config_k().min(sb.lg_config_k())); + u.merge(&sa); + u.merge(&sb); + out.append_value(u.to_sketch_bytes()); + } + Ok(ColumnarValue::Array(Arc::new(out.finish()))) +} + #[cfg(test)] mod tests { use super::*; @@ -58,3 +95,41 @@ mod tests { assert!((est - 1000).abs() <= 30, "estimate {est}"); } } + +#[cfg(test)] +mod union_tests { + use super::*; + use crate::agg_funcs::SparkHllSketch; + use arrow::array::BooleanArray; + + #[test] + fn unions_two_sketch_columns() { + let mut a = SparkHllSketch::new(12); + for i in 0..1000i64 { + a.update_i64(i); + } + let mut b = SparkHllSketch::new(12); + for i in 500..1500i64 { + b.update_i64(i); + } + let aa = Arc::new(BinaryArray::from(vec![Some( + a.to_sketch_bytes().as_slice(), + )])); + let bb = Arc::new(BinaryArray::from(vec![Some( + b.to_sketch_bytes().as_slice(), + )])); + let allow = Arc::new(BooleanArray::from(vec![false])); + let out = spark_hll_union(&[ + ColumnarValue::Array(aa), + ColumnarValue::Array(bb), + ColumnarValue::Array(allow), + ]) + .unwrap(); + let ColumnarValue::Array(arr) = out else { + panic!() + }; + let est = estimate_from_bytes(arr.as_any().downcast_ref::().unwrap().value(0)) + .unwrap(); + assert!((est - 1500).abs() <= 45, "estimate {est}"); + } +} diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index baca331dc9..2260953155 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -61,6 +61,7 @@ mod conditional_funcs; mod conversion_funcs; mod hll_scalar; pub use hll_scalar::spark_hll_sketch_estimate; +pub use hll_scalar::spark_hll_union; mod map_funcs; pub use map_funcs::spark_map_sort; mod math_funcs; diff --git a/spark/src/main/spark-4.x/org/apache/comet/serde/CometHllUnion.scala b/spark/src/main/spark-4.x/org/apache/comet/serde/CometHllUnion.scala new file mode 100644 index 0000000000..af7b330966 --- /dev/null +++ b/spark/src/main/spark-4.x/org/apache/comet/serde/CometHllUnion.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde + +import org.apache.spark.sql.catalyst.expressions.{Attribute, HllUnion} +import org.apache.spark.sql.types.BinaryType + +import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithFallbackReason, scalarFunctionExprToProtoWithReturnType} + +// Spark 4.0 HllUnion is a TernaryExpression: `first`, `second` (binary sketches), +// `third` (allowDifferentLgConfigK boolean, default Literal(false)). All three are +// passed to the native `hll_union`, which enforces the flag and unions with min lgConfigK. +object CometHllUnion extends CometExpressionSerde[HllUnion] { + private val incompatReason = + "Comet uses a Rust DataSketches port; HLL sketch bytes and estimates may differ slightly from Spark." + + override def getIncompatibleReasons(): Seq[String] = Seq(incompatReason) + + override def getSupportLevel(expr: HllUnion): SupportLevel = { + if (!expr.third.foldable) { + Unsupported(Some("The allowDifferentLgConfigK argument must be a foldable literal.")) + } else { + Incompatible(Some(incompatReason)) + } + } + override def convert( + expr: HllUnion, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val first = exprToProtoInternal(expr.first, inputs, binding) + val second = exprToProtoInternal(expr.second, inputs, binding) + val third = exprToProtoInternal(expr.third, inputs, binding) + val unionExpr = scalarFunctionExprToProtoWithReturnType( + "hll_union", + BinaryType, + failOnError = false, + first, + second, + third) + optExprWithFallbackReason(unionExpr, expr, expr.first, expr.second, expr.third) + } +} diff --git a/spark/src/main/spark-4.x/org/apache/comet/shims/Spark4xCometExprShim.scala b/spark/src/main/spark-4.x/org/apache/comet/shims/Spark4xCometExprShim.scala index cfab80033a..93cc2ef184 100644 --- a/spark/src/main/spark-4.x/org/apache/comet/shims/Spark4xCometExprShim.scala +++ b/spark/src/main/spark-4.x/org/apache/comet/shims/Spark4xCometExprShim.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.types.ArrayType import org.apache.comet.CometExplainInfo import org.apache.comet.expressions.CometEvalMode -import org.apache.comet.serde.{CometAggregateExpressionSerde, CometExpressionSerde, CometHllSketchAgg, CometHllSketchEstimate, CometHllUnionAgg, CometMapSort, CometToPrettyString, CometWidthBucket} +import org.apache.comet.serde.{CometAggregateExpressionSerde, CometExpressionSerde, CometHllSketchAgg, CometHllSketchEstimate, CometHllUnion, CometHllUnionAgg, CometMapSort, CometToPrettyString, CometWidthBucket} import org.apache.comet.serde.ExprOuterClass.Expr import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithFallbackReason, scalarFunctionExprToProtoWithReturnType} @@ -49,7 +49,8 @@ trait Spark4xCometExprShim extends CometExprShim4x { def sparkVersionSpecificMiscExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( classOf[ToPrettyString] -> CometToPrettyString, - classOf[HllSketchEstimate] -> CometHllSketchEstimate) + classOf[HllSketchEstimate] -> CometHllSketchEstimate, + classOf[HllUnion] -> CometHllUnion) def sparkVersionSpecificMapExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map(classOf[MapSort] -> CometMapSort) def sparkVersionSpecificAggregates: Map[Class[_], CometAggregateExpressionSerde[_]] = From 93dedb3fdf3d06c709f90d8167b99f97a659b41e Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 2 Jul 2026 19:19:42 -0600 Subject: [PATCH 10/12] test: add hll_union end-to-end and error tests, apply formatting [skip ci] --- native/spark-expr/src/agg_funcs/hll_sketch.rs | 15 ++++-- .../apache/comet/CometExpressionSuite.scala | 47 +++++++++++++++++++ 2 files changed, 59 insertions(+), 3 deletions(-) diff --git a/native/spark-expr/src/agg_funcs/hll_sketch.rs b/native/spark-expr/src/agg_funcs/hll_sketch.rs index 5e40c54aae..68c8d03943 100644 --- a/native/spark-expr/src/agg_funcs/hll_sketch.rs +++ b/native/spark-expr/src/agg_funcs/hll_sketch.rs @@ -157,7 +157,10 @@ mod tests { } let bytes = s.to_sketch_bytes(); let est = estimate_from_bytes(&bytes).unwrap(); - assert!((est - 1000).abs() <= 30, "estimate {est} not within 3% of 1000"); + assert!( + (est - 1000).abs() <= 30, + "estimate {est} not within 3% of 1000" + ); } #[test] @@ -174,7 +177,10 @@ mod tests { u.merge(&a); u.merge(&b); let est = estimate_from_bytes(&u.to_sketch_bytes()).unwrap(); - assert!((est - 1500).abs() <= 45, "union estimate {est} not within 3% of 1500"); + assert!( + (est - 1500).abs() <= 45, + "union estimate {est} not within 3% of 1500" + ); } /// A sketch built from raw bytes (StringType/BinaryType path) round-trips and @@ -187,7 +193,10 @@ mod tests { } s.update_bytes(b""); // skipped, no effect let est = estimate_from_bytes(&s.to_sketch_bytes()).unwrap(); - assert!((est - 1000).abs() <= 30, "estimate {est} not within 3% of 1000"); + assert!( + (est - 1000).abs() <= 30, + "estimate {est} not within 3% of 1000" + ); } /// Cross-engine regression guard: `testdata/hll_sketch_spark_lgk12.bin` was diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 8faecb24a5..2b58a9e791 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -3405,4 +3405,51 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("hll_union_agg and hll_union (incompatible, opt-in)") { + assume(isSpark40Plus) + withSQLConf( + "spark.comet.expression.HllSketchAgg.allowIncompatible" -> "true", + "spark.comet.expression.HllSketchEstimate.allowIncompatible" -> "true", + "spark.comet.expression.HllUnionAgg.allowIncompatible" -> "true", + "spark.comet.expression.HllUnion.allowIncompatible" -> "true") { + withParquetTable((0 until 1000).map(i => (i % 3, i)), "tbl") { + // hll_union_agg: union the per-group sketches -> ~1000 distinct. + val aggDf = sql( + "SELECT hll_sketch_estimate(hll_union_agg(s)) FROM " + + "(SELECT _1 AS g, hll_sketch_agg(_2) AS s FROM tbl GROUP BY _1)") + checkCometOperators(stripAQEPlan(aggDf.queryExecution.executedPlan)) + val aggEst = aggDf.collect().head.getLong(0) + assert(math.abs(aggEst - 1000).toDouble / 1000 <= 0.05, s"union_agg estimate $aggEst") + + // hll_union: union two disjoint group sketches -> ~667 distinct. + val unionDf = sql( + "SELECT hll_sketch_estimate(hll_union(a.s, b.s)) FROM " + + "(SELECT hll_sketch_agg(_2) AS s FROM tbl WHERE _1 = 0) a, " + + "(SELECT hll_sketch_agg(_2) AS s FROM tbl WHERE _1 = 1) b") + checkCometOperators(stripAQEPlan(unionDf.queryExecution.executedPlan)) + val unionEst = unionDf.collect().head.getLong(0) + assert(math.abs(unionEst - 667).toDouble / 667 <= 0.05, s"union estimate $unionEst") + } + } + } + + test("hll_union_agg rejects different lgConfigK when not allowed") { + assume(isSpark40Plus) + withSQLConf( + "spark.comet.expression.HllSketchAgg.allowIncompatible" -> "true", + "spark.comet.expression.HllUnionAgg.allowIncompatible" -> "true") { + withParquetTable((0 until 100).map(i => Tuple1(i)), "tbl") { + // A lgConfigK=10 sketch unioned with a lgConfigK=12 sketch (allowDifferentLgConfigK + // defaults false) must throw in BOTH Spark and Comet. + val df = sql( + "SELECT hll_union_agg(s) FROM (" + + " SELECT hll_sketch_agg(_1, 10) AS s FROM tbl UNION ALL" + + " SELECT hll_sketch_agg(_1, 12) AS s FROM tbl)") + val (sparkErr, cometErr) = checkSparkAnswerMaybeThrows(df) + assert(sparkErr.isDefined, "expected Spark to throw on different lgConfigK") + assert(cometErr.isDefined, "expected Comet to throw on different lgConfigK") + } + } + } + } From ad39ec54d2f55260e99a010a86abf012b654f4b8 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 2 Jul 2026 19:37:58 -0600 Subject: [PATCH 11/12] fix: HLL empty-group returns empty sketch and size() accounts for sketch heap [skip ci] --- .../src/agg_funcs/hll_sketch_agg.rs | 47 ++++++++++------- .../spark-expr/src/agg_funcs/hll_union_agg.rs | 51 ++++++++++++++++++- .../apache/comet/CometExpressionSuite.scala | 20 ++++++++ 3 files changed, 98 insertions(+), 20 deletions(-) diff --git a/native/spark-expr/src/agg_funcs/hll_sketch_agg.rs b/native/spark-expr/src/agg_funcs/hll_sketch_agg.rs index e962f984f1..3453b64657 100644 --- a/native/spark-expr/src/agg_funcs/hll_sketch_agg.rs +++ b/native/spark-expr/src/agg_funcs/hll_sketch_agg.rs @@ -77,14 +77,12 @@ impl AggregateUDFImpl for HllSketchAgg { #[derive(Debug)] pub struct HllSketchAccumulator { sketch: SparkHllSketch, - saw_input: bool, } impl HllSketchAccumulator { pub fn new(lg_config_k: u8) -> Self { Self { sketch: SparkHllSketch::new(lg_config_k), - saw_input: false, } } } @@ -99,27 +97,21 @@ impl Accumulator for HllSketchAccumulator { match ScalarValue::try_from_array(arr, i)? { ScalarValue::Int8(Some(v)) => { self.sketch.update_i64(v as i64); - self.saw_input = true; } ScalarValue::Int16(Some(v)) => { self.sketch.update_i64(v as i64); - self.saw_input = true; } ScalarValue::Int32(Some(v)) => { self.sketch.update_i64(v as i64); - self.saw_input = true; } ScalarValue::Int64(Some(v)) => { self.sketch.update_i64(v); - self.saw_input = true; } ScalarValue::Utf8(Some(v)) => { self.sketch.update_bytes(v.as_bytes()); - self.saw_input = true; } ScalarValue::Binary(Some(v)) => { self.sketch.update_bytes(&v); - self.saw_input = true; } // Spark's HllSketchAgg ignores null inputs. ScalarValue::Int8(None) @@ -139,22 +131,18 @@ impl Accumulator for HllSketchAccumulator { } fn evaluate(&mut self) -> Result { - // Spark returns a non-null sketch even for empty groups only when it saw input; - // an empty group yields NULL. - if !self.saw_input { - return Ok(ScalarValue::Binary(None)); - } + // Spark's HllSketchAgg is declared non-nullable: an empty/all-null group + // still returns a serialized empty sketch (which estimates to 0), never NULL. Ok(ScalarValue::Binary(Some(self.sketch.to_sketch_bytes()))) } fn size(&self) -> usize { - std::mem::size_of_val(self) + // An HLL_8 sketch at lgConfigK=k can heap-allocate up to 1 << k bytes; + // account for that so memory reservation reflects actual usage. + std::mem::size_of_val(self) + (1usize << self.sketch.lg_config_k() as usize) } fn state(&mut self) -> Result> { - if !self.saw_input { - return Ok(vec![ScalarValue::Binary(None)]); - } Ok(vec![ScalarValue::Binary(Some( self.sketch.to_sketch_bytes(), ))]) @@ -169,7 +157,6 @@ impl Accumulator for HllSketchAccumulator { let peer = SparkHllSketch::from_bytes(arr.value(i))?; // Merge peer into self by unioning; reuse SparkHllUnion via sketch merge. self.sketch.merge_sketch(&peer); - self.saw_input = true; } Ok(()) } @@ -193,4 +180,28 @@ mod tests { let est = crate::agg_funcs::estimate_from_bytes(&bytes).unwrap(); assert!((est - 1000).abs() <= 30, "estimate {est}"); } + + /// Spark's `HllSketchAgg` is non-nullable: an empty/all-null group still + /// produces a serialized empty sketch (estimate 0), not NULL. + #[test] + fn empty_group_evaluates_to_empty_sketch_not_null() { + let mut acc = HllSketchAccumulator::new(12); + let ScalarValue::Binary(Some(bytes)) = acc.evaluate().unwrap() else { + panic!("expected Binary(Some(_)) for an empty group, got NULL") + }; + let est = crate::agg_funcs::estimate_from_bytes(&bytes).unwrap(); + assert_eq!(est, 0, "empty sketch should estimate to 0, got {est}"); + } + + #[test] + fn size_accounts_for_sketch_heap() { + let mut acc = HllSketchAccumulator::new(12); + let arr = Arc::new(Int64Array::from((0..10000i64).collect::>())); + acc.update_batch(&[arr]).unwrap(); + assert!( + acc.size() > 1000, + "size() should account for the sketch heap allocation, got {}", + acc.size() + ); + } } diff --git a/native/spark-expr/src/agg_funcs/hll_union_agg.rs b/native/spark-expr/src/agg_funcs/hll_union_agg.rs index 9d3658cef7..aee7c044a1 100644 --- a/native/spark-expr/src/agg_funcs/hll_union_agg.rs +++ b/native/spark-expr/src/agg_funcs/hll_union_agg.rs @@ -65,6 +65,10 @@ impl AggregateUDFImpl for HllUnionAgg { } } +/// Default `lgMaxK` used by Spark's `new Union()` when constructing the empty +/// union returned for a group that never absorbed any sketch. +const DEFAULT_LG_K: u8 = 12; + #[derive(Debug)] pub struct HllUnionAccumulator { // Spark's HllUnionAgg defers creating the Union until the first sketch is seen, @@ -119,18 +123,31 @@ impl Accumulator for HllUnionAccumulator { Ok(()) } fn evaluate(&mut self) -> Result { + // Spark's HllUnionAgg is declared non-nullable: an empty/all-null group + // still returns the serialized bytes of an empty `new Union()` (default + // lgMaxK), which estimates to 0, never NULL. match &self.union { Some(u) => Ok(ScalarValue::Binary(Some(u.to_sketch_bytes()))), - None => Ok(ScalarValue::Binary(None)), + None => Ok(ScalarValue::Binary(Some( + SparkHllUnion::new(DEFAULT_LG_K).to_sketch_bytes(), + ))), } } fn size(&self) -> usize { + // An HLL_8 sketch at lgConfigK=k can heap-allocate up to 1 << k bytes; + // account for that so memory reservation reflects actual usage. std::mem::size_of_val(self) + + self + .seen_lg_config_k + .map(|k| 1usize << k as usize) + .unwrap_or(0) } fn state(&mut self) -> Result> { match &self.union { Some(u) => Ok(vec![ScalarValue::Binary(Some(u.to_sketch_bytes()))]), - None => Ok(vec![ScalarValue::Binary(None)]), + None => Ok(vec![ScalarValue::Binary(Some( + SparkHllUnion::new(DEFAULT_LG_K).to_sketch_bytes(), + ))]), } } fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { @@ -174,4 +191,34 @@ mod tests { let est = crate::agg_funcs::estimate_from_bytes(&bytes).unwrap(); assert!((est - 1500).abs() <= 45, "union estimate {est}"); } + + /// Spark's `HllUnionAgg` is non-nullable: an empty/all-null group still + /// produces the serialized bytes of an empty union (estimate 0), not NULL. + #[test] + fn empty_group_evaluates_to_empty_sketch_not_null() { + let mut acc = HllUnionAccumulator::new(false); + let ScalarValue::Binary(Some(bytes)) = acc.evaluate().unwrap() else { + panic!("expected Binary(Some(_)) for an empty group, got NULL") + }; + let est = crate::agg_funcs::estimate_from_bytes(&bytes).unwrap(); + assert_eq!(est, 0, "empty union should estimate to 0, got {est}"); + } + + #[test] + fn size_accounts_for_sketch_heap() { + let mut a = SparkHllSketch::new(12); + for i in 0..10000i64 { + a.update_i64(i); + } + let arr = Arc::new(BinaryArray::from(vec![Some( + a.to_sketch_bytes().as_slice(), + )])); + let mut acc = HllUnionAccumulator::new(false); + acc.update_batch(&[arr]).unwrap(); + assert!( + acc.size() > 1000, + "size() should account for the sketch heap allocation, got {}", + acc.size() + ); + } } diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 2b58a9e791..9605f5dd16 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -3452,4 +3452,24 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("hll_sketch_agg over all-null input estimates to 0, not NULL") { + assume(isSpark40Plus) + // Spark's HllSketchAgg/HllSketchEstimate are declared non-nullable: an empty or + // all-null group still produces a serialized empty sketch, and hll_sketch_estimate + // reads that as 0, never NULL. Guard against regressing to Binary(None) here. + withSQLConf( + "spark.comet.expression.HllSketchAgg.allowIncompatible" -> "true", + "spark.comet.expression.HllSketchEstimate.allowIncompatible" -> "true") { + withParquetTable((0 until 100).map(_ => Tuple1(null.asInstanceOf[Integer])), "tbl") { + val df = sql("SELECT hll_sketch_estimate(hll_sketch_agg(_1)) FROM tbl") + checkCometOperators(stripAQEPlan(df.queryExecution.executedPlan)) + val row = df.collect().head + assert(!row.isNullAt(0), "expected a non-null estimate for an all-null group") + assert( + row.getLong(0) == 0, + s"expected estimate 0 for an all-null group, got ${row.getLong(0)}") + } + } + } + } From df37620d910bd817cb94c3f47da1a9a8b3579d26 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 2 Jul 2026 19:47:07 -0600 Subject: [PATCH 12/12] ci: run CI for HLL sketch functions