From 768b3e90f261c7aea58bdb98dc698b90deeeae34 Mon Sep 17 00:00:00 2001 From: Kazantsev Maksim Date: Sun, 14 Dec 2025 16:24:01 +0400 Subject: [PATCH 01/10] impl map_from_entries --- native/core/src/execution/jni_api.rs | 2 + .../apache/comet/serde/QueryPlanSerde.scala | 3 +- .../scala/org/apache/comet/serde/maps.scala | 29 +++++++++++- .../comet/CometMapExpressionSuite.scala | 45 +++++++++++++++++++ 4 files changed, 77 insertions(+), 2 deletions(-) diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index a24d993059..4f53cea3e6 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -46,6 +46,7 @@ use datafusion_spark::function::datetime::date_add::SparkDateAdd; use datafusion_spark::function::datetime::date_sub::SparkDateSub; use datafusion_spark::function::hash::sha1::SparkSha1; use datafusion_spark::function::hash::sha2::SparkSha2; +use datafusion_spark::function::map::map_from_entries::MapFromEntries; use datafusion_spark::function::math::expm1::SparkExpm1; use datafusion_spark::function::string::char::CharFunc; use datafusion_spark::function::string::concat::SparkConcat; @@ -337,6 +338,7 @@ fn register_datafusion_spark_function(session_ctx: &SessionContext) { session_ctx.register_udf(ScalarUDF::new_from_impl(SparkSha1::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(SparkConcat::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(SparkBitwiseNot::default())); + session_ctx.register_udf(ScalarUDF::new_from_impl(MapFromEntries::default())); } /// Prepares arrow arrays for output. diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 54df2f1688..a99cf3824b 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -125,7 +125,8 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[MapKeys] -> CometMapKeys, classOf[MapEntries] -> CometMapEntries, classOf[MapValues] -> CometMapValues, - classOf[MapFromArrays] -> CometMapFromArrays) + classOf[MapFromArrays] -> CometMapFromArrays, + classOf[MapFromEntries] -> CometMapFromEntries) private val structExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( classOf[CreateNamedStruct] -> CometCreateNamedStruct, diff --git a/spark/src/main/scala/org/apache/comet/serde/maps.scala b/spark/src/main/scala/org/apache/comet/serde/maps.scala index 2e217f6af0..498aa3594c 100644 --- a/spark/src/main/scala/org/apache/comet/serde/maps.scala +++ b/spark/src/main/scala/org/apache/comet/serde/maps.scala @@ -19,9 +19,12 @@ package org.apache.comet.serde +import scala.annotation.tailrec + import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{ArrayType, MapType} +import org.apache.spark.sql.types.{ArrayType, BinaryType, DataType, MapType, StructType} +import org.apache.comet.serde.CometArrayReverse.containsBinary import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto, scalarFunctionExprToProtoWithReturnType} object CometMapKeys extends CometExpressionSerde[MapKeys] { @@ -89,3 +92,27 @@ object CometMapFromArrays extends CometExpressionSerde[MapFromArrays] { optExprWithInfo(mapFromArraysExpr, expr, expr.children: _*) } } + +object CometMapFromEntries extends CometScalarFunction[MapFromEntries]("map_from_entries") { + val keyUnsupportedReason = "Using BinaryType as Map keys is not allowed in map_from_entries" + val valueUnsupportedReason = "Using BinaryType as Map values is not allowed in map_from_entries" + + private def containsBinary(dataType: DataType): Boolean = { + dataType match { + case BinaryType => true + case StructType(fields) => fields.exists(field => containsBinary(field.dataType)) + case ArrayType(elementType, _) => containsBinary(elementType) + case _ => false + } + } + + override def getSupportLevel(expr: MapFromEntries): SupportLevel = { + if (containsBinary(expr.dataType.keyType)) { + return Incompatible(Some(keyUnsupportedReason)) + } + if (containsBinary(expr.dataType.valueType)) { + return Incompatible(Some(valueUnsupportedReason)) + } + Compatible(None) + } +} diff --git a/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala index 88c13391a6..01b9744ed6 100644 --- a/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala @@ -25,7 +25,9 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql.CometTestBase import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.BinaryType +import org.apache.comet.serde.CometMapFromEntries import org.apache.comet.testing.{DataGenOptions, ParquetGenerator, SchemaGenOptions} class CometMapExpressionSuite extends CometTestBase { @@ -125,4 +127,47 @@ class CometMapExpressionSuite extends CometTestBase { } } + test("map_from_entries") { + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + val filename = path.toString + val random = new Random(42) + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + val schemaGenOptions = + SchemaGenOptions( + generateArray = true, + generateStruct = true, + primitiveTypes = SchemaGenOptions.defaultPrimitiveTypes.filterNot(_ == BinaryType)) + val dataGenOptions = DataGenOptions(allowNull = false, generateNegativeZero = false) + ParquetGenerator.makeParquetFile( + random, + spark, + filename, + 100, + schemaGenOptions, + dataGenOptions) + } + val df = spark.read.parquet(filename) + df.createOrReplaceTempView("t1") + for (field <- df.schema.fieldNames) { + checkSparkAnswerAndOperator( + spark.sql(s"SELECT map_from_entries(array(struct($field as a, $field as b))) FROM t1")) + } + } + } + + test("map_from_entries - fallback for binary type") { + val table = "t2" + withTable(table) { + sql( + s"create table $table using parquet as select cast(array() as array) as c1 from range(10)") + checkSparkAnswerAndFallbackReason( + sql(s"select map_from_entries(array(struct(c1, 0))) from $table"), + CometMapFromEntries.keyUnsupportedReason) + checkSparkAnswerAndFallbackReason( + sql(s"select map_from_entries(array(struct(0, c1))) from $table"), + CometMapFromEntries.valueUnsupportedReason) + } + } + } From c68c3428676b5d991e7ba9e13464bf2ce1ec84e8 Mon Sep 17 00:00:00 2001 From: Kazantsev Maksim Date: Tue, 16 Dec 2025 16:10:43 +0400 Subject: [PATCH 02/10] Revert "impl map_from_entries" This reverts commit 768b3e90f261c7aea58bdb98dc698b90deeeae34. --- native/core/src/execution/jni_api.rs | 2 - .../apache/comet/serde/QueryPlanSerde.scala | 3 +- .../scala/org/apache/comet/serde/maps.scala | 29 +----------- .../comet/CometMapExpressionSuite.scala | 45 ------------------- 4 files changed, 2 insertions(+), 77 deletions(-) diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 4f53cea3e6..a24d993059 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -46,7 +46,6 @@ use datafusion_spark::function::datetime::date_add::SparkDateAdd; use datafusion_spark::function::datetime::date_sub::SparkDateSub; use datafusion_spark::function::hash::sha1::SparkSha1; use datafusion_spark::function::hash::sha2::SparkSha2; -use datafusion_spark::function::map::map_from_entries::MapFromEntries; use datafusion_spark::function::math::expm1::SparkExpm1; use datafusion_spark::function::string::char::CharFunc; use datafusion_spark::function::string::concat::SparkConcat; @@ -338,7 +337,6 @@ fn register_datafusion_spark_function(session_ctx: &SessionContext) { session_ctx.register_udf(ScalarUDF::new_from_impl(SparkSha1::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(SparkConcat::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(SparkBitwiseNot::default())); - session_ctx.register_udf(ScalarUDF::new_from_impl(MapFromEntries::default())); } /// Prepares arrow arrays for output. diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index a99cf3824b..54df2f1688 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -125,8 +125,7 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[MapKeys] -> CometMapKeys, classOf[MapEntries] -> CometMapEntries, classOf[MapValues] -> CometMapValues, - classOf[MapFromArrays] -> CometMapFromArrays, - classOf[MapFromEntries] -> CometMapFromEntries) + classOf[MapFromArrays] -> CometMapFromArrays) private val structExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( classOf[CreateNamedStruct] -> CometCreateNamedStruct, diff --git a/spark/src/main/scala/org/apache/comet/serde/maps.scala b/spark/src/main/scala/org/apache/comet/serde/maps.scala index 498aa3594c..2e217f6af0 100644 --- a/spark/src/main/scala/org/apache/comet/serde/maps.scala +++ b/spark/src/main/scala/org/apache/comet/serde/maps.scala @@ -19,12 +19,9 @@ package org.apache.comet.serde -import scala.annotation.tailrec - import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{ArrayType, BinaryType, DataType, MapType, StructType} +import org.apache.spark.sql.types.{ArrayType, MapType} -import org.apache.comet.serde.CometArrayReverse.containsBinary import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto, scalarFunctionExprToProtoWithReturnType} object CometMapKeys extends CometExpressionSerde[MapKeys] { @@ -92,27 +89,3 @@ object CometMapFromArrays extends CometExpressionSerde[MapFromArrays] { optExprWithInfo(mapFromArraysExpr, expr, expr.children: _*) } } - -object CometMapFromEntries extends CometScalarFunction[MapFromEntries]("map_from_entries") { - val keyUnsupportedReason = "Using BinaryType as Map keys is not allowed in map_from_entries" - val valueUnsupportedReason = "Using BinaryType as Map values is not allowed in map_from_entries" - - private def containsBinary(dataType: DataType): Boolean = { - dataType match { - case BinaryType => true - case StructType(fields) => fields.exists(field => containsBinary(field.dataType)) - case ArrayType(elementType, _) => containsBinary(elementType) - case _ => false - } - } - - override def getSupportLevel(expr: MapFromEntries): SupportLevel = { - if (containsBinary(expr.dataType.keyType)) { - return Incompatible(Some(keyUnsupportedReason)) - } - if (containsBinary(expr.dataType.valueType)) { - return Incompatible(Some(valueUnsupportedReason)) - } - Compatible(None) - } -} diff --git a/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala index 01b9744ed6..88c13391a6 100644 --- a/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala @@ -25,9 +25,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql.CometTestBase import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.BinaryType -import org.apache.comet.serde.CometMapFromEntries import org.apache.comet.testing.{DataGenOptions, ParquetGenerator, SchemaGenOptions} class CometMapExpressionSuite extends CometTestBase { @@ -127,47 +125,4 @@ class CometMapExpressionSuite extends CometTestBase { } } - test("map_from_entries") { - withTempDir { dir => - val path = new Path(dir.toURI.toString, "test.parquet") - val filename = path.toString - val random = new Random(42) - withSQLConf(CometConf.COMET_ENABLED.key -> "false") { - val schemaGenOptions = - SchemaGenOptions( - generateArray = true, - generateStruct = true, - primitiveTypes = SchemaGenOptions.defaultPrimitiveTypes.filterNot(_ == BinaryType)) - val dataGenOptions = DataGenOptions(allowNull = false, generateNegativeZero = false) - ParquetGenerator.makeParquetFile( - random, - spark, - filename, - 100, - schemaGenOptions, - dataGenOptions) - } - val df = spark.read.parquet(filename) - df.createOrReplaceTempView("t1") - for (field <- df.schema.fieldNames) { - checkSparkAnswerAndOperator( - spark.sql(s"SELECT map_from_entries(array(struct($field as a, $field as b))) FROM t1")) - } - } - } - - test("map_from_entries - fallback for binary type") { - val table = "t2" - withTable(table) { - sql( - s"create table $table using parquet as select cast(array() as array) as c1 from range(10)") - checkSparkAnswerAndFallbackReason( - sql(s"select map_from_entries(array(struct(c1, 0))) from $table"), - CometMapFromEntries.keyUnsupportedReason) - checkSparkAnswerAndFallbackReason( - sql(s"select map_from_entries(array(struct(0, c1))) from $table"), - CometMapFromEntries.valueUnsupportedReason) - } - } - } From 6b0d5007c213046fab5e39eb2f2d42514908b5c9 Mon Sep 17 00:00:00 2001 From: Kazantsev Maksim Date: Sun, 28 Jun 2026 22:53:10 +0400 Subject: [PATCH 03/10] Support native DataFusion lambda functions --- native/core/src/execution/planner.rs | 96 ++++++++++++++- native/proto/src/proto/expr.proto | 19 +++ .../spark-expr/src/comet_high_order_funcs.rs | 30 +++++ native/spark-expr/src/lib.rs | 2 + .../comet/serde/CometHighOrderFunction.scala | 114 ++++++++++++++++++ .../apache/comet/serde/QueryPlanSerde.scala | 3 +- .../scala/org/apache/comet/serde/arrays.scala | 11 +- .../expressions/array/array_filter_native.sql | 30 +++++ 8 files changed, 296 insertions(+), 9 deletions(-) create mode 100644 native/spark-expr/src/comet_high_order_funcs.rs create mode 100644 spark/src/main/scala/org/apache/comet/serde/CometHighOrderFunction.scala create mode 100644 spark/src/test/resources/sql-tests/expressions/array/array_filter_native.sql diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index e89f0a8cf4..150f7d23c9 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -71,9 +71,9 @@ use datafusion::{ prelude::SessionContext, }; use datafusion_comet_spark_expr::{ - create_comet_physical_fun, create_comet_physical_fun_with_eval_mode, BinaryOutputStyle, - BloomFilterAgg, BloomFilterMightContain, CsvWriteOptions, EvalMode, SparkArraysZipFunc, - SparkBloomFilterVersion, SumInteger, ToCsv, + create_comet_hof_func, create_comet_physical_fun, create_comet_physical_fun_with_eval_mode, + BinaryOutputStyle, BloomFilterAgg, BloomFilterMightContain, CsvWriteOptions, EvalMode, + SparkArraysZipFunc, SparkBloomFilterVersion, SumInteger, ToCsv, }; use datafusion_spark::function::aggregate::collect::SparkCollectSet; use iceberg::expr::Bind; @@ -94,9 +94,9 @@ use datafusion::logical_expr::{ AggregateUDF, ReturnFieldArgs, ScalarUDF, TypeSignature, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; -use datafusion::physical_expr::expressions::{Literal, StatsType}; +use datafusion::physical_expr::expressions::{LambdaExpr, LambdaVariable, Literal, StatsType}; use datafusion::physical_expr::window::WindowExpr; -use datafusion::physical_expr::LexOrdering; +use datafusion::physical_expr::{HigherOrderFunctionExpr, LexOrdering}; use crate::parquet::parquet_exec::init_datasource_exec; use arrow::array::{ @@ -112,7 +112,7 @@ use datafusion::physical_plan::filter::FilterExec; use datafusion::physical_plan::joins::NestedLoopJoinExec; use datafusion::physical_plan::limit::GlobalLimitExec; use datafusion::physical_plan::unnest::{ListUnnest, UnnestExec}; -use datafusion_comet_proto::spark_expression::ListLiteral; +use datafusion_comet_proto::spark_expression::{HigherOrderFunc, LambdaFunction, ListLiteral}; use datafusion_comet_proto::spark_operator::SparkFilePartition; use datafusion_comet_proto::{ spark_expression::{ @@ -531,6 +531,20 @@ impl PhysicalPlanner { _ => func, } } + ExprStruct::HighOrderFunc(hof) => { + self.create_high_order_function_expr(hof, input_schema) + } + ExprStruct::NamedLambdaVariable(nlv) => { + let idx = input_schema.index_of(&nlv.name).map_err(|_| { + GeneralError(format!( + "NamedLambdaVariable '{}' not found in enclosing lambda schema", + nlv.name + )) + })?; + let field = Arc::clone(&input_schema.fields()[idx]); + //TODO get valid idx from Spark + Ok(Arc::new(LambdaVariable::new(1, field))) + } ExprStruct::CaseWhen(case_when) => { let when_then_pairs = case_when .when @@ -2886,6 +2900,76 @@ impl PhysicalPlanner { } } + fn create_high_order_function_expr( + &self, + expr: &HigherOrderFunc, + input_schema: SchemaRef, + ) -> Result, ExecutionError> { + let comet_hof_func = + create_comet_hof_func(expr.func_name.as_str(), &self.session_ctx.state())?; + + let value_args = expr + .value_args + .iter() + .map(|e| self.create_expr(e, Arc::clone(&input_schema))) + .collect::, _>>()?; + + let lambdas = expr + .lambdas + .iter() + .map(|l| self.create_lambda_expr(l)) + .collect::, _>>()?; + + let mut args: Vec> = + Vec::with_capacity(value_args.len() + lambdas.len()); + args.extend(value_args); + args.extend(lambdas); + + let higher_order_function_expr = HigherOrderFunctionExpr::try_new_with_schema( + comet_hof_func, + args, + &input_schema, + Arc::new(ConfigOptions::default()), + )?; + + Ok(Arc::new(higher_order_function_expr)) + } + + fn create_lambda_expr( + &self, + lambda: &LambdaFunction, + ) -> Result, ExecutionError> { + let lambda_fields = lambda + .args + .iter() + .map(|arg| { + let data_type = arg.data_type.as_ref().ok_or_else(|| { + DataFusionError::Internal("lambda variable without data type".to_string()) + })?; + let arrow_data_type = to_arrow_datatype(data_type); + let field = Field::new(&arg.name, arrow_data_type, arg.nullable); + Ok(Arc::new(field)) + }) + .collect::>, ExecutionError>>()?; + + let body_schema = Arc::new(Schema::new(lambda_fields)); + + let lambda_body = lambda + .body + .as_ref() + .ok_or_else(|| DataFusionError::Internal("lambda has no body".to_string()))?; + let body_expr = self.create_expr(lambda_body, body_schema)?; + + Ok(Arc::new(LambdaExpr::try_new( + lambda + .args + .iter() + .map(|a| a.name.clone()) + .collect::>(), + body_expr, + )?)) + } + fn create_scalar_function_expr( &self, expr: &ScalarFunc, diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 90e3d87032..e8da82553b 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -91,6 +91,8 @@ message Expr { HoursTransform hours_transform = 68; ArraysZip arrays_zip = 69; JvmScalarUdf jvm_scalar_udf = 70; + HigherOrderFunc high_order_func = 71; + NamedLambdaVariable named_lambda_variable = 72; } // Optional QueryContext for error reporting (contains SQL text and position) @@ -530,3 +532,20 @@ message JvmScalarUdf { // Whether the result column may contain nulls. bool return_nullable = 4; } + +message HigherOrderFunc { + string func_name = 1; + repeated Expr value_args = 2; + repeated LambdaFunction lambdas = 3; +} + +message NamedLambdaVariable { + string name = 1; + DataType data_type = 2; + bool nullable = 3; +} + +message LambdaFunction { + Expr body = 1; + repeated NamedLambdaVariable args = 2; +} diff --git a/native/spark-expr/src/comet_high_order_funcs.rs b/native/spark-expr/src/comet_high_order_funcs.rs new file mode 100644 index 0000000000..7277d1ee66 --- /dev/null +++ b/native/spark-expr/src/comet_high_order_funcs.rs @@ -0,0 +1,30 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::common::DataFusionError; +use datafusion::execution::FunctionRegistry; +use datafusion::logical_expr::HigherOrderUDF; +use std::sync::Arc; + +pub fn create_comet_hof_func( + func_name: &str, + registry: &dyn FunctionRegistry, +) -> Result, DataFusionError> { + registry.higher_order_function(func_name).map_err(|e| { + DataFusionError::Execution(format!("HOF {func_name} not found in the registry: {e}")) + }) +} diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index 174a4ada9a..bdffad94a9 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -42,6 +42,7 @@ pub use predicate_funcs::{spark_isnan, RLike}; mod agg_funcs; mod array_funcs; +mod comet_high_order_funcs; mod comet_scalar_funcs; pub mod hash_funcs; @@ -70,6 +71,7 @@ pub use conditional_funcs::*; pub use conversion_funcs::*; pub use nondetermenistic_funcs::*; +pub use comet_high_order_funcs::create_comet_hof_func; pub use comet_scalar_funcs::{ create_comet_physical_fun, create_comet_physical_fun_with_eval_mode, register_all_comet_functions, diff --git a/spark/src/main/scala/org/apache/comet/serde/CometHighOrderFunction.scala b/spark/src/main/scala/org/apache/comet/serde/CometHighOrderFunction.scala new file mode 100644 index 0000000000..9e75a54d72 --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/serde/CometHighOrderFunction.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde + +import scala.jdk.CollectionConverters._ + +import org.apache.spark.sql.catalyst.expressions.{Attribute, HigherOrderFunction, LambdaFunction => SparkLambdaFunction, NamedLambdaVariable => SparkNamedLambdaVariable} + +import org.apache.comet.CometSparkSessionExtensions.withFallbackReason +import org.apache.comet.serde.CometHighOrderFunction.nlv2Proto +import org.apache.comet.serde.ExprOuterClass.{HigherOrderFunc, LambdaFunction, NamedLambdaVariable} +import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, serializeDataType} + +case class CometHighOrderFunction[T <: HigherOrderFunction](name: String) + extends CometExpressionSerde[T] { + + private val UNSUPPORTED_LAMBDA_TYPE = "lambda functions must be LambdaFunction" + private val UNSUPPORTED_LAMBDA_PARAM_TYPE = "lambda arguments must be NamedLambdaVariables" + private val UNARY_FUNCTION_EXPECTED = + "DataFusion higher-order functions support only 1 argument" + + override def getIncompatibleReasons(): Seq[String] = + Seq(UNSUPPORTED_LAMBDA_TYPE, UNSUPPORTED_LAMBDA_PARAM_TYPE) + + override def getSupportLevel(expr: T): SupportLevel = { + if (!expr.functions.forall(_.isInstanceOf[SparkLambdaFunction])) { + return Unsupported(Some(UNSUPPORTED_LAMBDA_TYPE)) + } + val lambdaFunctions = expr.functions.map(_.asInstanceOf[SparkLambdaFunction]) + if (!lambdaFunctions.forall(_.arguments.length == 1)) { + return Unsupported(Some(UNARY_FUNCTION_EXPECTED)) + } + if (!expr.functions + .flatMap(_.asInstanceOf[SparkLambdaFunction].arguments) + .forall(_.isInstanceOf[SparkNamedLambdaVariable])) { + return Unsupported(Some(UNSUPPORTED_LAMBDA_PARAM_TYPE)) + } + Compatible() + } + + def convert(expr: T, inputs: Seq[Attribute], binding: Boolean): Option[ExprOuterClass.Expr] = { + val argumentsProto = expr.arguments.map(exprToProtoInternal(_, inputs, binding)) + val functionsProto = expr.functions + .map(_.asInstanceOf[SparkLambdaFunction]) + .map { slf => + val maybeExpr = exprToProtoInternal(slf.function, inputs, binding) + maybeExpr + .map { bodyProto => + val nlvProto = slf.arguments + .map(_.asInstanceOf[SparkNamedLambdaVariable]) + .map(nlv2Proto) + LambdaFunction + .newBuilder() + .addAllArgs(nlvProto.asJava) + .setBody(bodyProto) + .build() + } + } + if (functionsProto.forall(_.isDefined) && argumentsProto.forall(_.isDefined)) { + val hof = HigherOrderFunc + .newBuilder() + .setFuncName(name) + .addAllValueArgs(argumentsProto.map(_.get).asJava) + .addAllLambdas(functionsProto.map(_.get).asJava) + .build() + Some(ExprOuterClass.Expr.newBuilder().setHighOrderFunc(hof).build()) + } else { + withFallbackReason(expr, expr.children: _*) + None + } + } +} + +object CometHighOrderFunction { + def nlv2Proto(nlv: SparkNamedLambdaVariable): NamedLambdaVariable = { + NamedLambdaVariable + .newBuilder() + .setName(nlv.name) + .setNullable(nlv.nullable) + .setDataType(serializeDataType(nlv.dataType).get) + .build() + } +} + +object CometNamedLambdaVariable extends CometExpressionSerde[SparkNamedLambdaVariable] { + def convert( + expr: SparkNamedLambdaVariable, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val nlvProto = CometHighOrderFunction.nlv2Proto(expr) + Some( + ExprOuterClass.Expr + .newBuilder() + .setNamedLambdaVariable(nlvProto) + .build()) + } +} diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 143048fb44..65e77c9f97 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -357,7 +357,8 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim { classOf[SortOrder] -> CometSortOrder, classOf[StaticInvoke] -> CometStaticInvoke, classOf[TryEval] -> CometTryEval, - classOf[UnscaledValue] -> CometUnscaledValue) + classOf[UnscaledValue] -> CometUnscaledValue, + classOf[NamedLambdaVariable] -> CometNamedLambdaVariable) base ++ sparkVersionSpecificMiscExpressions } diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala b/spark/src/main/scala/org/apache/comet/serde/arrays.scala index eaecd1b49a..bf75865b61 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala @@ -682,9 +682,14 @@ object CometFlatten extends CometExpressionSerde[Flatten] with ArraysBase { } } -object CometArrayFilter extends CometExpressionSerde[ArrayFilter] { +object CometArrayFilter extends CometHighOrderFunction[ArrayFilter]("array_filter") { - override def getSupportLevel(expr: ArrayFilter): SupportLevel = Compatible() + override def getSupportLevel(expr: ArrayFilter): SupportLevel = { + if (!CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED.get()) { + return super.getSupportLevel(expr) + } + Compatible() + } override def convert( expr: ArrayFilter, @@ -695,6 +700,8 @@ object CometArrayFilter extends CometExpressionSerde[ArrayFilter] { // Fast path: `array_compact` lowers to `filter(arr, x -> x is not null)`. Use the native // array_compact serde to avoid the per-batch JNI cost of the codegen dispatcher. CometArrayCompact.convert(expr, inputs, binding) + case _ if !CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED.get() => + super.convert(expr, inputs, binding) case _ => // General lambda: run Spark's own evaluation through the codegen dispatcher so the result // matches Spark exactly, like the other higher-order functions (`transform`, `exists`). diff --git a/spark/src/test/resources/sql-tests/expressions/array/array_filter_native.sql b/spark/src/test/resources/sql-tests/expressions/array/array_filter_native.sql new file mode 100644 index 0000000000..2943695680 --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/array/array_filter_native.sql @@ -0,0 +1,30 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- Config: spark.comet.exec.scalaUDF.codegen.enabled=false + +statement +CREATE TABLE test_array_filter_native(arr array) USING parquet + +statement +INSERT INTO test_array_filter_native VALUES (array(1, 2, 3, 4, 5)), (array(-1, 0, 1)), (array(10)), (NULL) + +query +SELECT filter(arr, x -> x > 2) FROM test_array_filter_native + +query +SELECT filter(arr, x -> x >= 0) FROM test_array_filter_native From 4281483b734b1b379bdd54e323678a500c0a6c60 Mon Sep 17 00:00:00 2001 From: Kazantsev Maksim Date: Mon, 29 Jun 2026 22:07:12 +0400 Subject: [PATCH 04/10] more tests --- native/core/src/execution/planner.rs | 39 +++++---- .../comet/serde/CometHighOrderFunction.scala | 12 +++ .../expressions/array/array_filter_native.sql | 4 + .../benchmark/CometArrayFilterBenchmark.scala | 87 +++++++++++++++++++ 4 files changed, 126 insertions(+), 16 deletions(-) create mode 100644 spark/src/test/scala/org/apache/spark/sql/benchmark/CometArrayFilterBenchmark.scala diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 150f7d23c9..dbf66ec003 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -542,8 +542,7 @@ impl PhysicalPlanner { )) })?; let field = Arc::clone(&input_schema.fields()[idx]); - //TODO get valid idx from Spark - Ok(Arc::new(LambdaVariable::new(1, field))) + Ok(Arc::new(LambdaVariable::new(idx, field))) } ExprStruct::CaseWhen(case_when) => { let when_then_pairs = case_when @@ -2917,7 +2916,7 @@ impl PhysicalPlanner { let lambdas = expr .lambdas .iter() - .map(|l| self.create_lambda_expr(l)) + .map(|l| self.create_lambda_expr(l, &input_schema)) .collect::, _>>()?; let mut args: Vec> = @@ -2938,21 +2937,29 @@ impl PhysicalPlanner { fn create_lambda_expr( &self, lambda: &LambdaFunction, + input_schema: &SchemaRef, ) -> Result, ExecutionError> { - let lambda_fields = lambda - .args - .iter() - .map(|arg| { - let data_type = arg.data_type.as_ref().ok_or_else(|| { - DataFusionError::Internal("lambda variable without data type".to_string()) - })?; - let arrow_data_type = to_arrow_datatype(data_type); - let field = Field::new(&arg.name, arrow_data_type, arg.nullable); - Ok(Arc::new(field)) - }) - .collect::>, ExecutionError>>()?; + let mut body_fields: Vec> = + input_schema.fields().iter().map(Arc::clone).collect(); + + for arg in &lambda.args { + let data_type = arg.data_type.as_ref().ok_or_else(|| { + DataFusionError::Internal("lambda variable without data type".to_string()) + })?; + let arrow_data_type = to_arrow_datatype(data_type); + body_fields.push(Arc::new(Field::new( + &arg.name, + arrow_data_type, + arg.nullable, + ))); + } - let body_schema = Arc::new(Schema::new(lambda_fields)); + let body_schema = Arc::new(Schema::new( + body_fields + .iter() + .map(|f| f.as_ref().clone()) + .collect::>(), + )); let lambda_body = lambda .body diff --git a/spark/src/main/scala/org/apache/comet/serde/CometHighOrderFunction.scala b/spark/src/main/scala/org/apache/comet/serde/CometHighOrderFunction.scala index 9e75a54d72..0293cb96ec 100644 --- a/spark/src/main/scala/org/apache/comet/serde/CometHighOrderFunction.scala +++ b/spark/src/main/scala/org/apache/comet/serde/CometHighOrderFunction.scala @@ -28,6 +28,18 @@ import org.apache.comet.serde.CometHighOrderFunction.nlv2Proto import org.apache.comet.serde.ExprOuterClass.{HigherOrderFunc, LambdaFunction, NamedLambdaVariable} import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, serializeDataType} +/** + * Serializer that converts Spark higher-order functions (e.g. `filter`, `transform`, `exists`) + * into Comet's protobuf representation so they can be executed by the native DataFusion engine. + * + * A higher-order function carries two kinds of children: + * - "value" arguments: the regular input expressions, typically the array/map being processed. + * - "function" arguments: one or more lambda expressions describing the per-element + * computation. + * + * This serde only supports the subset of higher-order functions that the native engine can + * currently handle; see [[getSupportLevel]] for the exact constraints. + */ case class CometHighOrderFunction[T <: HigherOrderFunction](name: String) extends CometExpressionSerde[T] { diff --git a/spark/src/test/resources/sql-tests/expressions/array/array_filter_native.sql b/spark/src/test/resources/sql-tests/expressions/array/array_filter_native.sql index 2943695680..ee0b934ad6 100644 --- a/spark/src/test/resources/sql-tests/expressions/array/array_filter_native.sql +++ b/spark/src/test/resources/sql-tests/expressions/array/array_filter_native.sql @@ -28,3 +28,7 @@ SELECT filter(arr, x -> x > 2) FROM test_array_filter_native query SELECT filter(arr, x -> x >= 0) FROM test_array_filter_native + +query expect_fallback(DataFusion higher-order functions support only 1 argument) +SELECT filter(arr, (x, i) -> i > 0) FROM test_array_filter_native + diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometArrayFilterBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometArrayFilterBenchmark.scala new file mode 100644 index 0000000000..a64be019a4 --- /dev/null +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometArrayFilterBenchmark.scala @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.benchmark + +import org.apache.spark.benchmark.Benchmark + +import org.apache.comet.CometConf + +// spotless:off +/** + * Benchmark to measure performance of Comet array expressions. To run this benchmark: + * {{{ + * SPARK_GENERATE_BENCHMARK_FILES=1 make benchmark-org.apache.spark.sql.benchmark.CometArrayFilterBenchmark + * }}} + * Results will be written to "spark/benchmarks/CometArrayFilterBenchmark-**results.txt". + */ +// spotless:on +object CometArrayFilterBenchmark extends CometBenchmarkBase { + + def runExprBenchmark(config: ArrayFilterExprConfig, values: Int, arraySize: Int): Unit = { + val benchmark = + new Benchmark(s"${config.name} (size $arraySize)", values, output = output) + withTempPath { dir => + withTempTable("parquetV1Table") { + prepareTable( + dir, + spark.sql( + s"SELECT sequence(0, cast(rand(42) * $arraySize as int)) AS arr " + + s"FROM range($values)")) + + benchmark.addCase(s"Spark ${config.name}") { _ => + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + spark.sql(config.query).noop() + } + } + + benchmark.addCase(s"Comet (Native) ${config.name}") { _ => + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED.key -> "false") { + spark.sql(config.query).noop() + } + } + + benchmark.addCase(s"Comet (Codegen) ${config.name}") { _ => + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED.key -> "true") { + spark.sql(config.query).noop() + } + } + + benchmark.run() + } + } + } + + def runCometBenchmark(args: Array[String]): Unit = { + val values = 4 * 1024 * 1024 + + val config = + ArrayFilterExprConfig("array_filter", "SELECT filter(arr, x -> x > 2) FROM parquetV1Table") + + runExprBenchmark(config, values, 100) + } +} + +case class ArrayFilterExprConfig(name: String, query: String) From d056f64a21da9b13e2552efa8eedb4b409b3fcdd Mon Sep 17 00:00:00 2001 From: Kazantsev Maksim Date: Mon, 29 Jun 2026 22:18:15 +0400 Subject: [PATCH 05/10] fix --- .../scala/org/apache/comet/serde/CometHighOrderFunction.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/CometHighOrderFunction.scala b/spark/src/main/scala/org/apache/comet/serde/CometHighOrderFunction.scala index 0293cb96ec..8179d6406e 100644 --- a/spark/src/main/scala/org/apache/comet/serde/CometHighOrderFunction.scala +++ b/spark/src/main/scala/org/apache/comet/serde/CometHighOrderFunction.scala @@ -49,7 +49,7 @@ case class CometHighOrderFunction[T <: HigherOrderFunction](name: String) "DataFusion higher-order functions support only 1 argument" override def getIncompatibleReasons(): Seq[String] = - Seq(UNSUPPORTED_LAMBDA_TYPE, UNSUPPORTED_LAMBDA_PARAM_TYPE) + Seq(UNSUPPORTED_LAMBDA_TYPE, UNARY_FUNCTION_EXPECTED, UNSUPPORTED_LAMBDA_PARAM_TYPE) override def getSupportLevel(expr: T): SupportLevel = { if (!expr.functions.forall(_.isInstanceOf[SparkLambdaFunction])) { From 30f06f40e39b10978aeed8bc947cf228f6903bf9 Mon Sep 17 00:00:00 2001 From: Kazantsev Maksim Date: Wed, 1 Jul 2026 22:02:31 +0400 Subject: [PATCH 06/10] Fix PR issues --- native/core/src/execution/planner.rs | 5 ++ .../scala/org/apache/comet/CometConf.scala | 10 +++ .../comet/serde/CometHighOrderFunction.scala | 86 +++++++++++++------ .../scala/org/apache/comet/serde/arrays.scala | 14 +-- .../expressions/array/array_filter.sql | 2 +- ...r_native.sql => array_filter_fallback.sql} | 16 +--- 6 files changed, 81 insertions(+), 52 deletions(-) rename spark/src/test/resources/sql-tests/expressions/array/{array_filter_native.sql => array_filter_fallback.sql} (70%) diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index f6825626ac..c8c19c4bb5 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -3120,6 +3120,11 @@ impl PhysicalPlanner { .map(|l| self.create_lambda_expr(l, &input_schema)) .collect::, _>>()?; + // NOTE: assumes all value arguments precede all lambda arguments. + // Holds for array_filter and the current single-lambda HOFs, but would + // NOT generalize to a future HOF with interleaved value/lambda args + // (e.g. f(value, lambda, value, lambda)). Revisit this split if such a + // function is added. let mut args: Vec> = Vec::with_capacity(value_args.len() + lambdas.len()); args.extend(value_args); diff --git a/spark/src/main/scala/org/apache/comet/CometConf.scala b/spark/src/main/scala/org/apache/comet/CometConf.scala index 8e47151358..366cd01e2b 100644 --- a/spark/src/main/scala/org/apache/comet/CometConf.scala +++ b/spark/src/main/scala/org/apache/comet/CometConf.scala @@ -381,6 +381,16 @@ object CometConf extends ShimCometConf { .booleanConf .createWithDefault(true) + val COMET_EXEC_HIGHER_ORDER_FUNCTION_NATIVE_ENABLED: ConfigEntry[Boolean] = + conf("spark.comet.exec.higherOrderFunction.native.enabled") + .category(CATEGORY_EXEC) + .doc( + "When enabled, supported higher-order functions (e.g. filter) are executed by the " + + "native DataFusion engine. Shapes the native path cannot handle fall back to the " + + "codegen dispatcher, and finally to Spark.") + .booleanConf + .createWithDefault(true) + val COMET_EXEC_SHUFFLE_WITH_HASH_PARTITIONING_ENABLED: ConfigEntry[Boolean] = conf("spark.comet.native.shuffle.partitioning.hash.enabled") .category(CATEGORY_SHUFFLE) diff --git a/spark/src/main/scala/org/apache/comet/serde/CometHighOrderFunction.scala b/spark/src/main/scala/org/apache/comet/serde/CometHighOrderFunction.scala index 8179d6406e..9e3845e641 100644 --- a/spark/src/main/scala/org/apache/comet/serde/CometHighOrderFunction.scala +++ b/spark/src/main/scala/org/apache/comet/serde/CometHighOrderFunction.scala @@ -23,6 +23,7 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.sql.catalyst.expressions.{Attribute, HigherOrderFunction, LambdaFunction => SparkLambdaFunction, NamedLambdaVariable => SparkNamedLambdaVariable} +import org.apache.comet.CometConf import org.apache.comet.CometSparkSessionExtensions.withFallbackReason import org.apache.comet.serde.CometHighOrderFunction.nlv2Proto import org.apache.comet.serde.ExprOuterClass.{HigherOrderFunc, LambdaFunction, NamedLambdaVariable} @@ -48,41 +49,72 @@ case class CometHighOrderFunction[T <: HigherOrderFunction](name: String) private val UNARY_FUNCTION_EXPECTED = "DataFusion higher-order functions support only 1 argument" - override def getIncompatibleReasons(): Seq[String] = + override def getUnsupportedReasons(): Seq[String] = Seq(UNSUPPORTED_LAMBDA_TYPE, UNARY_FUNCTION_EXPECTED, UNSUPPORTED_LAMBDA_PARAM_TYPE) - override def getSupportLevel(expr: T): SupportLevel = { + private def nativeUnsupportedReason(expr: T): Option[String] = { if (!expr.functions.forall(_.isInstanceOf[SparkLambdaFunction])) { - return Unsupported(Some(UNSUPPORTED_LAMBDA_TYPE)) + return Some(UNSUPPORTED_LAMBDA_TYPE) } val lambdaFunctions = expr.functions.map(_.asInstanceOf[SparkLambdaFunction]) if (!lambdaFunctions.forall(_.arguments.length == 1)) { - return Unsupported(Some(UNARY_FUNCTION_EXPECTED)) + return Some(UNARY_FUNCTION_EXPECTED) } if (!expr.functions .flatMap(_.asInstanceOf[SparkLambdaFunction].arguments) .forall(_.isInstanceOf[SparkNamedLambdaVariable])) { - return Unsupported(Some(UNSUPPORTED_LAMBDA_PARAM_TYPE)) + return Some(UNSUPPORTED_LAMBDA_PARAM_TYPE) + } + None + } + + override def getSupportLevel(expr: T): SupportLevel = { + val unsupportedReason = nativeUnsupportedReason(expr) + val nativeAvailable = + unsupportedReason.isEmpty && CometConf.COMET_EXEC_HIGHER_ORDER_FUNCTION_NATIVE_ENABLED.get() + val codegenEnabled = CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED.get() + if (nativeAvailable || codegenEnabled) { + Compatible() + } else { + Unsupported(unsupportedReason) } - Compatible() } def convert(expr: T, inputs: Seq[Attribute], binding: Boolean): Option[ExprOuterClass.Expr] = { + val nativeAvailable = + nativeUnsupportedReason( + expr).isEmpty && CometConf.COMET_EXEC_HIGHER_ORDER_FUNCTION_NATIVE_ENABLED.get() + val hofProto = hof2Proto(expr, inputs, binding) + if (nativeAvailable && hofProto.isDefined) { + hofProto + } else { + CometScalaUDF.emitJvmCodegenDispatch(expr, inputs, binding) + } + } + + private def hof2Proto( + expr: T, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { val argumentsProto = expr.arguments.map(exprToProtoInternal(_, inputs, binding)) val functionsProto = expr.functions .map(_.asInstanceOf[SparkLambdaFunction]) .map { slf => - val maybeExpr = exprToProtoInternal(slf.function, inputs, binding) - maybeExpr - .map { bodyProto => + exprToProtoInternal(slf.function, inputs, binding) + .flatMap { bodyProto => val nlvProto = slf.arguments .map(_.asInstanceOf[SparkNamedLambdaVariable]) .map(nlv2Proto) - LambdaFunction - .newBuilder() - .addAllArgs(nlvProto.asJava) - .setBody(bodyProto) - .build() + if (nlvProto.forall(_.isDefined)) { + Some( + LambdaFunction + .newBuilder() + .addAllArgs(nlvProto.map(_.get).asJava) + .setBody(bodyProto) + .build()) + } else { + None + } } } if (functionsProto.forall(_.isDefined) && argumentsProto.forall(_.isDefined)) { @@ -101,13 +133,19 @@ case class CometHighOrderFunction[T <: HigherOrderFunction](name: String) } object CometHighOrderFunction { - def nlv2Proto(nlv: SparkNamedLambdaVariable): NamedLambdaVariable = { - NamedLambdaVariable - .newBuilder() - .setName(nlv.name) - .setNullable(nlv.nullable) - .setDataType(serializeDataType(nlv.dataType).get) - .build() + def nlv2Proto(nlv: SparkNamedLambdaVariable): Option[NamedLambdaVariable] = { + val dataTypeProto = serializeDataType(nlv.dataType) + if (dataTypeProto.isEmpty) { + withFallbackReason(nlv, s"Unsupported datatype: ${nlv.dataType}") + return None + } + Some( + NamedLambdaVariable + .newBuilder() + .setName(nlv.name) + .setNullable(nlv.nullable) + .setDataType(dataTypeProto.get) + .build()) } } @@ -116,11 +154,11 @@ object CometNamedLambdaVariable extends CometExpressionSerde[SparkNamedLambdaVar expr: SparkNamedLambdaVariable, inputs: Seq[Attribute], binding: Boolean): Option[ExprOuterClass.Expr] = { - val nlvProto = CometHighOrderFunction.nlv2Proto(expr) - Some( + CometHighOrderFunction.nlv2Proto(expr).map { nlvProto => ExprOuterClass.Expr .newBuilder() .setNamedLambdaVariable(nlvProto) - .build()) + .build() + } } } diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala b/spark/src/main/scala/org/apache/comet/serde/arrays.scala index bf75865b61..b8df7f1b29 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala @@ -684,13 +684,6 @@ object CometFlatten extends CometExpressionSerde[Flatten] with ArraysBase { object CometArrayFilter extends CometHighOrderFunction[ArrayFilter]("array_filter") { - override def getSupportLevel(expr: ArrayFilter): SupportLevel = { - if (!CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED.get()) { - return super.getSupportLevel(expr) - } - Compatible() - } - override def convert( expr: ArrayFilter, inputs: Seq[Attribute], @@ -700,13 +693,8 @@ object CometArrayFilter extends CometHighOrderFunction[ArrayFilter]("array_filte // Fast path: `array_compact` lowers to `filter(arr, x -> x is not null)`. Use the native // array_compact serde to avoid the per-batch JNI cost of the codegen dispatcher. CometArrayCompact.convert(expr, inputs, binding) - case _ if !CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED.get() => - super.convert(expr, inputs, binding) case _ => - // General lambda: run Spark's own evaluation through the codegen dispatcher so the result - // matches Spark exactly, like the other higher-order functions (`transform`, `exists`). - // Falls back to Spark when the dispatcher is disabled. - CometScalaUDF.emitJvmCodegenDispatch(expr, inputs, binding) + super.convert(expr, inputs, binding) } } } diff --git a/spark/src/test/resources/sql-tests/expressions/array/array_filter.sql b/spark/src/test/resources/sql-tests/expressions/array/array_filter.sql index c4511e28e4..91af8bcdf5 100644 --- a/spark/src/test/resources/sql-tests/expressions/array/array_filter.sql +++ b/spark/src/test/resources/sql-tests/expressions/array/array_filter.sql @@ -19,7 +19,7 @@ statement CREATE TABLE test_array_filter(arr array) USING parquet statement -INSERT INTO test_array_filter VALUES (array(1, 2, 3, 4, 5)), (array(-1, 0, 1)), (array(10)), (NULL) +INSERT INTO test_array_filter VALUES (array(1, 2, 3, 4, 5)), (array(-1, 0, 1, NULL)), (array(NULL, NULL)), (array(10)), (NULL), (array()) query SELECT filter(arr, x -> x > 2) FROM test_array_filter diff --git a/spark/src/test/resources/sql-tests/expressions/array/array_filter_native.sql b/spark/src/test/resources/sql-tests/expressions/array/array_filter_fallback.sql similarity index 70% rename from spark/src/test/resources/sql-tests/expressions/array/array_filter_native.sql rename to spark/src/test/resources/sql-tests/expressions/array/array_filter_fallback.sql index ee0b934ad6..7431d45305 100644 --- a/spark/src/test/resources/sql-tests/expressions/array/array_filter_native.sql +++ b/spark/src/test/resources/sql-tests/expressions/array/array_filter_fallback.sql @@ -15,20 +15,8 @@ -- specific language governing permissions and limitations -- under the License. +-- Config: spark.comet.exec.higherOrderFunction.native.enabled=false -- Config: spark.comet.exec.scalaUDF.codegen.enabled=false -statement -CREATE TABLE test_array_filter_native(arr array) USING parquet - -statement -INSERT INTO test_array_filter_native VALUES (array(1, 2, 3, 4, 5)), (array(-1, 0, 1)), (array(10)), (NULL) - -query -SELECT filter(arr, x -> x > 2) FROM test_array_filter_native - -query -SELECT filter(arr, x -> x >= 0) FROM test_array_filter_native - query expect_fallback(DataFusion higher-order functions support only 1 argument) -SELECT filter(arr, (x, i) -> i > 0) FROM test_array_filter_native - +SELECT filter(array(1), (x, i) -> i > 0) From 47f2726ee56a433d8b7765545dff48957c9a2ca6 Mon Sep 17 00:00:00 2001 From: Kazantsev Maksim Date: Wed, 1 Jul 2026 22:05:17 +0400 Subject: [PATCH 07/10] Fix PR issues --- .../comet/serde/CometHighOrderFunction.scala | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/CometHighOrderFunction.scala b/spark/src/main/scala/org/apache/comet/serde/CometHighOrderFunction.scala index 9e3845e641..844a6b949f 100644 --- a/spark/src/main/scala/org/apache/comet/serde/CometHighOrderFunction.scala +++ b/spark/src/main/scala/org/apache/comet/serde/CometHighOrderFunction.scala @@ -31,16 +31,27 @@ import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, serializeData /** * Serializer that converts Spark higher-order functions (e.g. `filter`, `transform`, `exists`) - * into Comet's protobuf representation so they can be executed by the native DataFusion engine. + * into Comet's protobuf representation. * * A higher-order function carries two kinds of children: * - "value" arguments: the regular input expressions, typically the array/map being processed. * - "function" arguments: one or more lambda expressions describing the per-element * computation. * - * This serde only supports the subset of higher-order functions that the native engine can - * currently handle; see [[getSupportLevel]] for the exact constraints. + * Depending on the available configuration and on whether the expression satisfies the native + * constraints, [[convert]] produces one of two representations: + * - a native higher-order function proto (executed by the DataFusion engine), used when + * `COMET_EXEC_HIGHER_ORDER_FUNCTION_NATIVE_ENABLED` is set and the expression is natively + * supported (see [[nativeUnsupportedReason]] / [[getSupportLevel]]); or + * - a JVM codegen dispatch (Scala UDF fallback via `CometScalaUDF.emitJvmCodegenDispatch`), + * used when the native path is unavailable but `COMET_SCALA_UDF_CODEGEN_ENABLED` is enabled. + * + * Native execution is limited to the subset of higher-order functions the engine can currently + * handle: every function child must be a `LambdaFunction` taking exactly one argument, and every + * lambda argument must be a `NamedLambdaVariable`. See [[getSupportLevel]] for the exact + * constraints under which the expression is reported as [[Compatible]] vs [[Unsupported]]. */ + case class CometHighOrderFunction[T <: HigherOrderFunction](name: String) extends CometExpressionSerde[T] { From 435365aabb28305a11d93517fe4b3581deba62a5 Mon Sep 17 00:00:00 2001 From: Kazantsev Maksim Date: Wed, 1 Jul 2026 22:07:31 +0400 Subject: [PATCH 08/10] Fix PR issues --- .../apache/comet/serde/CometHighOrderFunction.scala | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/CometHighOrderFunction.scala b/spark/src/main/scala/org/apache/comet/serde/CometHighOrderFunction.scala index 844a6b949f..fa3f7e3543 100644 --- a/spark/src/main/scala/org/apache/comet/serde/CometHighOrderFunction.scala +++ b/spark/src/main/scala/org/apache/comet/serde/CometHighOrderFunction.scala @@ -33,11 +33,6 @@ import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, serializeData * Serializer that converts Spark higher-order functions (e.g. `filter`, `transform`, `exists`) * into Comet's protobuf representation. * - * A higher-order function carries two kinds of children: - * - "value" arguments: the regular input expressions, typically the array/map being processed. - * - "function" arguments: one or more lambda expressions describing the per-element - * computation. - * * Depending on the available configuration and on whether the expression satisfies the native * constraints, [[convert]] produces one of two representations: * - a native higher-order function proto (executed by the DataFusion engine), used when @@ -45,11 +40,6 @@ import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, serializeData * supported (see [[nativeUnsupportedReason]] / [[getSupportLevel]]); or * - a JVM codegen dispatch (Scala UDF fallback via `CometScalaUDF.emitJvmCodegenDispatch`), * used when the native path is unavailable but `COMET_SCALA_UDF_CODEGEN_ENABLED` is enabled. - * - * Native execution is limited to the subset of higher-order functions the engine can currently - * handle: every function child must be a `LambdaFunction` taking exactly one argument, and every - * lambda argument must be a `NamedLambdaVariable`. See [[getSupportLevel]] for the exact - * constraints under which the expression is reported as [[Compatible]] vs [[Unsupported]]. */ case class CometHighOrderFunction[T <: HigherOrderFunction](name: String) From fc132271c52d80d56ee58f6b110e77218785a11b Mon Sep 17 00:00:00 2001 From: Kazantsev Maksim Date: Fri, 3 Jul 2026 20:19:06 +0400 Subject: [PATCH 09/10] fix PR issues --- native/core/src/execution/planner.rs | 3 ++ .../comet/serde/CometHighOrderFunction.scala | 44 +++++++++---------- .../array/array_filter_fallback.sql | 22 ---------- 3 files changed, 25 insertions(+), 44 deletions(-) delete mode 100644 spark/src/test/resources/sql-tests/expressions/array/array_filter_fallback.sql diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 45d7a18f62..4b45ff2e0f 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -3171,10 +3171,12 @@ impl PhysicalPlanner { lambda: &LambdaFunction, input_schema: &SchemaRef, ) -> Result, ExecutionError> { + println!("{}", input_schema); let mut body_fields: Vec> = input_schema.fields().iter().map(Arc::clone).collect(); for arg in &lambda.args { + println!("{}", arg.name); let data_type = arg.data_type.as_ref().ok_or_else(|| { DataFusionError::Internal("lambda variable without data type".to_string()) })?; @@ -3185,6 +3187,7 @@ impl PhysicalPlanner { arg.nullable, ))); } + println!("{}", body_fields.len()); let body_schema = Arc::new(Schema::new( body_fields diff --git a/spark/src/main/scala/org/apache/comet/serde/CometHighOrderFunction.scala b/spark/src/main/scala/org/apache/comet/serde/CometHighOrderFunction.scala index fa3f7e3543..b598209525 100644 --- a/spark/src/main/scala/org/apache/comet/serde/CometHighOrderFunction.scala +++ b/spark/src/main/scala/org/apache/comet/serde/CometHighOrderFunction.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, HigherOrderFunction import org.apache.comet.CometConf import org.apache.comet.CometSparkSessionExtensions.withFallbackReason -import org.apache.comet.serde.CometHighOrderFunction.nlv2Proto +import org.apache.comet.serde.CometHighOrderFunction.namedLambdaVariable2Proto import org.apache.comet.serde.ExprOuterClass.{HigherOrderFunc, LambdaFunction, NamedLambdaVariable} import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, serializeDataType} @@ -57,10 +57,6 @@ case class CometHighOrderFunction[T <: HigherOrderFunction](name: String) if (!expr.functions.forall(_.isInstanceOf[SparkLambdaFunction])) { return Some(UNSUPPORTED_LAMBDA_TYPE) } - val lambdaFunctions = expr.functions.map(_.asInstanceOf[SparkLambdaFunction]) - if (!lambdaFunctions.forall(_.arguments.length == 1)) { - return Some(UNARY_FUNCTION_EXPECTED) - } if (!expr.functions .flatMap(_.asInstanceOf[SparkLambdaFunction].arguments) .forall(_.isInstanceOf[SparkNamedLambdaVariable])) { @@ -85,7 +81,7 @@ case class CometHighOrderFunction[T <: HigherOrderFunction](name: String) val nativeAvailable = nativeUnsupportedReason( expr).isEmpty && CometConf.COMET_EXEC_HIGHER_ORDER_FUNCTION_NATIVE_ENABLED.get() - val hofProto = hof2Proto(expr, inputs, binding) + val hofProto = highOrderFunction2Proto(expr, inputs, binding) if (nativeAvailable && hofProto.isDefined) { hofProto } else { @@ -93,24 +89,26 @@ case class CometHighOrderFunction[T <: HigherOrderFunction](name: String) } } - private def hof2Proto( + private def highOrderFunction2Proto( expr: T, inputs: Seq[Attribute], binding: Boolean): Option[ExprOuterClass.Expr] = { val argumentsProto = expr.arguments.map(exprToProtoInternal(_, inputs, binding)) val functionsProto = expr.functions - .map(_.asInstanceOf[SparkLambdaFunction]) - .map { slf => - exprToProtoInternal(slf.function, inputs, binding) + .map { func => + val sparkLambdaFunction = func.asInstanceOf[SparkLambdaFunction] + exprToProtoInternal(sparkLambdaFunction.function, inputs, binding) .flatMap { bodyProto => - val nlvProto = slf.arguments - .map(_.asInstanceOf[SparkNamedLambdaVariable]) - .map(nlv2Proto) - if (nlvProto.forall(_.isDefined)) { + val namedLambdaVariablesProto = sparkLambdaFunction.arguments + .map { arg => + val sparkNamedLambdaVariable = arg.asInstanceOf[SparkNamedLambdaVariable] + namedLambdaVariable2Proto(sparkNamedLambdaVariable) + } + if (namedLambdaVariablesProto.forall(_.isDefined)) { Some( LambdaFunction .newBuilder() - .addAllArgs(nlvProto.map(_.get).asJava) + .addAllArgs(namedLambdaVariablesProto.map(_.get).asJava) .setBody(bodyProto) .build()) } else { @@ -134,7 +132,7 @@ case class CometHighOrderFunction[T <: HigherOrderFunction](name: String) } object CometHighOrderFunction { - def nlv2Proto(nlv: SparkNamedLambdaVariable): Option[NamedLambdaVariable] = { + def namedLambdaVariable2Proto(nlv: SparkNamedLambdaVariable): Option[NamedLambdaVariable] = { val dataTypeProto = serializeDataType(nlv.dataType) if (dataTypeProto.isEmpty) { withFallbackReason(nlv, s"Unsupported datatype: ${nlv.dataType}") @@ -155,11 +153,13 @@ object CometNamedLambdaVariable extends CometExpressionSerde[SparkNamedLambdaVar expr: SparkNamedLambdaVariable, inputs: Seq[Attribute], binding: Boolean): Option[ExprOuterClass.Expr] = { - CometHighOrderFunction.nlv2Proto(expr).map { nlvProto => - ExprOuterClass.Expr - .newBuilder() - .setNamedLambdaVariable(nlvProto) - .build() - } + CometHighOrderFunction + .namedLambdaVariable2Proto(expr) + .map { nlvProto => + ExprOuterClass.Expr + .newBuilder() + .setNamedLambdaVariable(nlvProto) + .build() + } } } diff --git a/spark/src/test/resources/sql-tests/expressions/array/array_filter_fallback.sql b/spark/src/test/resources/sql-tests/expressions/array/array_filter_fallback.sql deleted file mode 100644 index 7431d45305..0000000000 --- a/spark/src/test/resources/sql-tests/expressions/array/array_filter_fallback.sql +++ /dev/null @@ -1,22 +0,0 @@ --- Licensed to the Apache Software Foundation (ASF) under one --- or more contributor license agreements. See the NOTICE file --- distributed with this work for additional information --- regarding copyright ownership. The ASF licenses this file --- to you under the Apache License, Version 2.0 (the --- "License"); you may not use this file except in compliance --- with the License. You may obtain a copy of the License at --- --- http://www.apache.org/licenses/LICENSE-2.0 --- --- Unless required by applicable law or agreed to in writing, --- software distributed under the License is distributed on an --- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY --- KIND, either express or implied. See the License for the --- specific language governing permissions and limitations --- under the License. - --- Config: spark.comet.exec.higherOrderFunction.native.enabled=false --- Config: spark.comet.exec.scalaUDF.codegen.enabled=false - -query expect_fallback(DataFusion higher-order functions support only 1 argument) -SELECT filter(array(1), (x, i) -> i > 0) From 9102c3daf76b1757e14e114055fe0b8778a9f3fc Mon Sep 17 00:00:00 2001 From: Kazantsev Maksim Date: Fri, 3 Jul 2026 22:32:59 +0400 Subject: [PATCH 10/10] Fix PR issues --- native/core/src/execution/planner.rs | 3 -- .../comet/serde/CometHighOrderFunction.scala | 6 ++-- .../scala/org/apache/comet/serde/arrays.scala | 29 ++++++++++++++++++- .../expressions/array/array_filter.sql | 10 +++++-- .../array/array_filter_fallback.sql | 21 ++++++++++++++ 5 files changed, 59 insertions(+), 10 deletions(-) create mode 100644 spark/src/test/resources/sql-tests/expressions/array/array_filter_fallback.sql diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 4b45ff2e0f..45d7a18f62 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -3171,12 +3171,10 @@ impl PhysicalPlanner { lambda: &LambdaFunction, input_schema: &SchemaRef, ) -> Result, ExecutionError> { - println!("{}", input_schema); let mut body_fields: Vec> = input_schema.fields().iter().map(Arc::clone).collect(); for arg in &lambda.args { - println!("{}", arg.name); let data_type = arg.data_type.as_ref().ok_or_else(|| { DataFusionError::Internal("lambda variable without data type".to_string()) })?; @@ -3187,7 +3185,6 @@ impl PhysicalPlanner { arg.nullable, ))); } - println!("{}", body_fields.len()); let body_schema = Arc::new(Schema::new( body_fields diff --git a/spark/src/main/scala/org/apache/comet/serde/CometHighOrderFunction.scala b/spark/src/main/scala/org/apache/comet/serde/CometHighOrderFunction.scala index b598209525..7558f92d5d 100644 --- a/spark/src/main/scala/org/apache/comet/serde/CometHighOrderFunction.scala +++ b/spark/src/main/scala/org/apache/comet/serde/CometHighOrderFunction.scala @@ -21,7 +21,7 @@ package org.apache.comet.serde import scala.jdk.CollectionConverters._ -import org.apache.spark.sql.catalyst.expressions.{Attribute, HigherOrderFunction, LambdaFunction => SparkLambdaFunction, NamedLambdaVariable => SparkNamedLambdaVariable} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, HigherOrderFunction, LambdaFunction => SparkLambdaFunction, NamedLambdaVariable => SparkNamedLambdaVariable} import org.apache.comet.CometConf import org.apache.comet.CometSparkSessionExtensions.withFallbackReason @@ -47,11 +47,9 @@ case class CometHighOrderFunction[T <: HigherOrderFunction](name: String) private val UNSUPPORTED_LAMBDA_TYPE = "lambda functions must be LambdaFunction" private val UNSUPPORTED_LAMBDA_PARAM_TYPE = "lambda arguments must be NamedLambdaVariables" - private val UNARY_FUNCTION_EXPECTED = - "DataFusion higher-order functions support only 1 argument" override def getUnsupportedReasons(): Seq[String] = - Seq(UNSUPPORTED_LAMBDA_TYPE, UNARY_FUNCTION_EXPECTED, UNSUPPORTED_LAMBDA_PARAM_TYPE) + Seq(UNSUPPORTED_LAMBDA_TYPE, UNSUPPORTED_LAMBDA_PARAM_TYPE) private def nativeUnsupportedReason(expr: T): Option[String] = { if (!expr.functions.forall(_.isInstanceOf[SparkLambdaFunction])) { diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala b/spark/src/main/scala/org/apache/comet/serde/arrays.scala index 1a931ef06d..e8d7e7fd9a 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala @@ -22,7 +22,7 @@ package org.apache.comet.serde import scala.annotation.tailrec import scala.jdk.CollectionConverters._ -import org.apache.spark.sql.catalyst.expressions.{And, ArrayAggregate, ArrayAppend, ArrayContains, ArrayExcept, ArrayExists, ArrayFilter, ArrayForAll, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayPosition, ArrayRemove, ArrayRepeat, ArraySort, ArraysOverlap, ArraysZip, ArrayTransform, ArrayUnion, Attribute, Cast, CreateArray, ElementAt, EmptyRow, Expression, Flatten, GetArrayItem, IsNotNull, Literal, Reverse, Sequence, Size, Slice, SortArray, ZipWith} +import org.apache.spark.sql.catalyst.expressions.{And, ArrayAggregate, ArrayAppend, ArrayContains, ArrayExcept, ArrayExists, ArrayFilter, ArrayForAll, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayPosition, ArrayRemove, ArrayRepeat, ArraySort, ArraysOverlap, ArraysZip, ArrayTransform, ArrayUnion, Attribute, Cast, CreateArray, ElementAt, EmptyRow, Expression, Flatten, GetArrayItem, IsNotNull, LambdaFunction, Literal, Reverse, Sequence, Size, Slice, SortArray, ZipWith} import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -724,6 +724,31 @@ object CometFlatten extends CometExpressionSerde[Flatten] with ArraysBase { object CometArrayFilter extends CometHighOrderFunction[ArrayFilter]("array_filter") { + private val UNARY_FUNCTION_EXPECTED = + "The array_filter function in DataFusion is limited to one lambda parameter" + + override def getUnsupportedReasons(): Seq[String] = Seq(UNARY_FUNCTION_EXPECTED) + + private def isUnaryLambdaFunction(expr: ArrayFilter): Boolean = { + expr.function match { + case function: LambdaFunction => + function.arguments.length == 1 + case _ => + false + } + } + + override def getSupportLevel(expr: ArrayFilter): SupportLevel = { + if (!isUnaryLambdaFunction(expr)) { + if (CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED.get()) { + return Compatible() + } else { + return Unsupported(Some(UNARY_FUNCTION_EXPECTED)) + } + } + super.getSupportLevel(expr) + } + override def convert( expr: ArrayFilter, inputs: Seq[Attribute], @@ -733,6 +758,8 @@ object CometArrayFilter extends CometHighOrderFunction[ArrayFilter]("array_filte // Fast path: `array_compact` lowers to `filter(arr, x -> x is not null)`. Use the native // array_compact serde to avoid the per-batch JNI cost of the codegen dispatcher. CometArrayCompact.convert(expr, inputs, binding) + case _ if !isUnaryLambdaFunction(expr) => + CometScalaUDF.emitJvmCodegenDispatch(expr, inputs, binding) case _ => super.convert(expr, inputs, binding) } diff --git a/spark/src/test/resources/sql-tests/expressions/array/array_filter.sql b/spark/src/test/resources/sql-tests/expressions/array/array_filter.sql index 91af8bcdf5..c4e24a9ae2 100644 --- a/spark/src/test/resources/sql-tests/expressions/array/array_filter.sql +++ b/spark/src/test/resources/sql-tests/expressions/array/array_filter.sql @@ -16,10 +16,10 @@ -- under the License. statement -CREATE TABLE test_array_filter(arr array) USING parquet +CREATE TABLE test_array_filter(arr array, threshold int) USING parquet statement -INSERT INTO test_array_filter VALUES (array(1, 2, 3, 4, 5)), (array(-1, 0, 1, NULL)), (array(NULL, NULL)), (array(10)), (NULL), (array()) +INSERT INTO test_array_filter VALUES (array(1, 2, 3, 4, 5), 10), (array(-1, 0, 1, NULL), 10), (array(NULL, NULL), 10), (array(10), 10), (NULL, 10), (array(), 10) query SELECT filter(arr, x -> x > 2) FROM test_array_filter @@ -27,5 +27,11 @@ SELECT filter(arr, x -> x > 2) FROM test_array_filter query SELECT filter(arr, x -> x >= 0) FROM test_array_filter +query +SELECT filter(filter(arr, x -> x < threshold), y -> y > 0) FROM test_array_filter + +query +SELECT filter(filter(arr, x -> x < threshold), y -> y > 0) FROM test_array_filter + query SELECT filter(arr, (x, i) -> i > 0) FROM test_array_filter diff --git a/spark/src/test/resources/sql-tests/expressions/array/array_filter_fallback.sql b/spark/src/test/resources/sql-tests/expressions/array/array_filter_fallback.sql new file mode 100644 index 0000000000..a6afa4504f --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/array/array_filter_fallback.sql @@ -0,0 +1,21 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- Config: spark.comet.exec.scalaUDF.codegen.enabled=false + +query expect_fallback(The array_filter function in DataFusion is limited to one lambda parameter) +SELECT filter(array(1, 2), (x, i) -> i > 0)