From 3d0da09769a8ec6cf16af3a61179017a8ed5cce4 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 30 Jun 2026 12:12:45 -0600 Subject: [PATCH 1/2] feat: implement regr_slope, regr_intercept, regr_r2, regr_sxx, regr_syy, regr_sxy aggregates Add native support for the six simple linear regression aggregates that previously fell back to Spark. regr_avgx, regr_avgy and regr_count were already accelerated because Spark rewrites them to Average/Count. The native accumulators are composed from Comet's existing Spark-compatible covariance and variance accumulators so the partial aggregation state matches the buffer layout Spark's planner expects between partial and final stages: RegrReplacement (regr_sxx/regr_syy) -> 3 fields, Covariance (regr_sxy) -> 4, PearsonCorrelation (regr_r2) -> 6, and the slope/intercept composite -> 7. regr_r2 matches Spark's behavior of returning 1.0 when the dependent variable is constant but the independent variable varies (a perfect horizontal fit), which differs from DataFusion's regr_r2. --- docs/source/user-guide/latest/expressions.md | 12 +- native/core/src/execution/planner.rs | 22 +- native/proto/src/proto/expr.proto | 19 + native/spark-expr/src/agg_funcs/mod.rs | 2 + native/spark-expr/src/agg_funcs/regr.rs | 591 ++++++++++++++++++ .../apache/comet/serde/QueryPlanSerde.scala | 5 + .../org/apache/comet/serde/aggregates.scala | 121 +++- .../sql-tests/expressions/aggregate/regr.sql | 82 ++- 8 files changed, 817 insertions(+), 37 deletions(-) create mode 100644 native/spark-expr/src/agg_funcs/regr.rs diff --git a/docs/source/user-guide/latest/expressions.md b/docs/source/user-guide/latest/expressions.md index b8addccd32..2e62839475 100644 --- a/docs/source/user-guide/latest/expressions.md +++ b/docs/source/user-guide/latest/expressions.md @@ -107,12 +107,12 @@ The tables below list every Spark built-in expression with its current status. | `regr_avgx` | ✅ | Native: Spark rewrites to `Average` (tests in [#4551](https://github.com/apache/datafusion-comet/issues/4551)) | | `regr_avgy` | ✅ | Native: Spark rewrites to `Average` (tests in [#4551](https://github.com/apache/datafusion-comet/issues/4551)) | | `regr_count` | ✅ | Native: Spark rewrites to `Count` (tests in [#4551](https://github.com/apache/datafusion-comet/issues/4551)) | -| `regr_intercept` | 🔜 | Falls back; can reuse `covar_pop`/`var_pop` accumulators ([#4552](https://github.com/apache/datafusion-comet/issues/4552)) | -| `regr_r2` | 🔜 | Falls back; can reuse the `corr` accumulator ([#4552](https://github.com/apache/datafusion-comet/issues/4552)) | -| `regr_slope` | 🔜 | Falls back; can reuse `covar_pop`/`var_pop` accumulators ([#4552](https://github.com/apache/datafusion-comet/issues/4552)) | -| `regr_sxx` | 🔜 | Falls back; can reuse `var_pop` accumulator ([#4552](https://github.com/apache/datafusion-comet/issues/4552)) | -| `regr_sxy` | 🔜 | Falls back; can reuse `covar_pop` accumulator ([#4552](https://github.com/apache/datafusion-comet/issues/4552)) | -| `regr_syy` | 🔜 | Falls back; can reuse `var_pop` accumulator ([#4552](https://github.com/apache/datafusion-comet/issues/4552)) | +| `regr_intercept` | ✅ | | +| `regr_r2` | ✅ | | +| `regr_slope` | ✅ | | +| `regr_sxx` | ✅ | | +| `regr_sxy` | ✅ | | +| `regr_syy` | ✅ | | | `skewness` | 🔜 | tracking [#4098](https://github.com/apache/datafusion-comet/issues/4098) | | `some` | ✅ | | | `std` | ✅ | | diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 7c06e60716..a45196cb1a 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -129,8 +129,8 @@ use datafusion_comet_proto::{ use datafusion_comet_spark_expr::{ jvm_udf::JvmScalarUdfExpr, ArrayInsert, Avg, AvgDecimal, Cast, CheckOverflow, Correlation, Covariance, CreateNamedStruct, DecimalRescaleCheckOverflow, GetArrayStructFields, - GetStructField, IfExpr, ListExtract, NormalizeNaNAndZero, SparkCastOptions, Stddev, SumDecimal, - ToJson, UnboundColumn, Variance, WideDecimalBinaryExpr, WideDecimalOp, + GetStructField, IfExpr, ListExtract, NormalizeNaNAndZero, Regr, RegrType, SparkCastOptions, + Stddev, SumDecimal, ToJson, UnboundColumn, Variance, WideDecimalBinaryExpr, WideDecimalOp, }; use itertools::Itertools; use jni::objects::{Global, JObject}; @@ -2601,6 +2601,24 @@ impl PhysicalPlanner { )); Self::create_aggr_func_expr("correlation", schema, vec![child1, child2], func) } + AggExprStruct::Regr(expr) => { + let child1 = + self.create_expr(expr.child1.as_ref().unwrap(), Arc::clone(&schema))?; + let child2 = + self.create_expr(expr.child2.as_ref().unwrap(), Arc::clone(&schema))?; + let (regr_type, name) = match expr.regr_type() { + spark_expression::regr::RegrType::Slope => (RegrType::Slope, "regr_slope"), + spark_expression::regr::RegrType::Intercept => { + (RegrType::Intercept, "regr_intercept") + } + spark_expression::regr::RegrType::R2 => (RegrType::R2, "regr_r2"), + spark_expression::regr::RegrType::Sxx => (RegrType::SXX, "regr_sxx"), + spark_expression::regr::RegrType::Syy => (RegrType::SYY, "regr_syy"), + spark_expression::regr::RegrType::Sxy => (RegrType::SXY, "regr_sxy"), + }; + let func = AggregateUDF::new_from_impl(Regr::new(regr_type, name)); + Self::create_aggr_func_expr(name, schema, vec![child1, child2], func) + } AggExprStruct::BloomFilterAgg(expr) => { let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?; let num_items = diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 90e3d87032..97e6746296 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -143,6 +143,7 @@ message AggExpr { Correlation correlation = 15; BloomFilterAgg bloomFilterAgg = 16; CollectSet collectSet = 17; + Regr regr = 18; } // Optional filter expression for SQL FILTER (WHERE ...) clause. @@ -244,6 +245,24 @@ message Correlation { DataType datatype = 4; } +// Simple linear regression aggregates (regr_slope, regr_intercept, regr_r2, +// regr_sxx, regr_syy, regr_sxy). child1 is the dependent variable (y) and +// child2 is the independent variable (x). +message Regr { + enum RegrType { + SLOPE = 0; + INTERCEPT = 1; + R2 = 2; + SXX = 3; + SYY = 4; + SXY = 5; + } + Expr child1 = 1; + Expr child2 = 2; + RegrType regr_type = 3; + DataType datatype = 4; +} + message BloomFilterAgg { Expr child = 1; Expr numItems = 2; diff --git a/native/spark-expr/src/agg_funcs/mod.rs b/native/spark-expr/src/agg_funcs/mod.rs index 2a0322e46c..4830532b5f 100644 --- a/native/spark-expr/src/agg_funcs/mod.rs +++ b/native/spark-expr/src/agg_funcs/mod.rs @@ -19,6 +19,7 @@ mod avg; mod avg_decimal; mod correlation; mod covariance; +mod regr; mod stddev; mod sum_decimal; mod sum_int; @@ -29,6 +30,7 @@ pub use avg::Avg; pub use avg_decimal::AvgDecimal; pub use correlation::Correlation; pub use covariance::Covariance; +pub use regr::{Regr, RegrType}; pub use stddev::Stddev; pub use sum_decimal::SumDecimal; pub use sum_int::SumInteger; diff --git a/native/spark-expr/src/agg_funcs/regr.rs b/native/spark-expr/src/agg_funcs/regr.rs new file mode 100644 index 0000000000..12977496a4 --- /dev/null +++ b/native/spark-expr/src/agg_funcs/regr.rs @@ -0,0 +1,591 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Spark-compatible `regr_*` simple linear regression aggregates. +//! +//! For `regr_*(y, x)`, the first argument `y` is the dependent variable and the +//! second argument `x` is the independent variable. Only rows where both `y` +//! and `x` are non-null take part in the aggregation. +//! +//! Each function is composed from Comet's Spark-compatible +//! [`CovarianceAccumulator`] and [`VarianceAccumulator`] so that the partial +//! aggregation state Comet emits matches the buffer layout Spark's planner +//! expects between the partial and final aggregation stages: +//! +//! | function | Spark aggregate | state fields | +//! |---------------------------|------------------------|--------------| +//! | `regr_sxx` / `regr_syy` | `RegrReplacement` | 3 | +//! | `regr_sxy` | `Covariance` | 4 | +//! | `regr_r2` | `PearsonCorrelation` | 6 | +//! | `regr_slope`/`_intercept` | composite cov + var | 7 | + +use arrow::array::{Array, ArrayRef}; +use arrow::compute::{and, filter, is_not_null}; +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion::common::{Result, ScalarValue}; +use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion::logical_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility}; +use datafusion::physical_expr::expressions::format_state_name; +use datafusion::physical_expr::expressions::StatsType; +use std::sync::Arc; + +use crate::agg_funcs::covariance::CovarianceAccumulator; +use crate::agg_funcs::variance::VarianceAccumulator; + +/// The kind of linear-regression statistic to compute. +/// +/// `regr_count`, `regr_avgx` and `regr_avgy` are rewritten by Spark to +/// `Count`/`Average`, so they never reach this accumulator. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum RegrType { + /// Slope of the least-squares regression line: `cov_pop(x, y) / var_pop(x)`. + Slope, + /// Intercept of the regression line: `mean_y - slope * mean_x`. + Intercept, + /// Coefficient of determination (R squared). + R2, + /// Sum of squares of the independent variable: `m2(x)`. + SXX, + /// Sum of squares of the dependent variable: `m2(y)`. + SYY, + /// Sum of products of the paired deviations: `ck`. + SXY, +} + +/// Comet's Spark-compatible `regr_*` aggregate UDF. +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct Regr { + name: String, + signature: Signature, + regr_type: RegrType, +} + +impl Regr { + pub fn new(regr_type: RegrType, name: impl Into) -> Self { + Self { + name: name.into(), + signature: Signature::exact( + vec![DataType::Float64, DataType::Float64], + Volatility::Immutable, + ), + regr_type, + } + } + + /// Number of `m2` / covariance sub-states this statistic carries, which + /// determines the partial aggregation buffer width. + fn num_state_fields(&self) -> usize { + match self.regr_type { + // RegrReplacement (CentralMomentAgg): count, mean, m2 + RegrType::SXX | RegrType::SYY => 3, + // Covariance: count, mean1, mean2, algo_const + RegrType::SXY => 4, + // PearsonCorrelation: count, mean1, mean2, algo_const, m2_1, m2_2 + RegrType::R2 => 6, + // CovPopulation (4) ++ VariancePop (3) + RegrType::Slope | RegrType::Intercept => 7, + } + } +} + +impl AggregateUDFImpl for Regr { + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + fn default_value(&self, _data_type: &DataType) -> Result { + Ok(ScalarValue::Float64(None)) + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + let acc: Box = match self.regr_type { + RegrType::SXX | RegrType::SYY => Box::new(RegrMomentAccumulator::try_new()?), + RegrType::SXY => Box::new(RegrCovAccumulator::try_new()?), + RegrType::R2 => Box::new(RegrR2Accumulator::try_new()?), + RegrType::Slope => Box::new(RegrLineAccumulator::try_new(false)?), + RegrType::Intercept => Box::new(RegrLineAccumulator::try_new(true)?), + }; + Ok(acc) + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + let names: &[&str] = match self.num_state_fields() { + 3 => &["count", "mean", "m2"], + 4 => &["count", "mean1", "mean2", "algo_const"], + 6 => &["count", "mean1", "mean2", "algo_const", "m2_1", "m2_2"], + _ => &[ + "count", + "mean1", + "mean2", + "algo_const", + "var_count", + "var_mean", + "var_m2", + ], + }; + Ok(names + .iter() + .map(|n| { + Arc::new(Field::new( + format_state_name(args.name, n), + DataType::Float64, + true, + )) + }) + .collect()) + } +} + +/// Keep only the rows where both inputs are non-null, mirroring Spark's +/// "regr functions operate on non-null pairs" rule. +fn filter_pairs(values: &[ArrayRef]) -> Result> { + if values[0].null_count() == 0 && values[1].null_count() == 0 { + return Ok(values.to_vec()); + } + let mask = and(&is_not_null(&values[0])?, &is_not_null(&values[1])?)?; + Ok(vec![filter(&values[0], &mask)?, filter(&values[1], &mask)?]) +} + +/// `regr_sxx` / `regr_syy`: `m2` of the (already null-filtered) child column. +/// Mirrors Spark's `RegrReplacement` (a `CentralMomentAgg` whose evaluate is +/// `m2`). State layout matches `VarianceAccumulator` (count, mean, m2). +#[derive(Debug)] +struct RegrMomentAccumulator { + var: VarianceAccumulator, +} + +impl RegrMomentAccumulator { + fn try_new() -> Result { + Ok(Self { + var: VarianceAccumulator::try_new(StatsType::Population, false)?, + }) + } +} + +impl Accumulator for RegrMomentAccumulator { + fn state(&mut self) -> Result> { + self.var.state() + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + // child1 == child2 (Spark's RegrReplacement is single-arg); use one. + self.var.update_batch(&values[0..1]) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.var.merge_batch(states) + } + + fn evaluate(&mut self) -> Result { + if self.var.get_count() == 0.0 { + Ok(ScalarValue::Float64(None)) + } else { + Ok(ScalarValue::Float64(Some(self.var.get_m2()))) + } + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) - std::mem::size_of_val(&self.var) + self.var.size() + } +} + +/// `regr_sxy`: the co-moment `ck` of the non-null pairs. Mirrors Spark's +/// `RegrSXY` (a population `Covariance` whose evaluate is `ck`). State layout +/// matches `CovarianceAccumulator`. +#[derive(Debug)] +struct RegrCovAccumulator { + covar: CovarianceAccumulator, +} + +impl RegrCovAccumulator { + fn try_new() -> Result { + Ok(Self { + covar: CovarianceAccumulator::try_new(StatsType::Population, false)?, + }) + } +} + +impl Accumulator for RegrCovAccumulator { + fn state(&mut self) -> Result> { + self.covar.state() + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + // CovarianceAccumulator already skips pairs where either side is null. + self.covar.update_batch(values) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.covar.merge_batch(states) + } + + fn evaluate(&mut self) -> Result { + if self.covar.get_count() == 0.0 { + Ok(ScalarValue::Float64(None)) + } else { + Ok(ScalarValue::Float64(Some(self.covar.get_algo_const()))) + } + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) - std::mem::size_of_val(&self.covar) + self.covar.size() + } +} + +/// `regr_r2`: the coefficient of determination. Mirrors Spark's `RegrR2` +/// (a `PearsonCorrelation`). State layout matches `CorrelationAccumulator`: +/// count, mean1, mean2, algo_const, m2(y), m2(x). +/// +/// Spark's evaluate differs from DataFusion in one degenerate case: when the +/// dependent variable `y` is constant but `x` varies, Spark returns `1.0` +/// (a horizontal line is a perfect fit) where DataFusion returns `null`. +#[derive(Debug)] +struct RegrR2Accumulator { + covar: CovarianceAccumulator, + var_y: VarianceAccumulator, + var_x: VarianceAccumulator, +} + +impl RegrR2Accumulator { + fn try_new() -> Result { + Ok(Self { + covar: CovarianceAccumulator::try_new(StatsType::Population, false)?, + var_y: VarianceAccumulator::try_new(StatsType::Population, false)?, + var_x: VarianceAccumulator::try_new(StatsType::Population, false)?, + }) + } +} + +impl Accumulator for RegrR2Accumulator { + fn state(&mut self) -> Result> { + let c = self.covar.state()?; + Ok(vec![ + c[0].clone(), + c[1].clone(), + c[2].clone(), + c[3].clone(), + ScalarValue::from(self.var_y.get_m2()), + ScalarValue::from(self.var_x.get_m2()), + ]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + // values[0] = y (dependent), values[1] = x (independent) + let pairs = filter_pairs(values)?; + if pairs[0].is_empty() { + return Ok(()); + } + self.covar.update_batch(&pairs)?; + self.var_y.update_batch(&pairs[0..1])?; + self.var_x.update_batch(&pairs[1..2])?; + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + // state: [count, mean1, mean2, algo_const, m2_y, m2_x] + let covar_state = [ + Arc::clone(&states[0]), + Arc::clone(&states[1]), + Arc::clone(&states[2]), + Arc::clone(&states[3]), + ]; + let var_y_state = [ + Arc::clone(&states[0]), + Arc::clone(&states[1]), + Arc::clone(&states[4]), + ]; + let var_x_state = [ + Arc::clone(&states[0]), + Arc::clone(&states[2]), + Arc::clone(&states[5]), + ]; + self.covar.merge_batch(&covar_state)?; + self.var_y.merge_batch(&var_y_state)?; + self.var_x.merge_batch(&var_x_state)?; + Ok(()) + } + + fn evaluate(&mut self) -> Result { + let count = self.covar.get_count(); + let m2_x = self.var_x.get_m2(); + let m2_y = self.var_y.get_m2(); + if count <= 1.0 || m2_x == 0.0 { + // independent variable has no spread -> undefined + Ok(ScalarValue::Float64(None)) + } else if m2_y == 0.0 { + // dependent variable is constant -> perfect horizontal fit + Ok(ScalarValue::Float64(Some(1.0))) + } else { + let ck = self.covar.get_algo_const(); + Ok(ScalarValue::Float64(Some((ck * ck) / (m2_x * m2_y)))) + } + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) - std::mem::size_of_val(&self.covar) + self.covar.size() + - std::mem::size_of_val(&self.var_y) + + self.var_y.size() + - std::mem::size_of_val(&self.var_x) + + self.var_x.size() + } +} + +/// `regr_slope` / `regr_intercept`. Mirrors Spark's `RegrSlope` / `RegrIntercept` +/// declarative aggregates, whose buffer is `CovPopulation(x, y)` (4 fields) +/// followed by `VariancePop(x)` (3 fields). The covariance is fed `(x, y)` so +/// that `mean1 = mean(x)` and `mean2 = mean(y)`. +#[derive(Debug)] +struct RegrLineAccumulator { + covar: CovarianceAccumulator, + var_x: VarianceAccumulator, + intercept: bool, +} + +impl RegrLineAccumulator { + fn try_new(intercept: bool) -> Result { + Ok(Self { + covar: CovarianceAccumulator::try_new(StatsType::Population, false)?, + var_x: VarianceAccumulator::try_new(StatsType::Population, false)?, + intercept, + }) + } +} + +impl Accumulator for RegrLineAccumulator { + fn state(&mut self) -> Result> { + let mut s = self.covar.state()?; + s.extend(self.var_x.state()?); + Ok(s) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + // values[0] = y (dependent), values[1] = x (independent) + let pairs = filter_pairs(values)?; + if pairs[0].is_empty() { + return Ok(()); + } + // Feed covariance as (x, y) so mean1 = mean(x), mean2 = mean(y). + let cov_input = [Arc::clone(&pairs[1]), Arc::clone(&pairs[0])]; + self.covar.update_batch(&cov_input)?; + self.var_x.update_batch(&pairs[1..2])?; + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let covar_state = [ + Arc::clone(&states[0]), + Arc::clone(&states[1]), + Arc::clone(&states[2]), + Arc::clone(&states[3]), + ]; + let var_x_state = [ + Arc::clone(&states[4]), + Arc::clone(&states[5]), + Arc::clone(&states[6]), + ]; + self.covar.merge_batch(&covar_state)?; + self.var_x.merge_batch(&var_x_state)?; + Ok(()) + } + + fn evaluate(&mut self) -> Result { + let m2_x = self.var_x.get_m2(); + if m2_x == 0.0 { + // independent variable has no spread (also covers count <= 1) + return Ok(ScalarValue::Float64(None)); + } + let slope = self.covar.get_algo_const() / m2_x; + if self.intercept { + // mean(y) - slope * mean(x); covar fed (x, y) => mean1 = mean(x), mean2 = mean(y) + let mean_x = self.covar.get_mean1(); + let mean_y = self.covar.get_mean2(); + Ok(ScalarValue::Float64(Some(mean_y - slope * mean_x))) + } else { + Ok(ScalarValue::Float64(Some(slope))) + } + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) - std::mem::size_of_val(&self.covar) + self.covar.size() + - std::mem::size_of_val(&self.var_x) + + self.var_x.size() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::Float64Array; + + fn acc(regr_type: RegrType) -> Box { + match regr_type { + RegrType::SXX | RegrType::SYY => Box::new(RegrMomentAccumulator::try_new().unwrap()), + RegrType::SXY => Box::new(RegrCovAccumulator::try_new().unwrap()), + RegrType::R2 => Box::new(RegrR2Accumulator::try_new().unwrap()), + RegrType::Slope => Box::new(RegrLineAccumulator::try_new(false).unwrap()), + RegrType::Intercept => Box::new(RegrLineAccumulator::try_new(true).unwrap()), + } + } + + fn cols(y: Vec>, x: Vec>) -> Vec { + vec![ + Arc::new(Float64Array::from(y)) as ArrayRef, + Arc::new(Float64Array::from(x)) as ArrayRef, + ] + } + + fn eval(regr_type: RegrType, y: Vec>, x: Vec>) -> Option { + let mut a = acc(regr_type); + a.update_batch(&cols(y, x)).unwrap(); + match a.evaluate().unwrap() { + ScalarValue::Float64(v) => v, + other => panic!("unexpected scalar {other:?}"), + } + } + + fn approx(a: Option, b: f64) { + assert!( + a.map(|v| (v - b).abs() < 1e-9).unwrap_or(false), + "expected ~{b}, got {a:?}" + ); + } + + fn perfect_line() -> (Vec>, Vec>) { + // y = 2x + 1 + let x = vec![1.0, 2.0, 3.0, 4.0, 5.0]; + let y: Vec<_> = x.iter().map(|v| 2.0 * v + 1.0).collect(); + ( + y.into_iter().map(Some).collect(), + x.into_iter().map(Some).collect(), + ) + } + + #[test] + fn slope_and_intercept_perfect_line() { + let (y, x) = perfect_line(); + approx(eval(RegrType::Slope, y.clone(), x.clone()), 2.0); + approx(eval(RegrType::Intercept, y, x), 1.0); + } + + #[test] + fn r2_perfect_line_is_one() { + let (y, x) = perfect_line(); + approx(eval(RegrType::R2, y, x), 1.0); + } + + #[test] + fn r2_constant_y_is_one() { + // Dependent variable constant, independent varies: Spark returns 1.0. + let y = vec![Some(7.0), Some(7.0), Some(7.0), Some(7.0)]; + let x = vec![Some(1.0), Some(2.0), Some(3.0), Some(4.0)]; + approx(eval(RegrType::R2, y, x), 1.0); + } + + #[test] + fn constant_x_yields_null() { + // Independent variable constant: slope/intercept/r2 are all NULL. + let y = vec![Some(1.0), Some(2.0), Some(3.0), Some(4.0)]; + let x = vec![Some(5.0), Some(5.0), Some(5.0), Some(5.0)]; + assert_eq!(eval(RegrType::Slope, y.clone(), x.clone()), None); + assert_eq!(eval(RegrType::Intercept, y.clone(), x.clone()), None); + assert_eq!(eval(RegrType::R2, y, x), None); + } + + #[test] + fn single_pair_edges() { + let y = vec![Some(3.0)]; + let x = vec![Some(5.0)]; + // slope/intercept/r2 require >= 2 points + assert_eq!(eval(RegrType::Slope, y.clone(), x.clone()), None); + assert_eq!(eval(RegrType::Intercept, y.clone(), x.clone()), None); + assert_eq!(eval(RegrType::R2, y.clone(), x.clone()), None); + // moments are 0 for a single point, not null + approx(eval(RegrType::SXX, y.clone(), x.clone()), 0.0); + approx(eval(RegrType::SYY, y.clone(), x.clone()), 0.0); + approx(eval(RegrType::SXY, y, x), 0.0); + } + + #[test] + fn empty_input_is_null() { + for t in [ + RegrType::Slope, + RegrType::Intercept, + RegrType::R2, + RegrType::SXX, + RegrType::SYY, + RegrType::SXY, + ] { + assert_eq!(eval(t, vec![], vec![]), None); + } + } + + #[test] + fn moments_and_comoment() { + // y, x with known deviations. mean_x = 2.75, mean_y = 1.75 + let y = vec![Some(1.0), Some(2.0), Some(2.0), Some(2.0)]; + let x = vec![Some(2.0), Some(2.0), Some(3.0), Some(4.0)]; + // The serde duplicates the target column into both slots: + // regr_sxx -> RegrReplacement(x), regr_syy -> RegrReplacement(y). + // m2_x = (2-2.75)^2+(2-2.75)^2+(3-2.75)^2+(4-2.75)^2 = 2.75 + approx(eval(RegrType::SXX, x.clone(), x.clone()), 2.75); + // m2_y = 0.75 + approx(eval(RegrType::SYY, y.clone(), y.clone()), 0.75); + // ck = sum (x-mx)(y-my) = 0.75 + approx(eval(RegrType::SXY, y, x), 0.75); + } + + #[test] + fn null_pairs_are_skipped() { + // Only (1,2) and (5,10) survive => y = 2x line. + let y = vec![Some(1.0), None, Some(3.0), Some(5.0)]; + let x = vec![Some(2.0), Some(99.0), None, Some(10.0)]; + approx(eval(RegrType::Slope, y.clone(), x.clone()), 0.5); + approx(eval(RegrType::R2, y, x), 1.0); + } + + #[test] + fn merge_matches_single_batch() { + let (y, x) = perfect_line(); + // Split into two batches and merge their partial states. + let mut a1 = acc(RegrType::Slope); + a1.update_batch(&cols(y[..2].to_vec(), x[..2].to_vec())) + .unwrap(); + let mut a2 = acc(RegrType::Slope); + a2.update_batch(&cols(y[2..].to_vec(), x[2..].to_vec())) + .unwrap(); + + let state2 = a2.state().unwrap(); + let state_arrays: Vec = state2 + .iter() + .map(|s| s.to_array_of_size(1).unwrap()) + .collect(); + a1.merge_batch(&state_arrays).unwrap(); + + match a1.evaluate().unwrap() { + ScalarValue::Float64(v) => approx(v, 2.0), + other => panic!("unexpected {other:?}"), + } + } +} diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 40d8e39dbf..eab0bc58bb 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -389,6 +389,11 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim { classOf[Last] -> CometLast, classOf[Max] -> CometMax, classOf[Min] -> CometMin, + classOf[RegrIntercept] -> CometRegrIntercept, + classOf[RegrR2] -> CometRegrR2, + classOf[RegrReplacement] -> CometRegrReplacement, + classOf[RegrSlope] -> CometRegrSlope, + classOf[RegrSXY] -> CometRegrSXY, classOf[StddevPop] -> CometStddevPop, classOf[StddevSamp] -> CometStddevSamp, classOf[Sum] -> CometSum, diff --git a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala index 6221452c86..4d340b099b 100644 --- a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala +++ b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala @@ -21,10 +21,10 @@ package org.apache.comet.serde import scala.jdk.CollectionConverters._ -import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal} -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate, CentralMomentAgg, CollectSet, Corr, Count, Covariance, CovPopulation, CovSample, First, Last, Max, Min, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Literal} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate, CentralMomentAgg, CollectSet, Corr, Count, Covariance, CovPopulation, CovSample, First, Last, Max, Min, RegrIntercept, RegrR2, RegrReplacement, RegrSlope, RegrSXY, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{ByteType, DecimalType, IntegerType, LongType, ShortType, StringType} +import org.apache.spark.sql.types.{ByteType, DecimalType, DoubleType, IntegerType, LongType, ShortType, StringType} import org.apache.comet.CometConf.COMET_EXEC_STRICT_FLOATING_POINT import org.apache.comet.CometSparkSessionExtensions.{isSpark41Plus, withFallbackReason} @@ -622,6 +622,121 @@ object CometCorr extends CometAggregateExpressionSerde[Corr] { } } +/** + * Shared serialization for the simple linear regression aggregates. `child1` is the dependent + * variable (y) and `child2` is the independent variable (x), matching the native accumulator's + * `regr_*(y, x)` convention. + */ +trait CometRegrBase { + def convertRegr( + aggExpr: AggregateExpression, + regrType: ExprOuterClass.Regr.RegrType, + y: Expression, + x: Expression, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.AggExpr] = { + val child1Expr = exprToProto(y, inputs, binding) + val child2Expr = exprToProto(x, inputs, binding) + val dataType = serializeDataType(DoubleType) + + if (child1Expr.isDefined && child2Expr.isDefined && dataType.isDefined) { + val builder = ExprOuterClass.Regr.newBuilder() + builder.setChild1(child1Expr.get) + builder.setChild2(child2Expr.get) + builder.setRegrType(regrType) + builder.setDatatype(dataType.get) + + Some( + ExprOuterClass.AggExpr + .newBuilder() + .setRegr(builder) + .build()) + } else { + withFallbackReason(aggExpr, y, x) + None + } + } +} + +object CometRegrSlope extends CometAggregateExpressionSerde[RegrSlope] with CometRegrBase { + override def convert( + aggExpr: AggregateExpression, + expr: RegrSlope, + inputs: Seq[Attribute], + binding: Boolean, + conf: SQLConf): Option[ExprOuterClass.AggExpr] = + convertRegr( + aggExpr, + ExprOuterClass.Regr.RegrType.SLOPE, + expr.left, + expr.right, + inputs, + binding) +} + +object CometRegrIntercept + extends CometAggregateExpressionSerde[RegrIntercept] + with CometRegrBase { + override def convert( + aggExpr: AggregateExpression, + expr: RegrIntercept, + inputs: Seq[Attribute], + binding: Boolean, + conf: SQLConf): Option[ExprOuterClass.AggExpr] = + convertRegr( + aggExpr, + ExprOuterClass.Regr.RegrType.INTERCEPT, + expr.left, + expr.right, + inputs, + binding) +} + +object CometRegrR2 extends CometAggregateExpressionSerde[RegrR2] with CometRegrBase { + override def convert( + aggExpr: AggregateExpression, + expr: RegrR2, + inputs: Seq[Attribute], + binding: Boolean, + conf: SQLConf): Option[ExprOuterClass.AggExpr] = + convertRegr(aggExpr, ExprOuterClass.Regr.RegrType.R2, expr.y, expr.x, inputs, binding) +} + +object CometRegrSXY extends CometAggregateExpressionSerde[RegrSXY] with CometRegrBase { + override def convert( + aggExpr: AggregateExpression, + expr: RegrSXY, + inputs: Seq[Attribute], + binding: Boolean, + conf: SQLConf): Option[ExprOuterClass.AggExpr] = + convertRegr(aggExpr, ExprOuterClass.Regr.RegrType.SXY, expr.y, expr.x, inputs, binding) +} + +/** + * Spark rewrites `regr_sxx(y, x)` and `regr_syy(y, x)` into `RegrReplacement(If(y IS NULL OR x IS + * NULL, null, col))`, where `col` is the independent (x) variable for `regr_sxx` and the + * dependent (y) variable for `regr_syy`. `RegrReplacement` evaluates to `m2` (the sum of squared + * deviations) of its single child. We serialize it as the `SXX` regression statistic with the + * child duplicated, since `regr_sxx(c, c) = m2(c)`. + */ +object CometRegrReplacement + extends CometAggregateExpressionSerde[RegrReplacement] + with CometRegrBase { + override def convert( + aggExpr: AggregateExpression, + expr: RegrReplacement, + inputs: Seq[Attribute], + binding: Boolean, + conf: SQLConf): Option[ExprOuterClass.AggExpr] = + convertRegr( + aggExpr, + ExprOuterClass.Regr.RegrType.SXX, + expr.child, + expr.child, + inputs, + binding) +} + object CometBloomFilterAggregate extends CometAggregateExpressionSerde[BloomFilterAggregate] { override def supportsMixedPartialFinal: Boolean = true diff --git a/spark/src/test/resources/sql-tests/expressions/aggregate/regr.sql b/spark/src/test/resources/sql-tests/expressions/aggregate/regr.sql index 04e941bd79..e6919f7d7a 100644 --- a/spark/src/test/resources/sql-tests/expressions/aggregate/regr.sql +++ b/spark/src/test/resources/sql-tests/expressions/aggregate/regr.sql @@ -58,53 +58,55 @@ query tolerance=1e-6 SELECT grp, regr_avgy(y, x) FROM test_regr GROUP BY grp ORDER BY grp -- regr_sxx: sum of squares of deviation of x from mean(x) --- Falls back to Spark: aggregate not yet accelerated by Comet -query spark_answer_only +query tolerance=1e-6 SELECT regr_sxx(y, x) FROM test_regr -query spark_answer_only +query tolerance=1e-6 SELECT grp, regr_sxx(y, x) FROM test_regr GROUP BY grp ORDER BY grp -- regr_syy: sum of squares of deviation of y from mean(y) --- Falls back to Spark: aggregate not yet accelerated by Comet -query spark_answer_only +query tolerance=1e-6 SELECT regr_syy(y, x) FROM test_regr -query spark_answer_only +query tolerance=1e-6 SELECT grp, regr_syy(y, x) FROM test_regr GROUP BY grp ORDER BY grp -- regr_sxy: sum of products of deviations of x and y from their means --- Falls back to Spark: aggregate not yet accelerated by Comet -query spark_answer_only +query tolerance=1e-6 SELECT regr_sxy(y, x) FROM test_regr -query spark_answer_only +query tolerance=1e-6 SELECT grp, regr_sxy(y, x) FROM test_regr GROUP BY grp ORDER BY grp -- regr_slope: slope of the least-squares regression line --- Falls back to Spark: aggregate not yet accelerated by Comet -query spark_answer_only +query tolerance=1e-6 SELECT regr_slope(y, x) FROM test_regr -query spark_answer_only +query tolerance=1e-6 SELECT grp, regr_slope(y, x) FROM test_regr GROUP BY grp ORDER BY grp -- regr_intercept: y-intercept of the least-squares regression line --- Falls back to Spark: aggregate not yet accelerated by Comet -query spark_answer_only +query tolerance=1e-6 SELECT regr_intercept(y, x) FROM test_regr -query spark_answer_only +query tolerance=1e-6 SELECT grp, regr_intercept(y, x) FROM test_regr GROUP BY grp ORDER BY grp -- regr_r2: square of the correlation coefficient (coefficient of determination) --- Falls back to Spark: aggregate not yet accelerated by Comet -query spark_answer_only +query tolerance=1e-6 SELECT regr_r2(y, x) FROM test_regr -query spark_answer_only +query tolerance=1e-6 SELECT grp, regr_r2(y, x) FROM test_regr GROUP BY grp ORDER BY grp +-- literal dependent variable mixed with a column independent variable +query tolerance=1e-6 +SELECT regr_slope(2.0, x), regr_intercept(2.0, x), regr_sxy(2.0, x) FROM test_regr + +-- literal independent variable mixed with a column dependent variable +query tolerance=1e-6 +SELECT regr_slope(y, 3.0), regr_sxx(y, 3.0), regr_r2(y, 3.0) FROM test_regr + -- edge case: all-NULL input returns NULL for all functions statement CREATE TABLE test_regr_all_null(y double, x double) USING parquet @@ -115,12 +117,12 @@ INSERT INTO test_regr_all_null VALUES (NULL, NULL), (NULL, NULL) query SELECT regr_count(y, x), regr_avgx(y, x), regr_avgy(y, x) FROM test_regr_all_null --- Falls back to Spark: regr_sxx/syy/sxy not yet accelerated by Comet -query spark_answer_only +-- regr_sxx/syy/sxy return NULL when there are no non-null pairs +query tolerance=1e-6 SELECT regr_sxx(y, x), regr_syy(y, x), regr_sxy(y, x) FROM test_regr_all_null --- Falls back to Spark: regr_slope/intercept/r2 not yet accelerated by Comet -query spark_answer_only +-- regr_slope/intercept/r2 return NULL when there are no non-null pairs +query tolerance=1e-6 SELECT regr_slope(y, x), regr_intercept(y, x), regr_r2(y, x) FROM test_regr_all_null -- edge case: single non-null pair (slope/intercept/r2 require >= 2 rows) @@ -133,10 +135,38 @@ INSERT INTO test_regr_single VALUES (3.0, 5.0), (NULL, 2.0), (1.0, NULL) query SELECT regr_count(y, x), regr_avgx(y, x), regr_avgy(y, x) FROM test_regr_single --- Falls back to Spark: regr_sxx/syy/sxy not yet accelerated by Comet -query spark_answer_only +-- sxx/syy/sxy are 0 for a single pair; slope/intercept/r2 are NULL +query tolerance=1e-6 SELECT regr_sxx(y, x), regr_syy(y, x), regr_sxy(y, x) FROM test_regr_single --- Falls back to Spark: regr_slope/intercept/r2 not yet accelerated by Comet -query spark_answer_only +query tolerance=1e-6 SELECT regr_slope(y, x), regr_intercept(y, x), regr_r2(y, x) FROM test_regr_single + +-- edge case: independent variable (x) is constant. +-- var_pop(x) = 0, so slope/intercept/r2 are all NULL, and regr_sxx = 0. +statement +CREATE TABLE test_regr_const_x(y double, x double) USING parquet + +statement +INSERT INTO test_regr_const_x VALUES (1.0, 5.0), (2.0, 5.0), (3.0, 5.0), (4.0, 5.0) + +query tolerance=1e-6 +SELECT regr_slope(y, x), regr_intercept(y, x), regr_r2(y, x) FROM test_regr_const_x + +query tolerance=1e-6 +SELECT regr_sxx(y, x), regr_syy(y, x), regr_sxy(y, x) FROM test_regr_const_x + +-- edge case: dependent variable (y) is constant but x varies. +-- A horizontal line is a perfect fit, so Spark's regr_r2 returns 1.0 (not NULL). +-- The slope is 0 and the intercept equals the constant y. +statement +CREATE TABLE test_regr_const_y(y double, x double) USING parquet + +statement +INSERT INTO test_regr_const_y VALUES (7.0, 1.0), (7.0, 2.0), (7.0, 3.0), (7.0, 4.0) + +query tolerance=1e-6 +SELECT regr_slope(y, x), regr_intercept(y, x), regr_r2(y, x) FROM test_regr_const_y + +query tolerance=1e-6 +SELECT regr_sxx(y, x), regr_syy(y, x), regr_sxy(y, x) FROM test_regr_const_y From 236f3a746a43c4f5303f8dc320b16f4bf5763754 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 1 Jul 2026 15:01:11 -0600 Subject: [PATCH 2/2] fix: match Spark's per-version regr_slope/intercept/r2 semantics The regr aggregates diverged from Spark in three ways that surfaced as CI failures across Spark versions: - regr_r2 degenerate cases were inverted. Spark 3.4/3.5/4.0 return null when the dependent variable is constant and 1.0 when the independent variable is constant; Comet had these swapped. - Spark 4.1 swapped that degenerate handling again (constant dependent -> 1.0, constant independent -> null). Route the behaviour through a new r2_constant_dependent_is_perfect_fit proto flag set from isSpark41Plus. - regr_slope/regr_intercept compute VariancePop(x) over both-non-null pairs on Spark 3.5+, but over every x-non-null row on Spark 3.4. Route this through a new filter_var_by_pair_nulls proto flag set from isSpark35Plus. Also evaluate regr_r2 as corr = ck / sqrt(m2_y * m2_x); corr * corr to mirror Spark's exact float rounding, so the golden-file postgres aggregates tests match bit-for-bit. --- native/core/src/execution/planner.rs | 7 +- native/proto/src/proto/expr.proto | 10 + native/spark-expr/src/agg_funcs/regr.rs | 174 +++++++++++++++--- .../org/apache/comet/serde/aggregates.scala | 10 +- .../sql-tests/expressions/aggregate/regr.sql | 11 +- 5 files changed, 176 insertions(+), 36 deletions(-) diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index cb3b891760..f04debc4e5 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -2617,7 +2617,12 @@ impl PhysicalPlanner { spark_expression::regr::RegrType::Syy => (RegrType::SYY, "regr_syy"), spark_expression::regr::RegrType::Sxy => (RegrType::SXY, "regr_sxy"), }; - let func = AggregateUDF::new_from_impl(Regr::new(regr_type, name)); + let func = AggregateUDF::new_from_impl(Regr::new( + regr_type, + name, + expr.filter_var_by_pair_nulls, + expr.r2_constant_dependent_is_perfect_fit, + )); Self::create_aggr_func_expr(name, schema, vec![child1, child2], func) } AggExprStruct::Percentile(expr) => { diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 87c11652d0..ee32dfec6b 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -262,6 +262,16 @@ message Regr { Expr child2 = 2; RegrType regr_type = 3; DataType datatype = 4; + // Only consulted for SLOPE and INTERCEPT. When true (Spark 3.5+), VariancePop(x) + // is computed only over rows where both y and x are non-null. When false + // (Spark 3.4), VariancePop(x) includes every row where x is non-null even if y + // is null, matching the pre-fix Spark 3.4 semantics. + bool filter_var_by_pair_nulls = 5; + // Only consulted for R2. Spark 4.1 swapped the degenerate-case handling of + // regr_r2. When true (Spark 4.1+), a constant dependent variable (m2(y) = 0) + // returns 1.0 and a constant independent variable (m2(x) = 0) returns null. + // When false (Spark 3.4/3.5/4.0), those two cases are reversed. + bool r2_constant_dependent_is_perfect_fit = 6; } message Percentile { diff --git a/native/spark-expr/src/agg_funcs/regr.rs b/native/spark-expr/src/agg_funcs/regr.rs index 12977496a4..3495cdd8ba 100644 --- a/native/spark-expr/src/agg_funcs/regr.rs +++ b/native/spark-expr/src/agg_funcs/regr.rs @@ -72,10 +72,23 @@ pub struct Regr { name: String, signature: Signature, regr_type: RegrType, + /// Only consulted for `Slope` / `Intercept`. When `true` (Spark 3.5+), + /// `VariancePop(x)` counts only rows where both `y` and `x` are non-null. + /// When `false` (Spark 3.4), it counts every row where `x` is non-null. + filter_var_by_pair_nulls: bool, + /// Only consulted for `R2`. When `true` (Spark 4.1+), a constant dependent + /// variable evaluates to `1.0` and a constant independent variable to `null`. + /// When `false` (Spark 3.4/3.5/4.0), those two cases are reversed. + r2_constant_dependent_is_perfect_fit: bool, } impl Regr { - pub fn new(regr_type: RegrType, name: impl Into) -> Self { + pub fn new( + regr_type: RegrType, + name: impl Into, + filter_var_by_pair_nulls: bool, + r2_constant_dependent_is_perfect_fit: bool, + ) -> Self { Self { name: name.into(), signature: Signature::exact( @@ -83,6 +96,8 @@ impl Regr { Volatility::Immutable, ), regr_type, + filter_var_by_pair_nulls, + r2_constant_dependent_is_perfect_fit, } } @@ -123,9 +138,17 @@ impl AggregateUDFImpl for Regr { let acc: Box = match self.regr_type { RegrType::SXX | RegrType::SYY => Box::new(RegrMomentAccumulator::try_new()?), RegrType::SXY => Box::new(RegrCovAccumulator::try_new()?), - RegrType::R2 => Box::new(RegrR2Accumulator::try_new()?), - RegrType::Slope => Box::new(RegrLineAccumulator::try_new(false)?), - RegrType::Intercept => Box::new(RegrLineAccumulator::try_new(true)?), + RegrType::R2 => Box::new(RegrR2Accumulator::try_new( + self.r2_constant_dependent_is_perfect_fit, + )?), + RegrType::Slope => Box::new(RegrLineAccumulator::try_new( + false, + self.filter_var_by_pair_nulls, + )?), + RegrType::Intercept => Box::new(RegrLineAccumulator::try_new( + true, + self.filter_var_by_pair_nulls, + )?), }; Ok(acc) } @@ -258,22 +281,33 @@ impl Accumulator for RegrCovAccumulator { /// (a `PearsonCorrelation`). State layout matches `CorrelationAccumulator`: /// count, mean1, mean2, algo_const, m2(y), m2(x). /// -/// Spark's evaluate differs from DataFusion in one degenerate case: when the -/// dependent variable `y` is constant but `x` varies, Spark returns `1.0` -/// (a horizontal line is a perfect fit) where DataFusion returns `null`. +/// Spark's degenerate-case handling (see `RegrR2.evaluateExpression`) changed in +/// Spark 4.1. In both eras one degenerate case returns `null` and the other +/// returns `1.0` (a perfect fit), but which is which was swapped: +/// - Spark 3.4/3.5/4.0: constant dependent `y` (`m2(y) == 0`) -> `null`; +/// constant independent `x` (`m2(x) == 0`) -> `1.0`. +/// - Spark 4.1+: constant dependent `y` -> `1.0`; constant independent `x` -> +/// `null`. +/// +/// `m2(y) == 0` also covers fewer than two rows. DataFusion returns `null` in +/// both degenerate cases. #[derive(Debug)] struct RegrR2Accumulator { covar: CovarianceAccumulator, var_y: VarianceAccumulator, var_x: VarianceAccumulator, + /// When `true` (Spark 4.1+), a constant dependent variable yields `1.0` and a + /// constant independent variable yields `null`; reversed when `false`. + constant_dependent_is_perfect_fit: bool, } impl RegrR2Accumulator { - fn try_new() -> Result { + fn try_new(constant_dependent_is_perfect_fit: bool) -> Result { Ok(Self { covar: CovarianceAccumulator::try_new(StatsType::Population, false)?, var_y: VarianceAccumulator::try_new(StatsType::Population, false)?, var_x: VarianceAccumulator::try_new(StatsType::Population, false)?, + constant_dependent_is_perfect_fit, }) } } @@ -328,18 +362,29 @@ impl Accumulator for RegrR2Accumulator { } fn evaluate(&mut self) -> Result { - let count = self.covar.get_count(); let m2_x = self.var_x.get_m2(); let m2_y = self.var_y.get_m2(); - if count <= 1.0 || m2_x == 0.0 { - // independent variable has no spread -> undefined + // The two degenerate cases (constant dependent y, constant independent x) + // return null and 1.0 respectively, but Spark 4.1 swapped which is which. + let (null_case, perfect_fit_case) = if self.constant_dependent_is_perfect_fit { + // Spark 4.1+: constant x -> null, constant y -> 1.0. + (m2_x == 0.0, m2_y == 0.0) + } else { + // Spark 3.4/3.5/4.0: constant y -> null, constant x -> 1.0. + (m2_y == 0.0, m2_x == 0.0) + }; + if null_case { Ok(ScalarValue::Float64(None)) - } else if m2_y == 0.0 { - // dependent variable is constant -> perfect horizontal fit + } else if perfect_fit_case { Ok(ScalarValue::Float64(Some(1.0))) } else { + // Mirror Spark's exact evaluation order (corr = ck / sqrt(m2_y * m2_x); + // corr * corr) so the last-ULP rounding matches bit-for-bit. Writing + // it as (ck * ck) / (m2_x * m2_y) is mathematically equal but rounds + // differently. let ck = self.covar.get_algo_const(); - Ok(ScalarValue::Float64(Some((ck * ck) / (m2_x * m2_y)))) + let corr = ck / (m2_y * m2_x).sqrt(); + Ok(ScalarValue::Float64(Some(corr * corr))) } } @@ -361,14 +406,19 @@ struct RegrLineAccumulator { covar: CovarianceAccumulator, var_x: VarianceAccumulator, intercept: bool, + /// When `true` (Spark 3.5+), `var_x` counts only rows where both `y` and `x` + /// are non-null. When `false` (Spark 3.4), `var_x` counts every row where `x` + /// is non-null, even if `y` is null. + filter_var_by_pair_nulls: bool, } impl RegrLineAccumulator { - fn try_new(intercept: bool) -> Result { + fn try_new(intercept: bool, filter_var_by_pair_nulls: bool) -> Result { Ok(Self { covar: CovarianceAccumulator::try_new(StatsType::Population, false)?, var_x: VarianceAccumulator::try_new(StatsType::Population, false)?, intercept, + filter_var_by_pair_nulls, }) } } @@ -383,13 +433,21 @@ impl Accumulator for RegrLineAccumulator { fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { // values[0] = y (dependent), values[1] = x (independent) let pairs = filter_pairs(values)?; - if pairs[0].is_empty() { - return Ok(()); + if !pairs[0].is_empty() { + // Feed covariance as (x, y) so mean1 = mean(x), mean2 = mean(y). + let cov_input = [Arc::clone(&pairs[1]), Arc::clone(&pairs[0])]; + self.covar.update_batch(&cov_input)?; + } + if self.filter_var_by_pair_nulls { + // Spark 3.5+: VariancePop(x) only over rows where both y and x are non-null. + if !pairs[0].is_empty() { + self.var_x.update_batch(&pairs[1..2])?; + } + } else { + // Spark 3.4: VariancePop(x) over every row where x is non-null (the + // VarianceAccumulator itself skips x nulls), regardless of y. + self.var_x.update_batch(&values[1..2])?; } - // Feed covariance as (x, y) so mean1 = mean(x), mean2 = mean(y). - let cov_input = [Arc::clone(&pairs[1]), Arc::clone(&pairs[0])]; - self.covar.update_batch(&cov_input)?; - self.var_x.update_batch(&pairs[1..2])?; Ok(()) } @@ -443,9 +501,11 @@ mod tests { match regr_type { RegrType::SXX | RegrType::SYY => Box::new(RegrMomentAccumulator::try_new().unwrap()), RegrType::SXY => Box::new(RegrCovAccumulator::try_new().unwrap()), - RegrType::R2 => Box::new(RegrR2Accumulator::try_new().unwrap()), - RegrType::Slope => Box::new(RegrLineAccumulator::try_new(false).unwrap()), - RegrType::Intercept => Box::new(RegrLineAccumulator::try_new(true).unwrap()), + // Default to pre-Spark-4.1 degenerate-case semantics. + RegrType::R2 => Box::new(RegrR2Accumulator::try_new(false).unwrap()), + // Existing tests exercise the Spark 3.5+ both-non-null semantics. + RegrType::Slope => Box::new(RegrLineAccumulator::try_new(false, true).unwrap()), + RegrType::Intercept => Box::new(RegrLineAccumulator::try_new(true, true).unwrap()), } } @@ -496,21 +556,51 @@ mod tests { } #[test] - fn r2_constant_y_is_one() { - // Dependent variable constant, independent varies: Spark returns 1.0. + fn r2_constant_y_is_null() { + // Dependent variable constant: Spark's regr_r2 returns NULL. let y = vec![Some(7.0), Some(7.0), Some(7.0), Some(7.0)]; let x = vec![Some(1.0), Some(2.0), Some(3.0), Some(4.0)]; - approx(eval(RegrType::R2, y, x), 1.0); + assert_eq!(eval(RegrType::R2, y, x), None); } #[test] - fn constant_x_yields_null() { - // Independent variable constant: slope/intercept/r2 are all NULL. + fn r2_degenerate_cases_swapped_in_spark_41() { + // Spark 4.1 swapped the two degenerate cases relative to 3.4/3.5/4.0. + let const_y = ( + vec![Some(7.0), Some(7.0), Some(7.0), Some(7.0)], + vec![Some(1.0), Some(2.0), Some(3.0), Some(4.0)], + ); + let const_x = ( + vec![Some(1.0), Some(2.0), Some(3.0), Some(4.0)], + vec![Some(5.0), Some(5.0), Some(5.0), Some(5.0)], + ); + + let r2 = |perfect_dep: bool, y: Vec>, x: Vec>| { + let mut a = RegrR2Accumulator::try_new(perfect_dep).unwrap(); + a.update_batch(&cols(y, x)).unwrap(); + match a.evaluate().unwrap() { + ScalarValue::Float64(v) => v, + other => panic!("unexpected {other:?}"), + } + }; + + // Spark 4.1+: constant dependent -> 1.0, constant independent -> null. + approx(r2(true, const_y.0.clone(), const_y.1.clone()), 1.0); + assert_eq!(r2(true, const_x.0.clone(), const_x.1.clone()), None); + // Spark 3.4/3.5/4.0: constant dependent -> null, constant independent -> 1.0. + assert_eq!(r2(false, const_y.0, const_y.1), None); + approx(r2(false, const_x.0, const_x.1), 1.0); + } + + #[test] + fn constant_x_edges() { + // Independent variable constant, dependent varies: slope/intercept are + // NULL, but Spark's regr_r2 returns 1.0 (a horizontal line is a perfect fit). let y = vec![Some(1.0), Some(2.0), Some(3.0), Some(4.0)]; let x = vec![Some(5.0), Some(5.0), Some(5.0), Some(5.0)]; assert_eq!(eval(RegrType::Slope, y.clone(), x.clone()), None); assert_eq!(eval(RegrType::Intercept, y.clone(), x.clone()), None); - assert_eq!(eval(RegrType::R2, y, x), None); + approx(eval(RegrType::R2, y, x), 1.0); } #[test] @@ -565,6 +655,30 @@ mod tests { approx(eval(RegrType::R2, y, x), 1.0); } + #[test] + fn slope_var_x_null_filtering_differs_by_spark_version() { + // A row where x is non-null but y is null: (1,2), (2,4), (3,6), (null,10). + // The co-moment ck = 4 is the same either way (only paired rows contribute). + // Spark 3.5+ excludes x=10 from VariancePop(x) (m2 over {2,4,6} = 8) -> 0.5. + // Spark 3.4 includes x=10 (m2 over {2,4,6,10} = 35) -> 4/35. + let y = vec![Some(1.0), Some(2.0), Some(3.0), None]; + let x = vec![Some(2.0), Some(4.0), Some(6.0), Some(10.0)]; + + let mut filtered = RegrLineAccumulator::try_new(false, true).unwrap(); + filtered.update_batch(&cols(y.clone(), x.clone())).unwrap(); + match filtered.evaluate().unwrap() { + ScalarValue::Float64(v) => approx(v, 0.5), + other => panic!("unexpected {other:?}"), + } + + let mut unfiltered = RegrLineAccumulator::try_new(false, false).unwrap(); + unfiltered.update_batch(&cols(y, x)).unwrap(); + match unfiltered.evaluate().unwrap() { + ScalarValue::Float64(v) => approx(v, 4.0 / 35.0), + other => panic!("unexpected {other:?}"), + } + } + #[test] fn merge_matches_single_batch() { let (y, x) = perfect_line(); diff --git a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala index 4dd7eacd4d..a74376b46f 100644 --- a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala +++ b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{ByteType, DecimalType, DoubleType, IntegerType, LongType, NumericType, ShortType, StringType} import org.apache.comet.CometConf.COMET_EXEC_STRICT_FLOATING_POINT -import org.apache.comet.CometSparkSessionExtensions.{isSpark41Plus, withFallbackReason} +import org.apache.comet.CometSparkSessionExtensions.{isSpark35Plus, isSpark41Plus, withFallbackReason} import org.apache.comet.serde.QueryPlanSerde.{evalModeToProto, exprToProto, serializeDataType} import org.apache.comet.shims.CometEvalModeUtil @@ -728,6 +728,14 @@ trait CometRegrBase { builder.setChild2(child2Expr.get) builder.setRegrType(regrType) builder.setDatatype(dataType.get) + // Spark 3.5 fixed regr_slope/regr_intercept so VariancePop(x) only counts + // rows where both y and x are non-null. Spark 3.4 counts every row where x + // is non-null. The native accumulator only consults this for slope/intercept. + builder.setFilterVarByPairNulls(isSpark35Plus) + // Spark 4.1 swapped regr_r2's degenerate-case handling: a constant dependent + // variable now yields 1.0 (was null) and a constant independent variable + // yields null (was 1.0). The native accumulator only consults this for R2. + builder.setR2ConstantDependentIsPerfectFit(isSpark41Plus) Some( ExprOuterClass.AggExpr diff --git a/spark/src/test/resources/sql-tests/expressions/aggregate/regr.sql b/spark/src/test/resources/sql-tests/expressions/aggregate/regr.sql index e6919f7d7a..250dffb1b8 100644 --- a/spark/src/test/resources/sql-tests/expressions/aggregate/regr.sql +++ b/spark/src/test/resources/sql-tests/expressions/aggregate/regr.sql @@ -142,8 +142,10 @@ SELECT regr_sxx(y, x), regr_syy(y, x), regr_sxy(y, x) FROM test_regr_single query tolerance=1e-6 SELECT regr_slope(y, x), regr_intercept(y, x), regr_r2(y, x) FROM test_regr_single --- edge case: independent variable (x) is constant. --- var_pop(x) = 0, so slope/intercept/r2 are all NULL, and regr_sxx = 0. +-- edge case: independent variable (x) is constant but y varies. +-- var_pop(x) = 0, so slope/intercept are NULL and regr_sxx = 0. regr_r2 is a +-- degenerate case whose value depends on the Spark version (1.0 on 3.4/3.5/4.0, +-- NULL on 4.1+), which Comet matches per version. statement CREATE TABLE test_regr_const_x(y double, x double) USING parquet @@ -157,8 +159,9 @@ query tolerance=1e-6 SELECT regr_sxx(y, x), regr_syy(y, x), regr_sxy(y, x) FROM test_regr_const_x -- edge case: dependent variable (y) is constant but x varies. --- A horizontal line is a perfect fit, so Spark's regr_r2 returns 1.0 (not NULL). --- The slope is 0 and the intercept equals the constant y. +-- The slope is 0 and the intercept equals the constant y. regr_r2 is a degenerate +-- case whose value depends on the Spark version (NULL on 3.4/3.5/4.0, 1.0 on +-- 4.1+), which Comet matches per version. statement CREATE TABLE test_regr_const_y(y double, x double) USING parquet