diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 25162332fd..45d7a18f62 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -72,9 +72,9 @@ use datafusion::{ prelude::SessionContext, }; use datafusion_comet_spark_expr::{ - create_comet_physical_fun, create_comet_physical_fun_with_eval_mode, BinaryOutputStyle, - BloomFilterAgg, BloomFilterMightContain, CsvWriteOptions, EvalMode, SparkArraysZipFunc, - SparkBloomFilterVersion, SumInteger, ToCsv, + create_comet_hof_func, create_comet_physical_fun, create_comet_physical_fun_with_eval_mode, + BinaryOutputStyle, BloomFilterAgg, BloomFilterMightContain, CsvWriteOptions, EvalMode, + SparkArraysZipFunc, SparkBloomFilterVersion, SumInteger, ToCsv, }; use datafusion_spark::function::aggregate::collect::SparkCollectSet; use iceberg::expr::Bind; @@ -95,9 +95,9 @@ use datafusion::logical_expr::{ AggregateUDF, ReturnFieldArgs, ScalarUDF, TypeSignature, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; -use datafusion::physical_expr::expressions::{Literal, StatsType}; +use datafusion::physical_expr::expressions::{LambdaExpr, LambdaVariable, Literal, StatsType}; use datafusion::physical_expr::window::WindowExpr; -use datafusion::physical_expr::LexOrdering; +use datafusion::physical_expr::{HigherOrderFunctionExpr, LexOrdering}; use crate::parquet::parquet_exec::init_datasource_exec; use arrow::array::{ @@ -113,7 +113,7 @@ use datafusion::physical_plan::filter::FilterExec; use datafusion::physical_plan::joins::NestedLoopJoinExec; use datafusion::physical_plan::limit::GlobalLimitExec; use datafusion::physical_plan::unnest::{ListUnnest, UnnestExec}; -use datafusion_comet_proto::spark_expression::ListLiteral; +use datafusion_comet_proto::spark_expression::{HigherOrderFunc, LambdaFunction, ListLiteral}; use datafusion_comet_proto::spark_operator::SparkFilePartition; use datafusion_comet_proto::{ spark_expression::{ @@ -532,6 +532,19 @@ impl PhysicalPlanner { _ => func, } } + ExprStruct::HighOrderFunc(hof) => { + self.create_high_order_function_expr(hof, input_schema) + } + ExprStruct::NamedLambdaVariable(nlv) => { + let idx = input_schema.index_of(&nlv.name).map_err(|_| { + GeneralError(format!( + "NamedLambdaVariable '{}' not found in enclosing lambda schema", + nlv.name + )) + })?; + let field = Arc::clone(&input_schema.fields()[idx]); + Ok(Arc::new(LambdaVariable::new(idx, field))) + } ExprStruct::CaseWhen(case_when) => { let when_then_pairs = case_when .when @@ -3113,6 +3126,89 @@ impl PhysicalPlanner { } } + fn create_high_order_function_expr( + &self, + expr: &HigherOrderFunc, + input_schema: SchemaRef, + ) -> Result, ExecutionError> { + let comet_hof_func = + create_comet_hof_func(expr.func_name.as_str(), &self.session_ctx.state())?; + + let value_args = expr + .value_args + .iter() + .map(|e| self.create_expr(e, Arc::clone(&input_schema))) + .collect::, _>>()?; + + let lambdas = expr + .lambdas + .iter() + .map(|l| self.create_lambda_expr(l, &input_schema)) + .collect::, _>>()?; + + // NOTE: assumes all value arguments precede all lambda arguments. + // Holds for array_filter and the current single-lambda HOFs, but would + // NOT generalize to a future HOF with interleaved value/lambda args + // (e.g. f(value, lambda, value, lambda)). Revisit this split if such a + // function is added. + let mut args: Vec> = + Vec::with_capacity(value_args.len() + lambdas.len()); + args.extend(value_args); + args.extend(lambdas); + + let higher_order_function_expr = HigherOrderFunctionExpr::try_new_with_schema( + comet_hof_func, + args, + &input_schema, + Arc::new(ConfigOptions::default()), + )?; + + Ok(Arc::new(higher_order_function_expr)) + } + + fn create_lambda_expr( + &self, + lambda: &LambdaFunction, + input_schema: &SchemaRef, + ) -> Result, ExecutionError> { + let mut body_fields: Vec> = + input_schema.fields().iter().map(Arc::clone).collect(); + + for arg in &lambda.args { + let data_type = arg.data_type.as_ref().ok_or_else(|| { + DataFusionError::Internal("lambda variable without data type".to_string()) + })?; + let arrow_data_type = to_arrow_datatype(data_type); + body_fields.push(Arc::new(Field::new( + &arg.name, + arrow_data_type, + arg.nullable, + ))); + } + + let body_schema = Arc::new(Schema::new( + body_fields + .iter() + .map(|f| f.as_ref().clone()) + .collect::>(), + )); + + let lambda_body = lambda + .body + .as_ref() + .ok_or_else(|| DataFusionError::Internal("lambda has no body".to_string()))?; + let body_expr = self.create_expr(lambda_body, body_schema)?; + + Ok(Arc::new(LambdaExpr::try_new( + lambda + .args + .iter() + .map(|a| a.name.clone()) + .collect::>(), + body_expr, + )?)) + } + fn create_scalar_function_expr( &self, expr: &ScalarFunc, diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 32adc16b72..64d136878a 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -93,6 +93,8 @@ message Expr { JvmScalarUdf jvm_scalar_udf = 70; PreciseTimestampConversion precise_timestamp_conversion = 71; Shuffle shuffle = 72; + HigherOrderFunc high_order_func = 73; + NamedLambdaVariable named_lambda_variable = 74; } // Optional QueryContext for error reporting (contains SQL text and position) @@ -557,3 +559,20 @@ message JvmScalarUdf { // Whether the result column may contain nulls. bool return_nullable = 4; } + +message HigherOrderFunc { + string func_name = 1; + repeated Expr value_args = 2; + repeated LambdaFunction lambdas = 3; +} + +message NamedLambdaVariable { + string name = 1; + DataType data_type = 2; + bool nullable = 3; +} + +message LambdaFunction { + Expr body = 1; + repeated NamedLambdaVariable args = 2; +} diff --git a/native/spark-expr/src/comet_high_order_funcs.rs b/native/spark-expr/src/comet_high_order_funcs.rs new file mode 100644 index 0000000000..7277d1ee66 --- /dev/null +++ b/native/spark-expr/src/comet_high_order_funcs.rs @@ -0,0 +1,30 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::common::DataFusionError; +use datafusion::execution::FunctionRegistry; +use datafusion::logical_expr::HigherOrderUDF; +use std::sync::Arc; + +pub fn create_comet_hof_func( + func_name: &str, + registry: &dyn FunctionRegistry, +) -> Result, DataFusionError> { + registry.higher_order_function(func_name).map_err(|e| { + DataFusionError::Execution(format!("HOF {func_name} not found in the registry: {e}")) + }) +} diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index 174a4ada9a..bdffad94a9 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -42,6 +42,7 @@ pub use predicate_funcs::{spark_isnan, RLike}; mod agg_funcs; mod array_funcs; +mod comet_high_order_funcs; mod comet_scalar_funcs; pub mod hash_funcs; @@ -70,6 +71,7 @@ pub use conditional_funcs::*; pub use conversion_funcs::*; pub use nondetermenistic_funcs::*; +pub use comet_high_order_funcs::create_comet_hof_func; pub use comet_scalar_funcs::{ create_comet_physical_fun, create_comet_physical_fun_with_eval_mode, register_all_comet_functions, diff --git a/spark/src/main/scala/org/apache/comet/CometConf.scala b/spark/src/main/scala/org/apache/comet/CometConf.scala index 8e47151358..366cd01e2b 100644 --- a/spark/src/main/scala/org/apache/comet/CometConf.scala +++ b/spark/src/main/scala/org/apache/comet/CometConf.scala @@ -381,6 +381,16 @@ object CometConf extends ShimCometConf { .booleanConf .createWithDefault(true) + val COMET_EXEC_HIGHER_ORDER_FUNCTION_NATIVE_ENABLED: ConfigEntry[Boolean] = + conf("spark.comet.exec.higherOrderFunction.native.enabled") + .category(CATEGORY_EXEC) + .doc( + "When enabled, supported higher-order functions (e.g. filter) are executed by the " + + "native DataFusion engine. Shapes the native path cannot handle fall back to the " + + "codegen dispatcher, and finally to Spark.") + .booleanConf + .createWithDefault(true) + val COMET_EXEC_SHUFFLE_WITH_HASH_PARTITIONING_ENABLED: ConfigEntry[Boolean] = conf("spark.comet.native.shuffle.partitioning.hash.enabled") .category(CATEGORY_SHUFFLE) diff --git a/spark/src/main/scala/org/apache/comet/serde/CometHighOrderFunction.scala b/spark/src/main/scala/org/apache/comet/serde/CometHighOrderFunction.scala new file mode 100644 index 0000000000..7558f92d5d --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/serde/CometHighOrderFunction.scala @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde + +import scala.jdk.CollectionConverters._ + +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, HigherOrderFunction, LambdaFunction => SparkLambdaFunction, NamedLambdaVariable => SparkNamedLambdaVariable} + +import org.apache.comet.CometConf +import org.apache.comet.CometSparkSessionExtensions.withFallbackReason +import org.apache.comet.serde.CometHighOrderFunction.namedLambdaVariable2Proto +import org.apache.comet.serde.ExprOuterClass.{HigherOrderFunc, LambdaFunction, NamedLambdaVariable} +import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, serializeDataType} + +/** + * Serializer that converts Spark higher-order functions (e.g. `filter`, `transform`, `exists`) + * into Comet's protobuf representation. + * + * Depending on the available configuration and on whether the expression satisfies the native + * constraints, [[convert]] produces one of two representations: + * - a native higher-order function proto (executed by the DataFusion engine), used when + * `COMET_EXEC_HIGHER_ORDER_FUNCTION_NATIVE_ENABLED` is set and the expression is natively + * supported (see [[nativeUnsupportedReason]] / [[getSupportLevel]]); or + * - a JVM codegen dispatch (Scala UDF fallback via `CometScalaUDF.emitJvmCodegenDispatch`), + * used when the native path is unavailable but `COMET_SCALA_UDF_CODEGEN_ENABLED` is enabled. + */ + +case class CometHighOrderFunction[T <: HigherOrderFunction](name: String) + extends CometExpressionSerde[T] { + + private val UNSUPPORTED_LAMBDA_TYPE = "lambda functions must be LambdaFunction" + private val UNSUPPORTED_LAMBDA_PARAM_TYPE = "lambda arguments must be NamedLambdaVariables" + + override def getUnsupportedReasons(): Seq[String] = + Seq(UNSUPPORTED_LAMBDA_TYPE, UNSUPPORTED_LAMBDA_PARAM_TYPE) + + private def nativeUnsupportedReason(expr: T): Option[String] = { + if (!expr.functions.forall(_.isInstanceOf[SparkLambdaFunction])) { + return Some(UNSUPPORTED_LAMBDA_TYPE) + } + if (!expr.functions + .flatMap(_.asInstanceOf[SparkLambdaFunction].arguments) + .forall(_.isInstanceOf[SparkNamedLambdaVariable])) { + return Some(UNSUPPORTED_LAMBDA_PARAM_TYPE) + } + None + } + + override def getSupportLevel(expr: T): SupportLevel = { + val unsupportedReason = nativeUnsupportedReason(expr) + val nativeAvailable = + unsupportedReason.isEmpty && CometConf.COMET_EXEC_HIGHER_ORDER_FUNCTION_NATIVE_ENABLED.get() + val codegenEnabled = CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED.get() + if (nativeAvailable || codegenEnabled) { + Compatible() + } else { + Unsupported(unsupportedReason) + } + } + + def convert(expr: T, inputs: Seq[Attribute], binding: Boolean): Option[ExprOuterClass.Expr] = { + val nativeAvailable = + nativeUnsupportedReason( + expr).isEmpty && CometConf.COMET_EXEC_HIGHER_ORDER_FUNCTION_NATIVE_ENABLED.get() + val hofProto = highOrderFunction2Proto(expr, inputs, binding) + if (nativeAvailable && hofProto.isDefined) { + hofProto + } else { + CometScalaUDF.emitJvmCodegenDispatch(expr, inputs, binding) + } + } + + private def highOrderFunction2Proto( + expr: T, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val argumentsProto = expr.arguments.map(exprToProtoInternal(_, inputs, binding)) + val functionsProto = expr.functions + .map { func => + val sparkLambdaFunction = func.asInstanceOf[SparkLambdaFunction] + exprToProtoInternal(sparkLambdaFunction.function, inputs, binding) + .flatMap { bodyProto => + val namedLambdaVariablesProto = sparkLambdaFunction.arguments + .map { arg => + val sparkNamedLambdaVariable = arg.asInstanceOf[SparkNamedLambdaVariable] + namedLambdaVariable2Proto(sparkNamedLambdaVariable) + } + if (namedLambdaVariablesProto.forall(_.isDefined)) { + Some( + LambdaFunction + .newBuilder() + .addAllArgs(namedLambdaVariablesProto.map(_.get).asJava) + .setBody(bodyProto) + .build()) + } else { + None + } + } + } + if (functionsProto.forall(_.isDefined) && argumentsProto.forall(_.isDefined)) { + val hof = HigherOrderFunc + .newBuilder() + .setFuncName(name) + .addAllValueArgs(argumentsProto.map(_.get).asJava) + .addAllLambdas(functionsProto.map(_.get).asJava) + .build() + Some(ExprOuterClass.Expr.newBuilder().setHighOrderFunc(hof).build()) + } else { + withFallbackReason(expr, expr.children: _*) + None + } + } +} + +object CometHighOrderFunction { + def namedLambdaVariable2Proto(nlv: SparkNamedLambdaVariable): Option[NamedLambdaVariable] = { + val dataTypeProto = serializeDataType(nlv.dataType) + if (dataTypeProto.isEmpty) { + withFallbackReason(nlv, s"Unsupported datatype: ${nlv.dataType}") + return None + } + Some( + NamedLambdaVariable + .newBuilder() + .setName(nlv.name) + .setNullable(nlv.nullable) + .setDataType(dataTypeProto.get) + .build()) + } +} + +object CometNamedLambdaVariable extends CometExpressionSerde[SparkNamedLambdaVariable] { + def convert( + expr: SparkNamedLambdaVariable, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + CometHighOrderFunction + .namedLambdaVariable2Proto(expr) + .map { nlvProto => + ExprOuterClass.Expr + .newBuilder() + .setNamedLambdaVariable(nlvProto) + .build() + } + } +} diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 7146eaec9b..4af55c9f2d 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -368,7 +368,8 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim { classOf[SortOrder] -> CometSortOrder, classOf[StaticInvoke] -> CometStaticInvoke, classOf[TryEval] -> CometTryEval, - classOf[UnscaledValue] -> CometUnscaledValue) + classOf[UnscaledValue] -> CometUnscaledValue, + classOf[NamedLambdaVariable] -> CometNamedLambdaVariable) base ++ sparkVersionSpecificMiscExpressions } diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala b/spark/src/main/scala/org/apache/comet/serde/arrays.scala index 8eda097ce6..e8d7e7fd9a 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala @@ -22,7 +22,7 @@ package org.apache.comet.serde import scala.annotation.tailrec import scala.jdk.CollectionConverters._ -import org.apache.spark.sql.catalyst.expressions.{And, ArrayAggregate, ArrayAppend, ArrayContains, ArrayExcept, ArrayExists, ArrayFilter, ArrayForAll, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayPosition, ArrayRemove, ArrayRepeat, ArraySort, ArraysOverlap, ArraysZip, ArrayTransform, ArrayUnion, Attribute, Cast, CreateArray, ElementAt, EmptyRow, Expression, Flatten, GetArrayItem, IsNotNull, Literal, Reverse, Sequence, Size, Slice, SortArray, ZipWith} +import org.apache.spark.sql.catalyst.expressions.{And, ArrayAggregate, ArrayAppend, ArrayContains, ArrayExcept, ArrayExists, ArrayFilter, ArrayForAll, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayPosition, ArrayRemove, ArrayRepeat, ArraySort, ArraysOverlap, ArraysZip, ArrayTransform, ArrayUnion, Attribute, Cast, CreateArray, ElementAt, EmptyRow, Expression, Flatten, GetArrayItem, IsNotNull, LambdaFunction, Literal, Reverse, Sequence, Size, Slice, SortArray, ZipWith} import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -722,9 +722,32 @@ object CometFlatten extends CometExpressionSerde[Flatten] with ArraysBase { } } -object CometArrayFilter extends CometExpressionSerde[ArrayFilter] { +object CometArrayFilter extends CometHighOrderFunction[ArrayFilter]("array_filter") { - override def getSupportLevel(expr: ArrayFilter): SupportLevel = Compatible() + private val UNARY_FUNCTION_EXPECTED = + "The array_filter function in DataFusion is limited to one lambda parameter" + + override def getUnsupportedReasons(): Seq[String] = Seq(UNARY_FUNCTION_EXPECTED) + + private def isUnaryLambdaFunction(expr: ArrayFilter): Boolean = { + expr.function match { + case function: LambdaFunction => + function.arguments.length == 1 + case _ => + false + } + } + + override def getSupportLevel(expr: ArrayFilter): SupportLevel = { + if (!isUnaryLambdaFunction(expr)) { + if (CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED.get()) { + return Compatible() + } else { + return Unsupported(Some(UNARY_FUNCTION_EXPECTED)) + } + } + super.getSupportLevel(expr) + } override def convert( expr: ArrayFilter, @@ -735,11 +758,10 @@ object CometArrayFilter extends CometExpressionSerde[ArrayFilter] { // Fast path: `array_compact` lowers to `filter(arr, x -> x is not null)`. Use the native // array_compact serde to avoid the per-batch JNI cost of the codegen dispatcher. CometArrayCompact.convert(expr, inputs, binding) - case _ => - // General lambda: run Spark's own evaluation through the codegen dispatcher so the result - // matches Spark exactly, like the other higher-order functions (`transform`, `exists`). - // Falls back to Spark when the dispatcher is disabled. + case _ if !isUnaryLambdaFunction(expr) => CometScalaUDF.emitJvmCodegenDispatch(expr, inputs, binding) + case _ => + super.convert(expr, inputs, binding) } } } diff --git a/spark/src/test/resources/sql-tests/expressions/array/array_filter.sql b/spark/src/test/resources/sql-tests/expressions/array/array_filter.sql index c4511e28e4..c4e24a9ae2 100644 --- a/spark/src/test/resources/sql-tests/expressions/array/array_filter.sql +++ b/spark/src/test/resources/sql-tests/expressions/array/array_filter.sql @@ -16,10 +16,10 @@ -- under the License. statement -CREATE TABLE test_array_filter(arr array) USING parquet +CREATE TABLE test_array_filter(arr array, threshold int) USING parquet statement -INSERT INTO test_array_filter VALUES (array(1, 2, 3, 4, 5)), (array(-1, 0, 1)), (array(10)), (NULL) +INSERT INTO test_array_filter VALUES (array(1, 2, 3, 4, 5), 10), (array(-1, 0, 1, NULL), 10), (array(NULL, NULL), 10), (array(10), 10), (NULL, 10), (array(), 10) query SELECT filter(arr, x -> x > 2) FROM test_array_filter @@ -27,5 +27,11 @@ SELECT filter(arr, x -> x > 2) FROM test_array_filter query SELECT filter(arr, x -> x >= 0) FROM test_array_filter +query +SELECT filter(filter(arr, x -> x < threshold), y -> y > 0) FROM test_array_filter + +query +SELECT filter(filter(arr, x -> x < threshold), y -> y > 0) FROM test_array_filter + query SELECT filter(arr, (x, i) -> i > 0) FROM test_array_filter diff --git a/spark/src/test/resources/sql-tests/expressions/array/array_filter_fallback.sql b/spark/src/test/resources/sql-tests/expressions/array/array_filter_fallback.sql new file mode 100644 index 0000000000..a6afa4504f --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/array/array_filter_fallback.sql @@ -0,0 +1,21 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- Config: spark.comet.exec.scalaUDF.codegen.enabled=false + +query expect_fallback(The array_filter function in DataFusion is limited to one lambda parameter) +SELECT filter(array(1, 2), (x, i) -> i > 0) diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometArrayFilterBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometArrayFilterBenchmark.scala new file mode 100644 index 0000000000..a64be019a4 --- /dev/null +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometArrayFilterBenchmark.scala @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.benchmark + +import org.apache.spark.benchmark.Benchmark + +import org.apache.comet.CometConf + +// spotless:off +/** + * Benchmark to measure performance of Comet array expressions. To run this benchmark: + * {{{ + * SPARK_GENERATE_BENCHMARK_FILES=1 make benchmark-org.apache.spark.sql.benchmark.CometArrayFilterBenchmark + * }}} + * Results will be written to "spark/benchmarks/CometArrayFilterBenchmark-**results.txt". + */ +// spotless:on +object CometArrayFilterBenchmark extends CometBenchmarkBase { + + def runExprBenchmark(config: ArrayFilterExprConfig, values: Int, arraySize: Int): Unit = { + val benchmark = + new Benchmark(s"${config.name} (size $arraySize)", values, output = output) + withTempPath { dir => + withTempTable("parquetV1Table") { + prepareTable( + dir, + spark.sql( + s"SELECT sequence(0, cast(rand(42) * $arraySize as int)) AS arr " + + s"FROM range($values)")) + + benchmark.addCase(s"Spark ${config.name}") { _ => + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + spark.sql(config.query).noop() + } + } + + benchmark.addCase(s"Comet (Native) ${config.name}") { _ => + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED.key -> "false") { + spark.sql(config.query).noop() + } + } + + benchmark.addCase(s"Comet (Codegen) ${config.name}") { _ => + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED.key -> "true") { + spark.sql(config.query).noop() + } + } + + benchmark.run() + } + } + } + + def runCometBenchmark(args: Array[String]): Unit = { + val values = 4 * 1024 * 1024 + + val config = + ArrayFilterExprConfig("array_filter", "SELECT filter(arr, x -> x > 2) FROM parquetV1Table") + + runExprBenchmark(config, values, 100) + } +} + +case class ArrayFilterExprConfig(name: String, query: String)