Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions docs/source/user-guide/latest/expressions.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,12 @@ The tables below list every Spark built-in expression with its current status.
| `regr_avgx` | ✅ | Native: Spark rewrites to `Average` (tests in [#4551](https://github.com/apache/datafusion-comet/issues/4551)) |
| `regr_avgy` | ✅ | Native: Spark rewrites to `Average` (tests in [#4551](https://github.com/apache/datafusion-comet/issues/4551)) |
| `regr_count` | ✅ | Native: Spark rewrites to `Count` (tests in [#4551](https://github.com/apache/datafusion-comet/issues/4551)) |
| `regr_intercept` | 🔜 | Falls back; can reuse `covar_pop`/`var_pop` accumulators ([#4552](https://github.com/apache/datafusion-comet/issues/4552)) |
| `regr_r2` | 🔜 | Falls back; can reuse the `corr` accumulator ([#4552](https://github.com/apache/datafusion-comet/issues/4552)) |
| `regr_slope` | 🔜 | Falls back; can reuse `covar_pop`/`var_pop` accumulators ([#4552](https://github.com/apache/datafusion-comet/issues/4552)) |
| `regr_sxx` | 🔜 | Falls back; can reuse `var_pop` accumulator ([#4552](https://github.com/apache/datafusion-comet/issues/4552)) |
| `regr_sxy` | 🔜 | Falls back; can reuse `covar_pop` accumulator ([#4552](https://github.com/apache/datafusion-comet/issues/4552)) |
| `regr_syy` | 🔜 | Falls back; can reuse `var_pop` accumulator ([#4552](https://github.com/apache/datafusion-comet/issues/4552)) |
| `regr_intercept` | | |
| `regr_r2` | | |
| `regr_slope` | | |
| `regr_sxx` | | |
| `regr_sxy` | | |
| `regr_syy` | | |
| `skewness` | 🔜 | tracking [#4098](https://github.com/apache/datafusion-comet/issues/4098) |
| `some` | ✅ | |
| `std` | ✅ | |
Expand Down
27 changes: 25 additions & 2 deletions native/core/src/execution/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ use datafusion_comet_proto::{
use datafusion_comet_spark_expr::{
jvm_udf::JvmScalarUdfExpr, ArrayInsert, Avg, AvgDecimal, Cast, CheckOverflow, Correlation,
Covariance, CreateNamedStruct, DecimalRescaleCheckOverflow, GetArrayStructFields,
GetStructField, IfExpr, ListExtract, NormalizeNaNAndZero, SparkCastOptions, Stddev, SumDecimal,
ToJson, UnboundColumn, Variance, WideDecimalBinaryExpr, WideDecimalOp,
GetStructField, IfExpr, ListExtract, NormalizeNaNAndZero, Regr, RegrType, SparkCastOptions,
Stddev, SumDecimal, ToJson, UnboundColumn, Variance, WideDecimalBinaryExpr, WideDecimalOp,
};
use itertools::Itertools;
use jni::objects::{Global, JObject};
Expand Down Expand Up @@ -2602,6 +2602,29 @@ impl PhysicalPlanner {
));
Self::create_aggr_func_expr("correlation", schema, vec![child1, child2], func)
}
AggExprStruct::Regr(expr) => {
let child1 =
self.create_expr(expr.child1.as_ref().unwrap(), Arc::clone(&schema))?;
let child2 =
self.create_expr(expr.child2.as_ref().unwrap(), Arc::clone(&schema))?;
let (regr_type, name) = match expr.regr_type() {
spark_expression::regr::RegrType::Slope => (RegrType::Slope, "regr_slope"),
spark_expression::regr::RegrType::Intercept => {
(RegrType::Intercept, "regr_intercept")
}
spark_expression::regr::RegrType::R2 => (RegrType::R2, "regr_r2"),
spark_expression::regr::RegrType::Sxx => (RegrType::SXX, "regr_sxx"),
spark_expression::regr::RegrType::Syy => (RegrType::SYY, "regr_syy"),
spark_expression::regr::RegrType::Sxy => (RegrType::SXY, "regr_sxy"),
};
let func = AggregateUDF::new_from_impl(Regr::new(
regr_type,
name,
expr.filter_var_by_pair_nulls,
expr.r2_constant_dependent_is_perfect_fit,
));
Self::create_aggr_func_expr(name, schema, vec![child1, child2], func)
}
AggExprStruct::Percentile(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?;
let percentile =
Expand Down
29 changes: 29 additions & 0 deletions native/proto/src/proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ message AggExpr {
BloomFilterAgg bloomFilterAgg = 16;
CollectSet collectSet = 17;
Percentile percentile = 18;
Regr regr = 19;
}

// Optional filter expression for SQL FILTER (WHERE ...) clause.
Expand Down Expand Up @@ -245,6 +246,34 @@ message Correlation {
DataType datatype = 4;
}

// Simple linear regression aggregates (regr_slope, regr_intercept, regr_r2,
// regr_sxx, regr_syy, regr_sxy). child1 is the dependent variable (y) and
// child2 is the independent variable (x).
message Regr {
enum RegrType {
SLOPE = 0;
INTERCEPT = 1;
R2 = 2;
SXX = 3;
SYY = 4;
SXY = 5;
}
Expr child1 = 1;
Expr child2 = 2;
RegrType regr_type = 3;
DataType datatype = 4;
// Only consulted for SLOPE and INTERCEPT. When true (Spark 3.5+), VariancePop(x)
// is computed only over rows where both y and x are non-null. When false
// (Spark 3.4), VariancePop(x) includes every row where x is non-null even if y
// is null, matching the pre-fix Spark 3.4 semantics.
bool filter_var_by_pair_nulls = 5;
// Only consulted for R2. Spark 4.1 swapped the degenerate-case handling of
// regr_r2. When true (Spark 4.1+), a constant dependent variable (m2(y) = 0)
// returns 1.0 and a constant independent variable (m2(x) = 0) returns null.
// When false (Spark 3.4/3.5/4.0), those two cases are reversed.
bool r2_constant_dependent_is_perfect_fit = 6;
}

message Percentile {
Expr child = 1;
// Single percentile in [0.0, 1.0] as a literal double expression.
Expand Down
2 changes: 2 additions & 0 deletions native/spark-expr/src/agg_funcs/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ mod avg;
mod avg_decimal;
mod correlation;
mod covariance;
mod regr;
mod stddev;
mod sum_decimal;
mod sum_int;
Expand All @@ -29,6 +30,7 @@ pub use avg::Avg;
pub use avg_decimal::AvgDecimal;
pub use correlation::Correlation;
pub use covariance::Covariance;
pub use regr::{Regr, RegrType};
pub use stddev::Stddev;
pub use sum_decimal::SumDecimal;
pub use sum_int::SumInteger;
Expand Down
Loading
Loading