diff --git a/native/Cargo.lock b/native/Cargo.lock index adb764fbfb..4fc27f0299 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -2141,6 +2141,7 @@ dependencies = [ "datafusion", "datafusion-comet-common", "datafusion-comet-jni-bridge", + "datasketches", "futures", "hex", "jni 0.22.4", @@ -2735,6 +2736,12 @@ dependencies = [ "sqlparser", ] +[[package]] +name = "datasketches" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46c4cf71a36b46dcfc00e5014c0c20ccad2b1b6a008304d7d57d2749b2d41b3d" + [[package]] name = "debugid" version = "0.8.0" diff --git a/native/Cargo.toml b/native/Cargo.toml index 3e797eb968..0c79dad57d 100644 --- a/native/Cargo.toml +++ b/native/Cargo.toml @@ -49,6 +49,7 @@ datafusion-comet-proto = { path = "proto" } datafusion-comet-shuffle = { path = "shuffle" } chrono = { version = "0.4", default-features = false, features = ["clock"] } chrono-tz = { version = "0.10" } +datasketches = { version = "0.3.0", features = ["hll"] } futures = "0.3.32" num = "0.4" rand = "0.10" diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 25162332fd..e63e09c8ee 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -130,8 +130,9 @@ use datafusion_comet_proto::{ use datafusion_comet_spark_expr::{ jvm_udf::JvmScalarUdfExpr, ArrayInsert, Avg, AvgDecimal, Cast, CheckOverflow, Correlation, Covariance, CreateNamedStruct, DecimalRescaleCheckOverflow, GetArrayStructFields, - GetStructField, IfExpr, ListExtract, NormalizeNaNAndZero, SparkCastOptions, Stddev, SumDecimal, - ToJson, UnboundColumn, Variance, WideDecimalBinaryExpr, WideDecimalOp, + GetStructField, HllSketchAgg, HllUnionAgg, IfExpr, ListExtract, NormalizeNaNAndZero, + SparkCastOptions, Stddev, SumDecimal, ToJson, UnboundColumn, Variance, WideDecimalBinaryExpr, + WideDecimalOp, }; use itertools::Itertools; use jni::objects::{Global, JObject}; @@ -2648,11 +2649,22 @@ impl PhysicalPlanner { )); Self::create_aggr_func_expr("bloom_filter_agg", schema, vec![child], func) } + AggExprStruct::HllSketchAgg(expr) => { + let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?; + let func = AggregateUDF::new_from_impl(HllSketchAgg::new(expr.lg_config_k)); + Self::create_aggr_func_expr("hll_sketch_agg", schema, vec![child], func) + } AggExprStruct::CollectSet(expr) => { let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?; let func = AggregateUDF::new_from_impl(SparkCollectSet::new()); Self::create_aggr_func_expr("collect_set", schema, vec![child], func) } + AggExprStruct::HllUnionAgg(expr) => { + let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?; + let func = + AggregateUDF::new_from_impl(HllUnionAgg::new(expr.allow_different_lg_config_k)); + Self::create_aggr_func_expr("hll_union_agg", schema, vec![child], func) + } } } diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 5b2a6ce9ee..978c0e0186 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -145,6 +145,8 @@ message AggExpr { BloomFilterAgg bloomFilterAgg = 16; CollectSet collectSet = 17; Percentile percentile = 18; + HllSketchAgg hllSketchAgg = 19; + HllUnionAgg hllUnionAgg = 20; } // Optional filter expression for SQL FILTER (WHERE ...) clause. @@ -271,6 +273,20 @@ enum BloomFilterVersion { BLOOM_FILTER_VERSION_V2 = 2; } +message HllSketchAgg { + // Child value expression (integral, string, or binary). + Expr child = 1; + // DataSketches lgConfigK (log2 of the number of buckets), Spark default 12. + int32 lg_config_k = 2; +} + +message HllUnionAgg { + // Child sketch expression (Binary column of serialized HLL sketches). + Expr child = 1; + // When false, Spark errors if input sketches have differing lgConfigK. + bool allow_different_lg_config_k = 2; +} + message CollectSet { Expr child = 1; DataType datatype = 2; diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml index 800fe3ecb1..d54094ff40 100644 --- a/native/spark-expr/Cargo.toml +++ b/native/spark-expr/Cargo.toml @@ -44,6 +44,7 @@ twox-hash = "2.1.2" rand = { workspace = true } hex = "0.4.3" base64 = "0.22.1" +datasketches = { workspace = true } [dev-dependencies] arrow = {workspace = true} diff --git a/native/spark-expr/src/agg_funcs/hll_sketch.rs b/native/spark-expr/src/agg_funcs/hll_sketch.rs new file mode 100644 index 0000000000..68c8d03943 --- /dev/null +++ b/native/spark-expr/src/agg_funcs/hll_sketch.rs @@ -0,0 +1,218 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Thin wrapper over the `datasketches` crate's HLL sketch, isolating all +//! crate-specific API so Comet's aggregate/scalar code depends on a stable +//! surface. Every sketch uses `HllType::Hll8` and DataSketches' +//! `DEFAULT_UPDATE_SEED` (9001), matching Spark's `HllSketchAgg`. +//! +//! Input hashing goes through the crate's `hash_value` wrappers +//! (`raw_bytes` for strings/binary without Rust's length prefix, `sign_extend` +//! for narrow integers) so the MurmurHash3-x64-128 input bytes are identical to +//! DataSketches-Java. This makes the sketches mutually readable with Spark. +//! +//! Note: the crate serializes List/Set (low-cardinality) modes in DataSketches +//! *compact* form, whereas Spark emits the *updatable* form. The bytes are +//! therefore not byte-identical to Spark's output for small inputs, but +//! DataSketches `deserialize` reads both forms, so estimates round-trip in both +//! directions. Comet must own both Partial and Final aggregation +//! (`supportsMixedPartialFinal = false`) so this compact intermediate is only +//! ever read back by Comet. + +use datafusion::error::DataFusionError; +use datasketches::hash_value::{raw_bytes, sign_extend}; +use datasketches::hll::{HllSketch, HllType, HllUnion}; + +/// A DataSketches HLL_8 sketch configured to match Spark's `HllSketchAgg`. +#[derive(Debug)] +pub struct SparkHllSketch { + inner: HllSketch, +} + +impl SparkHllSketch { + /// Create an empty HLL_8 sketch with the given `lgConfigK`. + pub fn new(lg_config_k: u8) -> Self { + Self { + inner: HllSketch::new(lg_config_k, HllType::Hll8), + } + } + + /// Update with a 64-bit integer. Spark widens narrower integrals to `long` + /// before hashing; callers should pass the already-widened value here. + /// Rust's `Hash` for `i64` writes 8 little-endian bytes with no prefix, + /// matching DataSketches-Java `update(long)`. + pub fn update_i64(&mut self, v: i64) { + self.inner.update(v); + } + + /// Update with a narrow signed integer, sign-extending to 64 bits exactly as + /// Spark's `toLong` does before hashing. + pub fn update_i32(&mut self, v: i32) { + self.inner.update(sign_extend::from_i32(v)); + } + pub fn update_i16(&mut self, v: i16) { + self.inner.update(sign_extend::from_i16(v)); + } + pub fn update_i8(&mut self, v: i8) { + self.inner.update(sign_extend::from_i8(v)); + } + + /// Update with raw bytes (used for both StringType UTF-8 bytes and + /// BinaryType), hashing without Rust's slice length prefix. Empty inputs are + /// skipped, matching DataSketches (and Spark), which ignore empty values. + pub fn update_bytes(&mut self, v: &[u8]) { + if v.is_empty() { + return; + } + self.inner.update(raw_bytes::from_slice(v)); + } + + /// Serialize to DataSketches bytes (compact for List/Set modes, full for HLL + /// array modes). Readable by Spark's `hll_sketch_estimate` / `hll_union_agg`. + pub fn to_sketch_bytes(&self) -> Vec { + self.inner.serialize() + } + + /// Deserialize a DataSketches sketch (either compact or updatable form). + pub fn from_bytes(bytes: &[u8]) -> Result { + HllSketch::deserialize(bytes) + .map(|inner| Self { inner }) + .map_err(|e| DataFusionError::Internal(format!("invalid HLL sketch bytes: {e}"))) + } + + /// The configured `lgConfigK`. + pub fn lg_config_k(&self) -> u8 { + self.inner.lg_config_k() + } + + /// Raw cardinality estimate (caller rounds to `i64` for Spark). + pub fn estimate(&self) -> f64 { + self.inner.estimate() + } + + /// Merge another sketch into this one via a union, keeping HLL_8 output. + pub fn merge_sketch(&mut self, other: &SparkHllSketch) { + let mut u = HllUnion::new(self.lg_config_k()); + u.update(&self.inner); + u.update(&other.inner); + self.inner = u.to_sketch(HllType::Hll8); + } +} + +/// A DataSketches HLL union configured to match Spark's `HllUnionAgg`. +#[derive(Debug)] +pub struct SparkHllUnion { + inner: HllUnion, +} + +impl SparkHllUnion { + /// Create an empty union with the given `lgMaxK` (Spark fixes this at 12). + pub fn new(lg_max_k: u8) -> Self { + Self { + inner: HllUnion::new(lg_max_k), + } + } + + /// Merge a sketch into the union. + pub fn merge(&mut self, sketch: &SparkHllSketch) { + self.inner.update(&sketch.inner); + } + + /// The union result as an HLL_8 sketch's serialized bytes. + pub fn to_sketch_bytes(&self) -> Vec { + self.inner.to_sketch(HllType::Hll8).serialize() + } +} + +/// Estimate the distinct count from serialized sketch bytes, rounded to the +/// nearest `i64` (Spark's `hll_sketch_estimate` returns a `Long`). +pub fn estimate_from_bytes(bytes: &[u8]) -> Result { + let sketch = SparkHllSketch::from_bytes(bytes)?; + Ok(sketch.estimate().round() as i64) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn sketch_roundtrips_and_estimates() { + let mut s = SparkHllSketch::new(12); + for i in 0..1000i64 { + s.update_i64(i); + } + let bytes = s.to_sketch_bytes(); + let est = estimate_from_bytes(&bytes).unwrap(); + assert!( + (est - 1000).abs() <= 30, + "estimate {est} not within 3% of 1000" + ); + } + + #[test] + fn union_merges_two_sketches() { + let mut a = SparkHllSketch::new(12); + for i in 0..1000i64 { + a.update_i64(i); + } + let mut b = SparkHllSketch::new(12); + for i in 500..1500i64 { + b.update_i64(i); + } + let mut u = SparkHllUnion::new(12); + u.merge(&a); + u.merge(&b); + let est = estimate_from_bytes(&u.to_sketch_bytes()).unwrap(); + assert!( + (est - 1500).abs() <= 45, + "union estimate {est} not within 3% of 1500" + ); + } + + /// A sketch built from raw bytes (StringType/BinaryType path) round-trips and + /// estimates. Empty inputs are skipped, so they do not affect the estimate. + #[test] + fn byte_input_roundtrips_and_estimates() { + let mut s = SparkHllSketch::new(12); + for i in 0..1000i64 { + s.update_bytes(format!("val-{i}").as_bytes()); + } + s.update_bytes(b""); // skipped, no effect + let est = estimate_from_bytes(&s.to_sketch_bytes()).unwrap(); + assert!( + (est - 1000).abs() <= 30, + "estimate {est} not within 3% of 1000" + ); + } + + /// Cross-engine regression guard: `testdata/hll_sketch_spark_lgk12.bin` was + /// produced by Spark 3.5's `hll_sketch_agg(id)` over `range(0, 1000)`. Comet + /// must read it and estimate the distinct count, proving the crate's + /// serialization stays DataSketches-Java compatible across crate upgrades. + /// (For this HLL_8 input the Comet-produced bytes are byte-identical to + /// Spark's; low-cardinality List/Set sketches differ in bytes but remain + /// mutually readable.) + #[test] + fn reads_spark_produced_sketch() { + let bytes = include_bytes!("testdata/hll_sketch_spark_lgk12.bin"); + let est = estimate_from_bytes(bytes).unwrap(); + assert!( + (est - 1000).abs() <= 30, + "estimate {est} of Spark-produced sketch not within 3% of 1000" + ); + } +} diff --git a/native/spark-expr/src/agg_funcs/hll_sketch_agg.rs b/native/spark-expr/src/agg_funcs/hll_sketch_agg.rs new file mode 100644 index 0000000000..3453b64657 --- /dev/null +++ b/native/spark-expr/src/agg_funcs/hll_sketch_agg.rs @@ -0,0 +1,207 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::agg_funcs::hll_sketch::SparkHllSketch; +use arrow::array::Array; +use arrow::array::ArrayRef; +use arrow::array::BinaryArray; +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion::common::{downcast_value, ScalarValue}; +use datafusion::error::{DataFusionError, Result}; +use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion::logical_expr::{AggregateUDFImpl, Signature, Volatility}; +use datafusion::physical_plan::Accumulator; +use std::sync::Arc; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct HllSketchAgg { + signature: Signature, + lg_config_k: i32, +} + +impl HllSketchAgg { + pub fn new(lg_config_k: i32) -> Self { + Self { + signature: Signature::uniform( + 1, + vec![ + DataType::Int8, + DataType::Int16, + DataType::Int32, + DataType::Int64, + DataType::Utf8, + DataType::Binary, + ], + Volatility::Immutable, + ), + lg_config_k, + } + } +} + +impl AggregateUDFImpl for HllSketchAgg { + fn name(&self) -> &str { + "hll_sketch_agg" + } + fn signature(&self) -> &Signature { + &self.signature + } + fn return_type(&self, _: &[DataType]) -> Result { + Ok(DataType::Binary) + } + fn accumulator(&self, _: AccumulatorArgs) -> Result> { + Ok(Box::new(HllSketchAccumulator::new(self.lg_config_k as u8))) + } + fn state_fields(&self, _: StateFieldsArgs) -> Result> { + Ok(vec![Arc::new(Field::new("sketch", DataType::Binary, true))]) + } + fn groups_accumulator_supported(&self, _: AccumulatorArgs) -> bool { + false + } +} + +#[derive(Debug)] +pub struct HllSketchAccumulator { + sketch: SparkHllSketch, +} + +impl HllSketchAccumulator { + pub fn new(lg_config_k: u8) -> Self { + Self { + sketch: SparkHllSketch::new(lg_config_k), + } + } +} + +impl Accumulator for HllSketchAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + } + let arr = &values[0]; + (0..arr.len()).try_for_each(|i| { + match ScalarValue::try_from_array(arr, i)? { + ScalarValue::Int8(Some(v)) => { + self.sketch.update_i64(v as i64); + } + ScalarValue::Int16(Some(v)) => { + self.sketch.update_i64(v as i64); + } + ScalarValue::Int32(Some(v)) => { + self.sketch.update_i64(v as i64); + } + ScalarValue::Int64(Some(v)) => { + self.sketch.update_i64(v); + } + ScalarValue::Utf8(Some(v)) => { + self.sketch.update_bytes(v.as_bytes()); + } + ScalarValue::Binary(Some(v)) => { + self.sketch.update_bytes(&v); + } + // Spark's HllSketchAgg ignores null inputs. + ScalarValue::Int8(None) + | ScalarValue::Int16(None) + | ScalarValue::Int32(None) + | ScalarValue::Int64(None) + | ScalarValue::Utf8(None) + | ScalarValue::Binary(None) => {} + other => { + return Err(DataFusionError::Internal(format!( + "hll_sketch_agg received an unsupported input type: {other:?}" + ))) + } + } + Ok(()) + }) + } + + fn evaluate(&mut self) -> Result { + // Spark's HllSketchAgg is declared non-nullable: an empty/all-null group + // still returns a serialized empty sketch (which estimates to 0), never NULL. + Ok(ScalarValue::Binary(Some(self.sketch.to_sketch_bytes()))) + } + + fn size(&self) -> usize { + // An HLL_8 sketch at lgConfigK=k can heap-allocate up to 1 << k bytes; + // account for that so memory reservation reflects actual usage. + std::mem::size_of_val(self) + (1usize << self.sketch.lg_config_k() as usize) + } + + fn state(&mut self) -> Result> { + Ok(vec![ScalarValue::Binary(Some( + self.sketch.to_sketch_bytes(), + ))]) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let arr = downcast_value!(states[0], BinaryArray); + for i in 0..arr.len() { + if arr.is_null(i) { + continue; + } + let peer = SparkHllSketch::from_bytes(arr.value(i))?; + // Merge peer into self by unioning; reuse SparkHllUnion via sketch merge. + self.sketch.merge_sketch(&peer); + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::Int64Array; + use datafusion::physical_plan::Accumulator; + use std::sync::Arc; + + #[test] + fn accumulates_and_estimates() { + let mut acc = HllSketchAccumulator::new(12); + let arr = Arc::new(Int64Array::from((0..1000i64).collect::>())); + acc.update_batch(&[arr]).unwrap(); + let ScalarValue::Binary(Some(bytes)) = acc.evaluate().unwrap() else { + panic!("expected binary") + }; + let est = crate::agg_funcs::estimate_from_bytes(&bytes).unwrap(); + assert!((est - 1000).abs() <= 30, "estimate {est}"); + } + + /// Spark's `HllSketchAgg` is non-nullable: an empty/all-null group still + /// produces a serialized empty sketch (estimate 0), not NULL. + #[test] + fn empty_group_evaluates_to_empty_sketch_not_null() { + let mut acc = HllSketchAccumulator::new(12); + let ScalarValue::Binary(Some(bytes)) = acc.evaluate().unwrap() else { + panic!("expected Binary(Some(_)) for an empty group, got NULL") + }; + let est = crate::agg_funcs::estimate_from_bytes(&bytes).unwrap(); + assert_eq!(est, 0, "empty sketch should estimate to 0, got {est}"); + } + + #[test] + fn size_accounts_for_sketch_heap() { + let mut acc = HllSketchAccumulator::new(12); + let arr = Arc::new(Int64Array::from((0..10000i64).collect::>())); + acc.update_batch(&[arr]).unwrap(); + assert!( + acc.size() > 1000, + "size() should account for the sketch heap allocation, got {}", + acc.size() + ); + } +} diff --git a/native/spark-expr/src/agg_funcs/hll_union_agg.rs b/native/spark-expr/src/agg_funcs/hll_union_agg.rs new file mode 100644 index 0000000000..aee7c044a1 --- /dev/null +++ b/native/spark-expr/src/agg_funcs/hll_union_agg.rs @@ -0,0 +1,224 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::agg_funcs::hll_sketch::{SparkHllSketch, SparkHllUnion}; +use arrow::array::{Array, ArrayRef, BinaryArray}; +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion::common::{downcast_value, ScalarValue}; +use datafusion::error::{DataFusionError, Result}; +use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion::logical_expr::{AggregateUDFImpl, Signature, Volatility}; +use datafusion::physical_plan::Accumulator; +use std::sync::Arc; + +// NOTE: matches bloom_filter_agg.rs for DataFusion 54.0.0 - no `as_any` method on +// AggregateUDFImpl, and PartialEq/Eq/Hash are required (DynEq/DynHash). +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct HllUnionAgg { + signature: Signature, + allow_different_lg_config_k: bool, +} + +impl HllUnionAgg { + pub fn new(allow_different_lg_config_k: bool) -> Self { + Self { + signature: Signature::uniform(1, vec![DataType::Binary], Volatility::Immutable), + allow_different_lg_config_k, + } + } +} + +impl AggregateUDFImpl for HllUnionAgg { + fn name(&self) -> &str { + "hll_union_agg" + } + fn signature(&self) -> &Signature { + &self.signature + } + fn return_type(&self, _: &[DataType]) -> Result { + Ok(DataType::Binary) + } + fn accumulator(&self, _: AccumulatorArgs) -> Result> { + Ok(Box::new(HllUnionAccumulator::new( + self.allow_different_lg_config_k, + ))) + } + fn state_fields(&self, _: StateFieldsArgs) -> Result> { + Ok(vec![Arc::new(Field::new("sketch", DataType::Binary, true))]) + } + fn groups_accumulator_supported(&self, _: AccumulatorArgs) -> bool { + false + } +} + +/// Default `lgMaxK` used by Spark's `new Union()` when constructing the empty +/// union returned for a group that never absorbed any sketch. +const DEFAULT_LG_K: u8 = 12; + +#[derive(Debug)] +pub struct HllUnionAccumulator { + // Spark's HllUnionAgg defers creating the Union until the first sketch is seen, + // then builds `new Union(sketch.getLgConfigK)` - so lgMaxK is NOT a fixed 12. + union: Option, + allow_different_lg_config_k: bool, + seen_lg_config_k: Option, +} + +impl HllUnionAccumulator { + pub fn new(allow_different_lg_config_k: bool) -> Self { + Self { + union: None, + allow_different_lg_config_k, + seen_lg_config_k: None, + } + } + + fn absorb(&mut self, bytes: &[u8]) -> Result<()> { + let sketch = SparkHllSketch::from_bytes(bytes)?; + let k = sketch.lg_config_k(); + match self.seen_lg_config_k { + None => { + // Lazily instantiate the union from the first sketch's lgConfigK. + self.seen_lg_config_k = Some(k); + self.union = Some(SparkHllUnion::new(k)); + } + Some(prev) if prev != k && !self.allow_different_lg_config_k => { + return Err(DataFusionError::Execution(format!( + "Sketches have different lgConfigK values: {prev} and {k}. \ + Set allowDifferentLgConfigK to true to enable unions of different lgConfigK." + ))); + } + _ => {} + } + self.union.as_mut().unwrap().merge(&sketch); + Ok(()) + } +} + +impl Accumulator for HllUnionAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + } + let arr = downcast_value!(values[0], BinaryArray); + for i in 0..arr.len() { + if !arr.is_null(i) { + self.absorb(arr.value(i))?; + } + } + Ok(()) + } + fn evaluate(&mut self) -> Result { + // Spark's HllUnionAgg is declared non-nullable: an empty/all-null group + // still returns the serialized bytes of an empty `new Union()` (default + // lgMaxK), which estimates to 0, never NULL. + match &self.union { + Some(u) => Ok(ScalarValue::Binary(Some(u.to_sketch_bytes()))), + None => Ok(ScalarValue::Binary(Some( + SparkHllUnion::new(DEFAULT_LG_K).to_sketch_bytes(), + ))), + } + } + fn size(&self) -> usize { + // An HLL_8 sketch at lgConfigK=k can heap-allocate up to 1 << k bytes; + // account for that so memory reservation reflects actual usage. + std::mem::size_of_val(self) + + self + .seen_lg_config_k + .map(|k| 1usize << k as usize) + .unwrap_or(0) + } + fn state(&mut self) -> Result> { + match &self.union { + Some(u) => Ok(vec![ScalarValue::Binary(Some(u.to_sketch_bytes()))]), + None => Ok(vec![ScalarValue::Binary(Some( + SparkHllUnion::new(DEFAULT_LG_K).to_sketch_bytes(), + ))]), + } + } + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let arr = downcast_value!(states[0], BinaryArray); + for i in 0..arr.len() { + if !arr.is_null(i) { + self.absorb(arr.value(i))?; + } + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::agg_funcs::hll_sketch::SparkHllSketch; + use arrow::array::BinaryArray; + use datafusion::physical_plan::Accumulator; + use std::sync::Arc; + + #[test] + fn unions_sketch_column() { + let mut a = SparkHllSketch::new(12); + for i in 0..1000i64 { + a.update_i64(i); + } + let mut b = SparkHllSketch::new(12); + for i in 500..1500i64 { + b.update_i64(i); + } + let arr = Arc::new(BinaryArray::from(vec![ + Some(a.to_sketch_bytes().as_slice()), + Some(b.to_sketch_bytes().as_slice()), + ])); + let mut acc = HllUnionAccumulator::new(false); + acc.update_batch(&[arr]).unwrap(); + let ScalarValue::Binary(Some(bytes)) = acc.evaluate().unwrap() else { + panic!() + }; + let est = crate::agg_funcs::estimate_from_bytes(&bytes).unwrap(); + assert!((est - 1500).abs() <= 45, "union estimate {est}"); + } + + /// Spark's `HllUnionAgg` is non-nullable: an empty/all-null group still + /// produces the serialized bytes of an empty union (estimate 0), not NULL. + #[test] + fn empty_group_evaluates_to_empty_sketch_not_null() { + let mut acc = HllUnionAccumulator::new(false); + let ScalarValue::Binary(Some(bytes)) = acc.evaluate().unwrap() else { + panic!("expected Binary(Some(_)) for an empty group, got NULL") + }; + let est = crate::agg_funcs::estimate_from_bytes(&bytes).unwrap(); + assert_eq!(est, 0, "empty union should estimate to 0, got {est}"); + } + + #[test] + fn size_accounts_for_sketch_heap() { + let mut a = SparkHllSketch::new(12); + for i in 0..10000i64 { + a.update_i64(i); + } + let arr = Arc::new(BinaryArray::from(vec![Some( + a.to_sketch_bytes().as_slice(), + )])); + let mut acc = HllUnionAccumulator::new(false); + acc.update_batch(&[arr]).unwrap(); + assert!( + acc.size() > 1000, + "size() should account for the sketch heap allocation, got {}", + acc.size() + ); + } +} diff --git a/native/spark-expr/src/agg_funcs/mod.rs b/native/spark-expr/src/agg_funcs/mod.rs index 2a0322e46c..f74e88b848 100644 --- a/native/spark-expr/src/agg_funcs/mod.rs +++ b/native/spark-expr/src/agg_funcs/mod.rs @@ -19,6 +19,9 @@ mod avg; mod avg_decimal; mod correlation; mod covariance; +mod hll_sketch; +mod hll_sketch_agg; +mod hll_union_agg; mod stddev; mod sum_decimal; mod sum_int; @@ -29,6 +32,9 @@ pub use avg::Avg; pub use avg_decimal::AvgDecimal; pub use correlation::Correlation; pub use covariance::Covariance; +pub use hll_sketch::{estimate_from_bytes, SparkHllSketch, SparkHllUnion}; +pub use hll_sketch_agg::HllSketchAgg; +pub use hll_union_agg::HllUnionAgg; pub use stddev::Stddev; pub use sum_decimal::SumDecimal; pub use sum_int::SumInteger; diff --git a/native/spark-expr/src/agg_funcs/testdata/hll_sketch_spark_lgk12.bin b/native/spark-expr/src/agg_funcs/testdata/hll_sketch_spark_lgk12.bin new file mode 100644 index 0000000000..761477d5f6 Binary files /dev/null and b/native/spark-expr/src/agg_funcs/testdata/hll_sketch_spark_lgk12.bin differ diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index 42ee72c82a..15de6bac49 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -16,6 +16,7 @@ // under the License. use crate::hash_funcs::*; +use crate::hll_scalar::{spark_hll_sketch_estimate, spark_hll_union}; use crate::json_funcs::JsonArrayLength; use crate::map_funcs::spark_map_sort; use crate::math_funcs::abs::abs; @@ -221,6 +222,14 @@ pub fn create_comet_physical_fun_with_eval_mode( let func = Arc::new(spark_map_sort); make_comet_scalar_udf!("spark_map_sort", func, without data_type) } + "hll_sketch_estimate" => { + let func = Arc::new(|args: &[ColumnarValue]| spark_hll_sketch_estimate(args)); + make_comet_scalar_udf!("hll_sketch_estimate", func, without data_type) + } + "hll_union" => { + let func = Arc::new(|args: &[ColumnarValue]| spark_hll_union(args)); + make_comet_scalar_udf!("hll_union", func, without data_type) + } "to_time" => { make_comet_scalar_udf!("to_time", spark_to_time, without data_type, fail_on_error) } diff --git a/native/spark-expr/src/hll_scalar.rs b/native/spark-expr/src/hll_scalar.rs new file mode 100644 index 0000000000..f9a9648b20 --- /dev/null +++ b/native/spark-expr/src/hll_scalar.rs @@ -0,0 +1,135 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::agg_funcs::estimate_from_bytes; +use arrow::array::{Array, BinaryArray, Int64Array}; +use datafusion::common::{DataFusionError, Result}; +use datafusion::physical_plan::ColumnarValue; +use std::sync::Arc; + +/// Spark hll_sketch_estimate: Binary sketch -> Long distinct-count estimate. +pub fn spark_hll_sketch_estimate(args: &[ColumnarValue]) -> Result { + let arrays = ColumnarValue::values_to_arrays(args)?; + let input = arrays[0].as_any().downcast_ref::().unwrap(); + let mut out = Int64Array::builder(input.len()); + for i in 0..input.len() { + if input.is_null(i) { + out.append_null(); + } else { + out.append_value(estimate_from_bytes(input.value(i))?); + } + } + Ok(ColumnarValue::Array(Arc::new(out.finish()))) +} + +// Spark's HllUnion is a TernaryExpression (first, second, third=allowDifferentLgConfigK). +// It builds `new Union(min(k1, k2))`, throws when the two sketches have different +// lgConfigK and the flag is false, and returns an HLL_8 sketch. +/// Spark hll_union(first, second, allowDifferentLgConfigK): union two sketch columns. +pub fn spark_hll_union(args: &[ColumnarValue]) -> Result { + use crate::agg_funcs::{SparkHllSketch, SparkHllUnion}; + use arrow::array::BooleanArray; + let arrays = ColumnarValue::values_to_arrays(args)?; + let a = arrays[0].as_any().downcast_ref::().unwrap(); + let b = arrays[1].as_any().downcast_ref::().unwrap(); + let allow = arrays[2].as_any().downcast_ref::().unwrap(); + let mut out = arrow::array::BinaryBuilder::new(); + for i in 0..a.len() { + if a.is_null(i) || b.is_null(i) { + out.append_null(); + continue; + } + let sa = SparkHllSketch::from_bytes(a.value(i))?; + let sb = SparkHllSketch::from_bytes(b.value(i))?; + let allow_i = !allow.is_null(i) && allow.value(i); + if !allow_i && sa.lg_config_k() != sb.lg_config_k() { + return Err(DataFusionError::Execution(format!( + "Sketches have different lgConfigK values: {} and {}. \ + Set allowDifferentLgConfigK to true to enable unions of different lgConfigK.", + sa.lg_config_k(), + sb.lg_config_k() + ))); + } + // Spark builds `new Union(min(k1, k2))`. + let mut u = SparkHllUnion::new(sa.lg_config_k().min(sb.lg_config_k())); + u.merge(&sa); + u.merge(&sb); + out.append_value(u.to_sketch_bytes()); + } + Ok(ColumnarValue::Array(Arc::new(out.finish()))) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::agg_funcs::SparkHllSketch; + + #[test] + fn estimates_from_sketch_column() { + let mut s = SparkHllSketch::new(12); + for i in 0..1000i64 { + s.update_i64(i); + } + let arr = Arc::new(BinaryArray::from(vec![Some( + s.to_sketch_bytes().as_slice(), + )])); + let out = spark_hll_sketch_estimate(&[ColumnarValue::Array(arr)]).unwrap(); + let ColumnarValue::Array(a) = out else { + panic!() + }; + let est = a.as_any().downcast_ref::().unwrap().value(0); + assert!((est - 1000).abs() <= 30, "estimate {est}"); + } +} + +#[cfg(test)] +mod union_tests { + use super::*; + use crate::agg_funcs::SparkHllSketch; + use arrow::array::BooleanArray; + + #[test] + fn unions_two_sketch_columns() { + let mut a = SparkHllSketch::new(12); + for i in 0..1000i64 { + a.update_i64(i); + } + let mut b = SparkHllSketch::new(12); + for i in 500..1500i64 { + b.update_i64(i); + } + let aa = Arc::new(BinaryArray::from(vec![Some( + a.to_sketch_bytes().as_slice(), + )])); + let bb = Arc::new(BinaryArray::from(vec![Some( + b.to_sketch_bytes().as_slice(), + )])); + let allow = Arc::new(BooleanArray::from(vec![false])); + let out = spark_hll_union(&[ + ColumnarValue::Array(aa), + ColumnarValue::Array(bb), + ColumnarValue::Array(allow), + ]) + .unwrap(); + let ColumnarValue::Array(arr) = out else { + panic!() + }; + let est = estimate_from_bytes(arr.as_any().downcast_ref::().unwrap().value(0)) + .unwrap(); + assert!((est - 1500).abs() <= 45, "estimate {est}"); + } +} diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index 174a4ada9a..2260953155 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -59,6 +59,9 @@ pub mod jvm_udf; mod conditional_funcs; mod conversion_funcs; +mod hll_scalar; +pub use hll_scalar::spark_hll_sketch_estimate; +pub use hll_scalar::spark_hll_union; mod map_funcs; pub use map_funcs::spark_map_sort; mod math_funcs; diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index b752f41d74..379e695192 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -384,27 +384,30 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim { /** * Mapping of Spark aggregate expression class to Comet expression handler. */ - val aggrSerdeMap: Map[Class[_], CometAggregateExpressionSerde[_]] = Map( - classOf[Average] -> CometAverage, - classOf[BitAndAgg] -> CometBitAndAgg, - classOf[BitOrAgg] -> CometBitOrAgg, - classOf[BitXorAgg] -> CometBitXOrAgg, - classOf[BloomFilterAggregate] -> CometBloomFilterAggregate, - classOf[CollectSet] -> CometCollectSet, - classOf[Corr] -> CometCorr, - classOf[Count] -> CometCount, - classOf[CovPopulation] -> CometCovPopulation, - classOf[CovSample] -> CometCovSample, - classOf[First] -> CometFirst, - classOf[Last] -> CometLast, - classOf[Max] -> CometMax, - classOf[Min] -> CometMin, - classOf[Percentile] -> CometPercentile, - classOf[StddevPop] -> CometStddevPop, - classOf[StddevSamp] -> CometStddevSamp, - classOf[Sum] -> CometSum, - classOf[VariancePop] -> CometVariancePop, - classOf[VarianceSamp] -> CometVarianceSamp) + val aggrSerdeMap: Map[Class[_], CometAggregateExpressionSerde[_]] = { + val base: Map[Class[_], CometAggregateExpressionSerde[_]] = Map( + classOf[Average] -> CometAverage, + classOf[BitAndAgg] -> CometBitAndAgg, + classOf[BitOrAgg] -> CometBitOrAgg, + classOf[BitXorAgg] -> CometBitXOrAgg, + classOf[BloomFilterAggregate] -> CometBloomFilterAggregate, + classOf[CollectSet] -> CometCollectSet, + classOf[Corr] -> CometCorr, + classOf[Count] -> CometCount, + classOf[CovPopulation] -> CometCovPopulation, + classOf[CovSample] -> CometCovSample, + classOf[First] -> CometFirst, + classOf[Last] -> CometLast, + classOf[Max] -> CometMax, + classOf[Min] -> CometMin, + classOf[Percentile] -> CometPercentile, + classOf[StddevPop] -> CometStddevPop, + classOf[StddevSamp] -> CometStddevSamp, + classOf[Sum] -> CometSum, + classOf[VariancePop] -> CometVariancePop, + classOf[VarianceSamp] -> CometVarianceSamp) + base ++ sparkVersionSpecificAggregates + } /** * Returns true if all aggregate expressions in the list have intermediate buffer formats that diff --git a/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala index 1ad9ec75bf..43837f266f 100644 --- a/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.Sum import org.apache.comet.expressions.CometEvalMode -import org.apache.comet.serde.{CometExpressionSerde, CometStringDecode} +import org.apache.comet.serde.{CometAggregateExpressionSerde, CometExpressionSerde, CometStringDecode} import org.apache.comet.serde.ExprOuterClass.{BinaryOutputStyle, Expr} /** @@ -44,6 +44,8 @@ trait CometExprShim { Map.empty def sparkVersionSpecificMapExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map.empty + def sparkVersionSpecificAggregates: Map[Class[_], CometAggregateExpressionSerde[_]] = + Map.empty def sparkVersionSpecificExprToProtoInternal( expr: Expression, diff --git a/spark/src/main/spark-3.5/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-3.5/org/apache/comet/shims/CometExprShim.scala index 0be1185f59..42290ab565 100644 --- a/spark/src/main/spark-3.5/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-3.5/org/apache/comet/shims/CometExprShim.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.Sum import org.apache.comet.expressions.CometEvalMode -import org.apache.comet.serde.{CometExpressionSerde, CometStringDecode, CometToPrettyString, CometWidthBucket} +import org.apache.comet.serde.{CometAggregateExpressionSerde, CometExpressionSerde, CometStringDecode, CometToPrettyString, CometWidthBucket} import org.apache.comet.serde.ExprOuterClass.{BinaryOutputStyle, Expr} /** @@ -44,6 +44,8 @@ trait CometExprShim { Map(classOf[ToPrettyString] -> CometToPrettyString) def sparkVersionSpecificMapExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map.empty + def sparkVersionSpecificAggregates: Map[Class[_], CometAggregateExpressionSerde[_]] = + Map.empty def sparkVersionSpecificExprToProtoInternal( expr: Expression, diff --git a/spark/src/main/spark-4.x/org/apache/comet/serde/CometHllSketchAgg.scala b/spark/src/main/spark-4.x/org/apache/comet/serde/CometHllSketchAgg.scala new file mode 100644 index 0000000000..5141e0c952 --- /dev/null +++ b/spark/src/main/spark-4.x/org/apache/comet/serde/CometHllSketchAgg.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, HllSketchAgg} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{BinaryType, IntegerType, LongType, StringType} + +import org.apache.comet.CometSparkSessionExtensions.withFallbackReason +import org.apache.comet.serde.QueryPlanSerde.exprToProto + +// In Spark 4.0, HllSketchAgg's fields are `left` (the child value expression) and `right` (the +// lgConfigK expression), not `child`/`lgConfigKExpression`. Accepted input types are Integer, +// Long, String, and Binary (Byte/Short are not accepted by Spark's HllSketchAgg). +object CometHllSketchAgg extends CometAggregateExpressionSerde[HllSketchAgg] { + + // DataSketches valid lgConfigK range; outside this Spark itself throws, so we + // fall back rather than forward an out-of-range value to native. + private val MinLgConfigK = 4 + private val MaxLgConfigK = 21 + + private val nonLiteralLgConfigKReason = + "The lgConfigK argument must be a foldable literal." + private val inputTypeReason = + "Only int, long, string, and binary input types are supported." + private val incompatReason = + "Comet uses a Rust DataSketches port; HLL sketch bytes and estimates may differ " + + "slightly from Spark." + + override def getUnsupportedReasons(): Seq[String] = + Seq(nonLiteralLgConfigKReason, inputTypeReason) + + override def getIncompatibleReasons(): Seq[String] = Seq(incompatReason) + + override def getSupportLevel(expr: HllSketchAgg): SupportLevel = { + if (!expr.right.foldable) { + return Unsupported(Some(nonLiteralLgConfigKReason)) + } + val lgConfigK = expr.right.eval() match { + case i: Int => i + case l: Long => l.toInt + case _ => return Unsupported(Some(nonLiteralLgConfigKReason)) + } + if (lgConfigK < MinLgConfigK || lgConfigK > MaxLgConfigK) { + return Unsupported(Some(s"lgConfigK must be in [$MinLgConfigK, $MaxLgConfigK]")) + } + expr.left.dataType match { + case IntegerType | LongType | StringType | BinaryType => + Incompatible(Some(incompatReason)) + case _ => Unsupported(Some(inputTypeReason)) + } + } + + override def convert( + aggExpr: AggregateExpression, + expr: HllSketchAgg, + inputs: Seq[Attribute], + binding: Boolean, + conf: SQLConf): Option[ExprOuterClass.AggExpr] = { + val childExpr = exprToProto(expr.left, inputs, binding) + val lgConfigK = expr.right.eval() match { + case i: Int => i + case l: Long => l.toInt + case other => + withFallbackReason(aggExpr, s"Unsupported lgConfigK literal: $other", expr.left) + return None + } + if (childExpr.isDefined) { + val builder = ExprOuterClass.HllSketchAgg.newBuilder() + builder.setChild(childExpr.get) + builder.setLgConfigK(lgConfigK) + Some(ExprOuterClass.AggExpr.newBuilder().setHllSketchAgg(builder).build()) + } else { + withFallbackReason(aggExpr, expr.left) + None + } + } +} diff --git a/spark/src/main/spark-4.x/org/apache/comet/serde/CometHllSketchEstimate.scala b/spark/src/main/spark-4.x/org/apache/comet/serde/CometHllSketchEstimate.scala new file mode 100644 index 0000000000..454b53eed8 --- /dev/null +++ b/spark/src/main/spark-4.x/org/apache/comet/serde/CometHllSketchEstimate.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde + +import org.apache.spark.sql.catalyst.expressions.{Attribute, HllSketchEstimate} +import org.apache.spark.sql.types.LongType + +import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithFallbackReason, scalarFunctionExprToProtoWithReturnType} + +object CometHllSketchEstimate extends CometExpressionSerde[HllSketchEstimate] { + private val incompatReason = + "Comet uses a Rust DataSketches port; HLL estimates may differ slightly from Spark." + + override def getIncompatibleReasons(): Seq[String] = Seq(incompatReason) + + override def getSupportLevel(expr: HllSketchEstimate): SupportLevel = + Incompatible(Some(incompatReason)) + + override def convert( + expr: HllSketchEstimate, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val childExpr = exprToProtoInternal(expr.child, inputs, binding) + val estimateExpr = scalarFunctionExprToProtoWithReturnType( + "hll_sketch_estimate", + LongType, + failOnError = false, + childExpr) + optExprWithFallbackReason(estimateExpr, expr, expr.child) + } +} diff --git a/spark/src/main/spark-4.x/org/apache/comet/serde/CometHllUnion.scala b/spark/src/main/spark-4.x/org/apache/comet/serde/CometHllUnion.scala new file mode 100644 index 0000000000..af7b330966 --- /dev/null +++ b/spark/src/main/spark-4.x/org/apache/comet/serde/CometHllUnion.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde + +import org.apache.spark.sql.catalyst.expressions.{Attribute, HllUnion} +import org.apache.spark.sql.types.BinaryType + +import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithFallbackReason, scalarFunctionExprToProtoWithReturnType} + +// Spark 4.0 HllUnion is a TernaryExpression: `first`, `second` (binary sketches), +// `third` (allowDifferentLgConfigK boolean, default Literal(false)). All three are +// passed to the native `hll_union`, which enforces the flag and unions with min lgConfigK. +object CometHllUnion extends CometExpressionSerde[HllUnion] { + private val incompatReason = + "Comet uses a Rust DataSketches port; HLL sketch bytes and estimates may differ slightly from Spark." + + override def getIncompatibleReasons(): Seq[String] = Seq(incompatReason) + + override def getSupportLevel(expr: HllUnion): SupportLevel = { + if (!expr.third.foldable) { + Unsupported(Some("The allowDifferentLgConfigK argument must be a foldable literal.")) + } else { + Incompatible(Some(incompatReason)) + } + } + override def convert( + expr: HllUnion, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val first = exprToProtoInternal(expr.first, inputs, binding) + val second = exprToProtoInternal(expr.second, inputs, binding) + val third = exprToProtoInternal(expr.third, inputs, binding) + val unionExpr = scalarFunctionExprToProtoWithReturnType( + "hll_union", + BinaryType, + failOnError = false, + first, + second, + third) + optExprWithFallbackReason(unionExpr, expr, expr.first, expr.second, expr.third) + } +} diff --git a/spark/src/main/spark-4.x/org/apache/comet/serde/CometHllUnionAgg.scala b/spark/src/main/spark-4.x/org/apache/comet/serde/CometHllUnionAgg.scala new file mode 100644 index 0000000000..156f26e675 --- /dev/null +++ b/spark/src/main/spark-4.x/org/apache/comet/serde/CometHllUnionAgg.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, HllUnionAgg} +import org.apache.spark.sql.internal.SQLConf + +import org.apache.comet.CometSparkSessionExtensions.withFallbackReason +import org.apache.comet.serde.QueryPlanSerde.exprToProto + +// IMPORTANT: In Spark 4.0, HllUnionAgg's fields are `left` (the child binary sketch +// expression) and `right` (the allowDifferentLgConfigK boolean expression) - NOT +// `child`/`allowDifferentLgConfigKExpression`. (Verified against Spark 4.0 source.) +object CometHllUnionAgg extends CometAggregateExpressionSerde[HllUnionAgg] { + + private val incompatReason = + "Comet uses a Rust DataSketches port; HLL sketch bytes and estimates may differ slightly " + + "from Spark." + + override def getIncompatibleReasons(): Seq[String] = Seq(incompatReason) + + override def getSupportLevel(expr: HllUnionAgg): SupportLevel = { + if (!expr.right.foldable) { + Unsupported(Some("The allowDifferentLgConfigK argument must be a foldable literal.")) + } else { + Incompatible(Some(incompatReason)) + } + } + + override def convert( + aggExpr: AggregateExpression, + expr: HllUnionAgg, + inputs: Seq[Attribute], + binding: Boolean, + conf: SQLConf): Option[ExprOuterClass.AggExpr] = { + val childExpr = exprToProto(expr.left, inputs, binding) + val allow = expr.right.eval() match { + case b: Boolean => b + case other => + withFallbackReason( + aggExpr, + s"Unsupported allowDifferentLgConfigK literal: $other", + expr.left) + return None + } + if (childExpr.isDefined) { + val builder = ExprOuterClass.HllUnionAgg.newBuilder() + builder.setChild(childExpr.get) + builder.setAllowDifferentLgConfigK(allow) + Some(ExprOuterClass.AggExpr.newBuilder().setHllUnionAgg(builder).build()) + } else { + withFallbackReason(aggExpr, expr.left) + None + } + } +} diff --git a/spark/src/main/spark-4.x/org/apache/comet/shims/Spark4xCometExprShim.scala b/spark/src/main/spark-4.x/org/apache/comet/shims/Spark4xCometExprShim.scala index 7efd17f68a..93cc2ef184 100644 --- a/spark/src/main/spark-4.x/org/apache/comet/shims/Spark4xCometExprShim.scala +++ b/spark/src/main/spark-4.x/org/apache/comet/shims/Spark4xCometExprShim.scala @@ -20,6 +20,7 @@ package org.apache.comet.shims import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{HllSketchAgg, HllUnionAgg} import org.apache.spark.sql.catalyst.expressions.json.{JsonExpressionUtils, StructsToJsonEvaluator} import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke} import org.apache.spark.sql.catalyst.expressions.url.ParseUrlEvaluator @@ -27,7 +28,7 @@ import org.apache.spark.sql.types.ArrayType import org.apache.comet.CometExplainInfo import org.apache.comet.expressions.CometEvalMode -import org.apache.comet.serde.{CometExpressionSerde, CometMapSort, CometToPrettyString, CometWidthBucket} +import org.apache.comet.serde.{CometAggregateExpressionSerde, CometExpressionSerde, CometHllSketchAgg, CometHllSketchEstimate, CometHllUnion, CometHllUnionAgg, CometMapSort, CometToPrettyString, CometWidthBucket} import org.apache.comet.serde.ExprOuterClass.Expr import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithFallbackReason, scalarFunctionExprToProtoWithReturnType} @@ -46,9 +47,14 @@ trait Spark4xCometExprShim extends CometExprShim4x { def sparkVersionSpecificMathExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map(classOf[WidthBucket] -> CometWidthBucket) def sparkVersionSpecificMiscExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = - Map(classOf[ToPrettyString] -> CometToPrettyString) + Map( + classOf[ToPrettyString] -> CometToPrettyString, + classOf[HllSketchEstimate] -> CometHllSketchEstimate, + classOf[HllUnion] -> CometHllUnion) def sparkVersionSpecificMapExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map(classOf[MapSort] -> CometMapSort) + def sparkVersionSpecificAggregates: Map[Class[_], CometAggregateExpressionSerde[_]] = + Map(classOf[HllSketchAgg] -> CometHllSketchAgg, classOf[HllUnionAgg] -> CometHllUnionAgg) def sparkVersionSpecificExprToProtoInternal( expr: Expression, diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 390a7c4908..9605f5dd16 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -3380,4 +3380,96 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("hll_sketch_agg and hll_sketch_estimate (incompatible, opt-in)") { + assume(isSpark40Plus) + // HLL is approximate: Comet's Rust DataSketches estimator differs slightly from + // Spark's after a merge, so these functions are Incompatible. Opt in, assert the + // query runs natively (no fallback), and that the estimate is within HLL error of + // the TRUE distinct count (700). Do NOT compare bit-exactly to Spark. + withSQLConf( + "spark.comet.expression.HllSketchAgg.allowIncompatible" -> "true", + "spark.comet.expression.HllSketchEstimate.allowIncompatible" -> "true") { + withParquetTable((0 until 1000).map(i => (i % 700, i)), "tbl") { + def checkEstimate(query: String): Unit = { + val df = sql(query) + checkCometOperators(stripAQEPlan(df.queryExecution.executedPlan)) + val est = df.collect().head.getLong(0) + assert( + math.abs(est - 700).toDouble / 700 <= 0.05, + s"estimate $est not within 5% of the true distinct count 700 for: $query") + } + checkEstimate("SELECT hll_sketch_estimate(hll_sketch_agg(_1)) FROM tbl") + checkEstimate("SELECT hll_sketch_estimate(hll_sketch_agg(_1, 14)) FROM tbl") + checkEstimate("SELECT hll_sketch_estimate(hll_sketch_agg(cast(_1 as string))) FROM tbl") + } + } + } + + test("hll_union_agg and hll_union (incompatible, opt-in)") { + assume(isSpark40Plus) + withSQLConf( + "spark.comet.expression.HllSketchAgg.allowIncompatible" -> "true", + "spark.comet.expression.HllSketchEstimate.allowIncompatible" -> "true", + "spark.comet.expression.HllUnionAgg.allowIncompatible" -> "true", + "spark.comet.expression.HllUnion.allowIncompatible" -> "true") { + withParquetTable((0 until 1000).map(i => (i % 3, i)), "tbl") { + // hll_union_agg: union the per-group sketches -> ~1000 distinct. + val aggDf = sql( + "SELECT hll_sketch_estimate(hll_union_agg(s)) FROM " + + "(SELECT _1 AS g, hll_sketch_agg(_2) AS s FROM tbl GROUP BY _1)") + checkCometOperators(stripAQEPlan(aggDf.queryExecution.executedPlan)) + val aggEst = aggDf.collect().head.getLong(0) + assert(math.abs(aggEst - 1000).toDouble / 1000 <= 0.05, s"union_agg estimate $aggEst") + + // hll_union: union two disjoint group sketches -> ~667 distinct. + val unionDf = sql( + "SELECT hll_sketch_estimate(hll_union(a.s, b.s)) FROM " + + "(SELECT hll_sketch_agg(_2) AS s FROM tbl WHERE _1 = 0) a, " + + "(SELECT hll_sketch_agg(_2) AS s FROM tbl WHERE _1 = 1) b") + checkCometOperators(stripAQEPlan(unionDf.queryExecution.executedPlan)) + val unionEst = unionDf.collect().head.getLong(0) + assert(math.abs(unionEst - 667).toDouble / 667 <= 0.05, s"union estimate $unionEst") + } + } + } + + test("hll_union_agg rejects different lgConfigK when not allowed") { + assume(isSpark40Plus) + withSQLConf( + "spark.comet.expression.HllSketchAgg.allowIncompatible" -> "true", + "spark.comet.expression.HllUnionAgg.allowIncompatible" -> "true") { + withParquetTable((0 until 100).map(i => Tuple1(i)), "tbl") { + // A lgConfigK=10 sketch unioned with a lgConfigK=12 sketch (allowDifferentLgConfigK + // defaults false) must throw in BOTH Spark and Comet. + val df = sql( + "SELECT hll_union_agg(s) FROM (" + + " SELECT hll_sketch_agg(_1, 10) AS s FROM tbl UNION ALL" + + " SELECT hll_sketch_agg(_1, 12) AS s FROM tbl)") + val (sparkErr, cometErr) = checkSparkAnswerMaybeThrows(df) + assert(sparkErr.isDefined, "expected Spark to throw on different lgConfigK") + assert(cometErr.isDefined, "expected Comet to throw on different lgConfigK") + } + } + } + + test("hll_sketch_agg over all-null input estimates to 0, not NULL") { + assume(isSpark40Plus) + // Spark's HllSketchAgg/HllSketchEstimate are declared non-nullable: an empty or + // all-null group still produces a serialized empty sketch, and hll_sketch_estimate + // reads that as 0, never NULL. Guard against regressing to Binary(None) here. + withSQLConf( + "spark.comet.expression.HllSketchAgg.allowIncompatible" -> "true", + "spark.comet.expression.HllSketchEstimate.allowIncompatible" -> "true") { + withParquetTable((0 until 100).map(_ => Tuple1(null.asInstanceOf[Integer])), "tbl") { + val df = sql("SELECT hll_sketch_estimate(hll_sketch_agg(_1)) FROM tbl") + checkCometOperators(stripAQEPlan(df.queryExecution.executedPlan)) + val row = df.collect().head + assert(!row.isNullAt(0), "expected a non-null estimate for an all-null group") + assert( + row.getLong(0) == 0, + s"expected estimate 0 for an all-null group, got ${row.getLong(0)}") + } + } + } + }