Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
495 changes: 495 additions & 0 deletions native/core/src/execution/operators/dynamic_filter.rs

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions native/core/src/execution/operators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ pub use scan::*;

mod aligned_stream_reader;
mod copy;
mod dynamic_filter;
pub use dynamic_filter::{attach_join_dynamic_filter, DynamicFilterExec};
mod expand;
pub use expand::ExpandExec;
mod iceberg_scan;
Expand Down
40 changes: 38 additions & 2 deletions native/core/src/execution/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ pub mod expression_registry;
pub mod macros;
pub mod operator_registry;

use crate::execution::operators::attach_join_dynamic_filter;
use crate::execution::operators::init_csv_datasource_exec;
use crate::execution::operators::AlignedArrowStreamReader;
use crate::execution::operators::IcebergScanExec;
Expand Down Expand Up @@ -1981,20 +1982,35 @@ impl PhysicalPlanner {
// (which matches DataFusion's default), and swap_inputs would turn LeftAnti
// into RightAnti, which DataFusion rejects with null_aware=true.
if join.build_side == BuildSide::BuildLeft as i32 || join.null_aware_anti_join {
// Null-aware anti joins are excluded from dynamic filtering: NOT IN
// semantics depend on observing build-side nulls, so probe rows must
// not be pre-filtered.
let mut additional_native_plans: Vec<Arc<dyn ExecutionPlan>> = vec![];
let hash_join = self.apply_join_dynamic_filter(
hash_join,
join.dynamic_filter_enabled && !join.null_aware_anti_join,
&mut additional_native_plans,
)?;
Ok((
scans,
shuffle_scans,
Arc::new(SparkPlan::new(
Arc::new(SparkPlan::new_with_additional(
spark_plan.plan_id,
hash_join,
vec![join_params.left, join_params.right],
additional_native_plans,
)),
))
} else {
let swapped_hash_join =
hash_join.as_ref().swap_inputs(PartitionMode::Partitioned)?;
let mut additional_native_plans: Vec<Arc<dyn ExecutionPlan>> = vec![];
let swapped_hash_join = self.apply_join_dynamic_filter(
swapped_hash_join,
join.dynamic_filter_enabled,
&mut additional_native_plans,
)?;

let mut additional_native_plans = vec![];
if swapped_hash_join.is::<ProjectionExec>() {
// a projection was added to the hash join
additional_native_plans.push(Arc::clone(swapped_hash_join.children()[0]));
Expand Down Expand Up @@ -2159,6 +2175,25 @@ impl PhysicalPlanner {
}
}

/// When `enabled`, attaches a runtime dynamic filter to the hash join in `plan`
/// (which may sit under a `ProjectionExec` from `swap_inputs`) and registers the
/// probe-side wrapper in `additional_native_plans` so its metrics
/// (dynamic_filter_rows_pruned) surface on the join's SparkPlan node.
fn apply_join_dynamic_filter(
&self,
plan: Arc<dyn ExecutionPlan>,
enabled: bool,
additional_native_plans: &mut Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>, ExecutionError> {
if !enabled {
return Ok(plan);
}
let session_config = self.session_ctx.copied_config();
let (attached, wrapper) = attach_join_dynamic_filter(plan, session_config.options())?;
additional_native_plans.extend(wrapper);
Ok(attached)
}

#[allow(clippy::too_many_arguments)]
fn parse_join_parameters(
&self,
Expand Down Expand Up @@ -4602,6 +4637,7 @@ mod tests {
condition: None,
build_side: 0,
null_aware_anti_join: false,
dynamic_filter_enabled: false,
})),
};

Expand Down
4 changes: 4 additions & 0 deletions native/proto/src/proto/operator.proto
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,10 @@ message HashJoin {
// True for BroadcastHashJoinExec null-aware anti-joins (NOT IN subquery semantics).
// When true, any null in the build side suppresses all left rows.
bool null_aware_anti_join = 6;
// When true, attach a runtime dynamic filter to the join: after the build side
// completes, build-key bounds/membership predicates are applied to probe-side
// batches before the hash probe. See spark.comet.exec.join.dynamicFilter.enabled.
bool dynamic_filter_enabled = 7;
}

message SortMergeJoin {
Expand Down
14 changes: 14 additions & 0 deletions spark/src/main/scala/org/apache/comet/CometConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,20 @@ object CometConf extends ShimCometConf {
.booleanConf
.createWithDefault(false)

val COMET_EXEC_JOIN_DYNAMIC_FILTER_ENABLED: ConfigEntry[Boolean] =
conf(s"$COMET_EXEC_CONFIG_PREFIX.join.dynamicFilter.enabled")
.category(CATEGORY_EXEC)
.doc(
"Experimental: when enabled, Comet native hash joins apply a runtime dynamic " +
"filter to the probe side. After the join's build side completes, min/max bounds " +
"and membership predicates derived from the build keys are used to drop probe-side " +
"rows before the hash probe, which can significantly speed up selective joins such " +
"as star-schema queries. Applies to inner, left outer, left semi, and left anti " +
"joins. A selectivity guard disables the filter on streams where it prunes " +
"few rows.")
.booleanConf
.createWithDefault(false)

val COMET_SCALA_UDF_CODEGEN_ENABLED: ConfigEntry[Boolean] =
conf("spark.comet.exec.scalaUDF.codegen.enabled")
.category(CATEGORY_EXEC)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2070,6 +2070,7 @@ trait CometHashJoin {
.setBuildSide(if (join.buildSide == BuildLeft) OperatorOuterClass.BuildSide.BuildLeft
else OperatorOuterClass.BuildSide.BuildRight)
.setNullAwareAntiJoin(isNullAwareAntiJoin)
.setDynamicFilterEnabled(CometConf.COMET_EXEC_JOIN_DYNAMIC_FILTER_ENABLED.get(join.conf))
condition.foreach(joinBuilder.setCondition)
Some(builder.setHashJoin(joinBuilder).build())
} else {
Expand Down
73 changes: 73 additions & 0 deletions spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,79 @@ class CometJoinSuite extends CometTestBase {
}
}

test("HashJoin with dynamic filter enabled") {
// Runs the same join shapes as the plain hash join tests with the runtime dynamic
// filter enabled, covering eligible types (inner, left/semi/anti), an ineligible
// type (full outer, which must not be filtered), a selective build side, an empty
// build side, and NULL join keys. Results must be identical to Spark's.
withSQLConf(
CometConf.COMET_EXEC_JOIN_DYNAMIC_FILTER_ENABLED.key -> "true",
CometConf.COMET_BATCH_SIZE.key -> "100",
SQLConf.PREFER_SORTMERGEJOIN.key -> "false",
"spark.sql.join.forceApplyShuffledHashJoin" -> "true",
SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
withParquetTable((0 until 1000).map(i => (i, i % 5)), "probe") {
withParquetTable((0 until 10).map(i => (i * 100, i + 2)), "build") {
// Selective broadcast inner join: most probe rows should be pruned pre-probe.
val df1 =
sql("SELECT /*+ BROADCAST(build) */ * FROM probe JOIN build ON probe._1 = build._1")
checkSparkAnswerAndOperator(
df1,
Seq(classOf[CometBroadcastExchangeExec], classOf[CometBroadcastHashJoinExec]))

// Shuffled hash join, inner.
val df2 = sql(
"SELECT /*+ SHUFFLE_HASH(build) */ * FROM probe JOIN build ON probe._1 = build._1")
checkSparkAnswerAndOperator(df2)

// Left outer join (probe side preserved under ON clause: eligible). Spark
// cannot broadcast the preserved side of a LEFT JOIN, so this plans as a
// shuffled hash join with BuildLeft.
val df3 = sql(
"SELECT /*+ BROADCAST(build) */ * FROM build LEFT JOIN probe ON build._1 = probe._1")
checkSparkAnswerAndOperator(df3)

// Full outer join (ineligible: filter must not be attached).
val df4 = sql(
"SELECT /*+ SHUFFLE_HASH(build) */ * FROM probe FULL JOIN build ON probe._1 = build._1")
checkSparkAnswerAndOperator(df4)

// Left semi and left anti joins.
val df5 = sql(
"SELECT /*+ SHUFFLE_HASH(build) */ * FROM probe LEFT SEMI JOIN build " +
"ON probe._1 = build._1")
checkSparkAnswerAndOperator(df5)
val df6 = sql(
"SELECT /*+ SHUFFLE_HASH(build) */ * FROM probe LEFT ANTI JOIN build " +
"ON probe._1 = build._1")
checkSparkAnswerAndOperator(df6)

// Empty build side: the dynamic filter may become constant-false.
val df7 = sql(
"SELECT /*+ BROADCAST(build) */ * FROM probe JOIN build " +
"ON probe._1 = build._1 WHERE build._2 < 0")
checkSparkAnswer(df7)
}
}

// NULL join keys: a NULL key never matches, and the dynamic filter must not
// change that (NULL evaluates to not-kept, matching join semantics).
withParquetTable(
(0 until 100).map(i => (if (i % 3 == 0) None else Some(i), i.toString)),
"probe_nulls") {
withParquetTable((0 until 10).map(i => (Some(i * 9), i.toString)), "build_nulls") {
val df = sql(
"SELECT /*+ BROADCAST(build_nulls) */ * FROM probe_nulls JOIN build_nulls " +
"ON probe_nulls._1 = build_nulls._1")
checkSparkAnswerAndOperator(
df,
Seq(classOf[CometBroadcastExchangeExec], classOf[CometBroadcastHashJoinExec]))
}
}
}
}

test("Broadcast HashJoin with join filter") {
withSQLConf(
CometConf.COMET_BATCH_SIZE.key -> "100",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.spark.sql.benchmark

import org.apache.spark.SparkConf
import org.apache.spark.benchmark.Benchmark
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.internal.SQLConf

import org.apache.comet.{CometConf, CometSparkSessionExtensions}

/**
* Benchmark to measure the effect of runtime dynamic filter pushdown
* (spark.comet.exec.join.dynamicFilter.enabled) on broadcast hash joins with three build-side
* shapes: sparse selective keys (membership predicate does the pruning), clustered selective keys
* (min/max bounds do the pruning), and a non-selective build side covering the full probe domain
* (the selectivity guard should disable the filter, bounding overhead). To run:
* {{{
* SPARK_GENERATE_BENCHMARK_FILES=1 make benchmark-org.apache.spark.sql.benchmark.CometJoinDynamicFilterBenchmark
* }}}
*/
object CometJoinDynamicFilterBenchmark extends CometBenchmarkBase {

override def getSparkSession: SparkSession = {
val conf = new SparkConf()
.setAppName("CometJoinDynamicFilterBenchmark")
.set("spark.master", "local[5]")
.setIfMissing("spark.driver.memory", "3g")
.setIfMissing("spark.executor.memory", "3g")
.set(
"spark.shuffle.manager",
"org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager")

val sparkSession = SparkSession.builder
.config(conf)
.withExtensions(new CometSparkSessionExtensions)
.getOrCreate()
sparkSession.conf.set("spark.sql.shuffle.partitions", "2")
sparkSession
}

// Force BroadcastHashJoin: builds below the 10MB threshold.
private val cometConfigs: Map[String, String] = Map(
CometConf.COMET_ENABLED.key -> "true",
CometConf.COMET_EXEC_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "10MB",
SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "10MB")

private def benchmarkQuery(name: String, cardinality: Long, query: String): Unit = {
val benchmark = new Benchmark(name, cardinality, output = output)

benchmark.addCase("Spark") { _ =>
withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
spark.sql(query).noop()
}
}

benchmark.addCase("Comet (dynamic filter off)") { _ =>
val configs = cometConfigs + (CometConf.COMET_EXEC_JOIN_DYNAMIC_FILTER_ENABLED.key ->
"false")
withSQLConf(configs.toSeq: _*) {
spark.sql(query).noop()
}
}

benchmark.addCase("Comet (dynamic filter on)") { _ =>
val configs = cometConfigs + (CometConf.COMET_EXEC_JOIN_DYNAMIC_FILTER_ENABLED.key ->
"true")
withSQLConf(configs.toSeq: _*) {
spark.sql(query).noop()
}
}

benchmark.run()
}

override def runCometBenchmark(mainArgs: Array[String]): Unit = {
val probeRows = 16 * 1024 * 1024
val sparseBuildRows = 1024
val clusteredBuildRows = 64 * 1024
val fullDomain = 64 * 1024

withTempPath { dir =>
withTempTable("probe", "probe_mod", "build_sparse", "build_clustered", "build_full") {
// Probe with unique keys across [0, probeRows).
spark
.range(probeRows)
.selectExpr("id AS k", "id % 1000 AS v")
.write
.parquet(s"${dir.getAbsolutePath}/probe")
// Probe whose keys all fall in [0, fullDomain) so build_full matches every row.
spark
.range(probeRows)
.selectExpr(s"id % $fullDomain AS k", "id % 1000 AS v")
.write
.parquet(s"${dir.getAbsolutePath}/probe_mod")
// Sparse selective build: keys spread across the whole probe domain, so min/max
// bounds prune nothing and the membership predicate does the work (~0.006%
// of probe rows match).
spark
.range(sparseBuildRows)
.selectExpr(s"id * ${probeRows / sparseBuildRows} AS k", "id AS w")
.write
.parquet(s"${dir.getAbsolutePath}/build_sparse")
// Clustered selective build: contiguous keys in the middle of the probe domain,
// so the min/max bounds prune ~99.6% of probe rows cheaply.
spark
.range(clusteredBuildRows)
.selectExpr(s"id + ${probeRows / 2} AS k", "id AS w")
.write
.parquet(s"${dir.getAbsolutePath}/build_clustered")
// Non-selective build: covers every probe_mod key, so the filter keeps 100% of
// rows and the selectivity guard should disable evaluation.
spark
.range(fullDomain)
.selectExpr("id AS k", "id AS w")
.write
.parquet(s"${dir.getAbsolutePath}/build_full")

Seq("probe", "probe_mod", "build_sparse", "build_clustered", "build_full").foreach { t =>
spark.read.parquet(s"${dir.getAbsolutePath}/$t").createOrReplaceTempView(t)
}

runBenchmark("BroadcastHashJoin dynamic filter - sparse selective build") {
benchmarkQuery(
"sparse selective build (membership pruning)",
probeRows,
"SELECT /*+ BROADCAST(b) */ count(*), sum(p.v) FROM probe p " +
"JOIN build_sparse b ON p.k = b.k")
}

runBenchmark("BroadcastHashJoin dynamic filter - clustered selective build") {
benchmarkQuery(
"clustered selective build (bounds pruning)",
probeRows,
"SELECT /*+ BROADCAST(b) */ count(*), sum(p.v) FROM probe p " +
"JOIN build_clustered b ON p.k = b.k")
}

runBenchmark("BroadcastHashJoin dynamic filter - non-selective build") {
benchmarkQuery(
"non-selective build (guard disables filter)",
probeRows,
"SELECT /*+ BROADCAST(b) */ count(*), sum(p.v) FROM probe_mod p " +
"JOIN build_full b ON p.k = b.k")
}
}
}
}
}