diff --git a/native/core/src/execution/operators/dynamic_filter.rs b/native/core/src/execution/operators/dynamic_filter.rs new file mode 100644 index 0000000000..49983f516a --- /dev/null +++ b/native/core/src/execution/operators/dynamic_filter.rs @@ -0,0 +1,495 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Runtime dynamic filter for hash join probe sides. +//! +//! [`DynamicFilterExec`] evaluates a join's [`DynamicFilterPhysicalExpr`] against +//! probe-side batches before they reach the hash probe. The expression starts as a +//! `lit(true)` placeholder and is populated by DataFusion's `HashJoinExec` build phase +//! (min/max bounds plus `InList` or hash-table-lookup membership). Until then — or if +//! the join never populates it — batches pass through untouched, so correctness never +//! depends on population. +//! +//! [`attach_join_dynamic_filter`] rewires an eligible `HashJoinExec` so that the join +//! and a new `DynamicFilterExec` wrapping its probe child share the same filter. + +use std::fmt::Formatter; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use arrow::array::RecordBatch; +use arrow::compute::filter_record_batch; +use arrow::datatypes::SchemaRef; +use datafusion::common::cast::as_boolean_array; +use datafusion::common::config::ConfigOptions; +use datafusion::common::{DataFusionError, Result as DataFusionResult, ScalarValue}; +use datafusion::execution::TaskContext; +use datafusion::logical_expr::ColumnarValue; +use datafusion::physical_expr::expressions::{lit, DynamicFilterPhysicalExpr}; +use datafusion::physical_expr::PhysicalExpr; +use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode}; +use datafusion::physical_plan::metrics::{ + BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet, Time, +}; +use datafusion::physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, RecordBatchStream, + SendableRecordBatchStream, +}; +use futures::{Stream, StreamExt}; + +/// Stop evaluating the filter for the remainder of a partition stream when, after at +/// least [`GUARD_MIN_ROWS`] filtered rows, it keeps more than this fraction of them. +const GUARD_MAX_SELECTIVITY: f64 = 0.95; +/// Minimum number of rows to observe before the selectivity guard may disable the +/// filter, so a few unrepresentative leading batches don't make the decision. +const GUARD_MIN_ROWS: usize = 65_536; + +/// Filters probe-side batches with a join's shared [`DynamicFilterPhysicalExpr`]. +/// +/// Distinct from a generic `FilterExec` in three ways: a pass-through fast path while +/// the filter is still the constant-`true` placeholder, a selectivity guard that +/// disables evaluation on non-selective streams, and dedicated metrics +/// (`dynamic_filter_rows_pruned`). +#[derive(Debug)] +pub struct DynamicFilterExec { + input: Arc, + predicate: Arc, + metrics: ExecutionPlanMetricsSet, + cache: Arc, +} + +impl DynamicFilterExec { + pub fn new(input: Arc, predicate: Arc) -> Self { + // Filtering preserves schema, ordering, and partitioning. + let cache = Arc::clone(input.properties()); + Self { + input, + predicate, + metrics: ExecutionPlanMetricsSet::new(), + cache, + } + } + + pub fn predicate(&self) -> &Arc { + &self.predicate + } +} + +impl DisplayAs for DynamicFilterExec { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + write!(f, "CometDynamicFilterExec") + } +} + +impl ExecutionPlan for DynamicFilterExec { + fn name(&self) -> &str { + "CometDynamicFilterExec" + } + + fn properties(&self) -> &Arc { + &self.cache + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + fn maintains_input_order(&self) -> Vec { + vec![true] + } + + fn with_new_children( + self: Arc, + mut children: Vec>, + ) -> DataFusionResult> { + Ok(Arc::new(DynamicFilterExec::new( + children.swap_remove(0), + Arc::clone(&self.predicate), + ))) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> DataFusionResult { + let input = self.input.execute(partition, context)?; + let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); + let rows_pruned = + MetricBuilder::new(&self.metrics).counter("dynamic_filter_rows_pruned", partition); + // A dedicated timer rather than the baseline elapsed_compute: this operator is + // registered as an additional native plan on the join's SparkPlan node for + // metrics collection, and the metric merge sums same-named metrics — timing + // via elapsed_compute would inflate the join's reported compute time. + let eval_time = + MetricBuilder::new(&self.metrics).subset_time("dynamic_filter_eval_time", partition); + Ok(Box::pin(DynamicFilterStream { + schema: self.input.schema(), + input, + predicate: Arc::clone(&self.predicate), + baseline_metrics, + rows_pruned, + eval_time, + rows_evaluated: 0, + rows_kept: 0, + guard_disabled: false, + })) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } +} + +struct DynamicFilterStream { + schema: SchemaRef, + input: SendableRecordBatchStream, + predicate: Arc, + baseline_metrics: BaselineMetrics, + rows_pruned: Count, + eval_time: Time, + /// Rows seen since the filter became a real (non-placeholder) predicate. + rows_evaluated: usize, + rows_kept: usize, + /// Set once the selectivity guard decides the filter is not worth evaluating. + guard_disabled: bool, +} + +impl DynamicFilterStream { + fn filter_batch(&mut self, batch: RecordBatch) -> DataFusionResult> { + let _timer = self.eval_time.timer(); + match self.predicate.evaluate(&batch)? { + // Placeholder (or degenerate all-true) predicate: pass through untouched. + // Not counted toward the selectivity guard — the real filter may not have + // arrived yet. + ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))) => Ok(Some(batch)), + // Constant false/null (e.g. empty build side): the whole batch is pruned. + ColumnarValue::Scalar(_) => { + self.rows_pruned.add(batch.num_rows()); + self.rows_evaluated += batch.num_rows(); + Ok(None) + } + ColumnarValue::Array(array) => { + let mask = as_boolean_array(&array)?; + let filtered = filter_record_batch(&batch, mask)?; + let kept = filtered.num_rows(); + self.rows_pruned.add(batch.num_rows() - kept); + self.rows_evaluated += batch.num_rows(); + self.rows_kept += kept; + if self.rows_evaluated >= GUARD_MIN_ROWS + && (self.rows_kept as f64 / self.rows_evaluated as f64) > GUARD_MAX_SELECTIVITY + { + self.guard_disabled = true; + } + if kept == 0 { + Ok(None) + } else { + Ok(Some(filtered)) + } + } + } + } +} + +impl Stream for DynamicFilterStream { + type Item = DataFusionResult; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + match self.input.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(batch))) => { + if self.guard_disabled { + self.baseline_metrics.record_output(batch.num_rows()); + return Poll::Ready(Some(Ok(batch))); + } + match self.filter_batch(batch) { + Ok(Some(filtered)) => { + self.baseline_metrics.record_output(filtered.num_rows()); + return Poll::Ready(Some(Ok(filtered))); + } + // Entire batch pruned: keep polling the input. + Ok(None) => continue, + Err(e) => return Poll::Ready(Some(Err(e))), + } + } + other => return other, + } + } + } +} + +impl RecordBatchStream for DynamicFilterStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } +} + +/// Attaches a runtime dynamic filter to an eligible [`HashJoinExec`], wrapping its +/// probe (right) child in a [`DynamicFilterExec`] that shares the same filter. +/// +/// Accepts either a `HashJoinExec` or a `ProjectionExec` directly above one (the shape +/// `HashJoinExec::swap_inputs` produces). Returns the (possibly rewritten) plan plus +/// the installed wrapper, so the planner can register the wrapper for metrics +/// collection (`SparkPlan::new_with_additional`); when the join is not eligible the +/// input plan is returned unchanged with `None`. +/// +/// Eligibility mirrors DataFusion's own `allow_join_dynamic_filter_pushdown` gate: +/// - the session option `optimizer.enable_join_dynamic_filter_pushdown` must be on, +/// - `optimizer.preserve_file_partitions` with `PartitionMode::Partitioned` is +/// excluded (file-group partitions are not hash-distributed by the join keys), +/// - the probe side must be preserved under the ON clause +/// (`JoinType::on_lr_is_preserved().1`), which admits Inner, Left, LeftSemi, +/// RightSemi, LeftAnti, and LeftMark joins — a probe row removed by the filter +/// could not have matched any build row, so results are unchanged. +/// +/// DataFusion re-checks the same gate at execute time (an ineligible join never +/// populates the filter, leaving the wrapper a harmless pass-through); checking here +/// as well avoids installing a wrapper that can never engage. +/// +/// Callers must not pass null-aware anti joins (Spark NOT IN semantics); that gate +/// lives at the call site where the flag is known. +pub fn attach_join_dynamic_filter( + plan: Arc, + config: &ConfigOptions, +) -> DataFusionResult { + // swap_inputs may have inserted a projection above the join to restore column order. + if plan.is::() { + let child = Arc::clone(plan.children()[0]); + let (new_child, wrapper) = attach_join_dynamic_filter(child, config)?; + return Ok((plan.with_new_children(vec![new_child])?, wrapper)); + } + + let Some(hash_join) = plan.downcast_ref::() else { + return Ok((plan, None)); + }; + if !config.optimizer.enable_join_dynamic_filter_pushdown { + return Ok((plan, None)); + } + if config.optimizer.preserve_file_partitions > 0 + && matches!(hash_join.partition_mode(), PartitionMode::Partitioned) + { + return Ok((plan, None)); + } + if !hash_join.join_type().on_lr_is_preserved().1 { + return Ok((plan, None)); + } + let probe_keys: Vec> = hash_join + .on() + .iter() + .map(|(_, right)| Arc::clone(right)) + .collect(); + if probe_keys.is_empty() { + return Ok((plan, None)); + } + + let dynamic_filter = Arc::new(DynamicFilterPhysicalExpr::new(probe_keys, lit(true))); + let wrapped_probe: Arc = Arc::new(DynamicFilterExec::new( + Arc::clone(hash_join.right()), + Arc::clone(&dynamic_filter), + )); + let new_join = hash_join + .builder() + .with_new_children(vec![ + Arc::clone(hash_join.left()), + Arc::clone(&wrapped_probe), + ])? + .build()? + .with_dynamic_filter_expr(dynamic_filter) + .map_err(|e| { + DataFusionError::Internal(format!("failed to attach join dynamic filter: {e}")) + })?; + Ok((Arc::new(new_join), Some(wrapped_probe))) +} + +/// Result of [`attach_join_dynamic_filter`]: the (possibly rewritten) plan and the +/// [`DynamicFilterExec`] wrapper if one was installed. +pub type PlanWithDynamicFilter = (Arc, Option>); + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::Int32Array; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion::common::JoinType; + use datafusion::common::NullEquality; + use datafusion::datasource::memory::MemorySourceConfig; + use datafusion::datasource::source::DataSourceExec; + use datafusion::logical_expr::Operator; + use datafusion::physical_expr::expressions::{col, BinaryExpr}; + use datafusion::physical_plan::collect; + use datafusion::prelude::SessionContext; + + fn int_batch(name: &str, values: Vec) -> (SchemaRef, RecordBatch) { + let schema = Arc::new(Schema::new(vec![Field::new(name, DataType::Int32, true)])); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from(values))], + ) + .unwrap(); + (schema, batch) + } + + fn memory_exec(schema: &SchemaRef, batches: Vec) -> Arc { + let config = MemorySourceConfig::try_new(&[batches], Arc::clone(schema), None).unwrap(); + Arc::new(DataSourceExec::new(Arc::new(config))) + } + + #[tokio::test] + async fn test_pass_through_then_filter_then_constant_false() { + let (schema, batch) = int_batch("a", (0..100).collect()); + let input = memory_exec(&schema, vec![batch]); + let predicate = Arc::new(DynamicFilterPhysicalExpr::new( + vec![col("a", &schema).unwrap()], + lit(true), + )); + let exec = Arc::new(DynamicFilterExec::new(input, Arc::clone(&predicate))); + let task_ctx = SessionContext::new().task_ctx(); + + // Placeholder: everything passes through. + let batches = collect(Arc::clone(&exec) as _, Arc::clone(&task_ctx)) + .await + .unwrap(); + assert_eq!(batches.iter().map(|b| b.num_rows()).sum::(), 100); + + // Populated: a >= 90 keeps 10 rows. + predicate + .update(Arc::new(BinaryExpr::new( + col("a", &schema).unwrap(), + Operator::GtEq, + lit(90), + ))) + .unwrap(); + let batches = collect(Arc::clone(&exec) as _, Arc::clone(&task_ctx)) + .await + .unwrap(); + assert_eq!(batches.iter().map(|b| b.num_rows()).sum::(), 10); + let pruned = exec + .metrics() + .unwrap() + .sum_by_name("dynamic_filter_rows_pruned") + .map(|m| m.as_usize()) + .unwrap_or(0); + assert_eq!(pruned, 90); + + // Constant false (e.g. empty build side): everything is pruned. + predicate.update(lit(false)).unwrap(); + let batches = collect(Arc::clone(&exec) as _, task_ctx).await.unwrap(); + assert_eq!(batches.iter().map(|b| b.num_rows()).sum::(), 0); + } + + fn test_hash_join_with_mode( + join_type: JoinType, + mode: PartitionMode, + ) -> Arc { + let (build_schema, build_batch) = int_batch("a", vec![10, 20]); + let (probe_schema, probe_batch) = int_batch("b", (0..100).collect()); + let join_on = vec![( + col("a", &build_schema).unwrap(), + col("b", &probe_schema).unwrap(), + )]; + Arc::new( + HashJoinExec::try_new( + memory_exec(&build_schema, vec![build_batch]), + memory_exec(&probe_schema, vec![probe_batch]), + join_on, + None, + &join_type, + None, + mode, + NullEquality::NullEqualsNothing, + false, + ) + .unwrap(), + ) + } + + fn test_hash_join(join_type: JoinType) -> Arc { + test_hash_join_with_mode(join_type, PartitionMode::CollectLeft) + } + + #[tokio::test] + async fn test_attach_wraps_probe_side_and_preserves_results() { + let plain = test_hash_join(JoinType::Inner); + let (attached, installed) = + attach_join_dynamic_filter(Arc::clone(&plain) as _, &ConfigOptions::default()).unwrap(); + assert!(installed.is_some()); + + let join = attached + .downcast_ref::() + .expect("expected HashJoinExec"); + assert!(join.dynamic_filter_expr().is_some()); + let wrapper = join + .right() + .downcast_ref::() + .expect("probe side should be wrapped in CometDynamicFilterExec"); + + let task_ctx = SessionContext::new().task_ctx(); + let expected = collect(plain as _, Arc::clone(&task_ctx)).await.unwrap(); + let actual = collect(Arc::clone(&attached), Arc::clone(&task_ctx)) + .await + .unwrap(); + let expected_rows: usize = expected.iter().map(|b| b.num_rows()).sum(); + let actual_rows: usize = actual.iter().map(|b| b.num_rows()).sum(); + assert_eq!(expected_rows, 2); + assert_eq!(actual_rows, 2); + + // The build phase populated the filter and probe rows were pruned before the join. + let pruned = wrapper + .metrics() + .unwrap() + .sum_by_name("dynamic_filter_rows_pruned") + .map(|m| m.as_usize()) + .unwrap_or(0); + assert!(pruned > 0, "expected probe rows to be pruned, got 0"); + } + + #[tokio::test] + async fn test_attach_skips_non_probe_preserved_join_types() { + for join_type in [JoinType::Right, JoinType::Full, JoinType::RightAnti] { + let plain = test_hash_join(join_type); + let (attached, installed) = + attach_join_dynamic_filter(Arc::clone(&plain) as _, &ConfigOptions::default()) + .unwrap(); + assert!(installed.is_none()); + let join = attached.downcast_ref::().unwrap(); + assert!( + join.dynamic_filter_expr().is_none(), + "join type {join_type:?} must not get a dynamic filter" + ); + assert!(join.right().downcast_ref::().is_none()); + } + } + + #[test] + fn test_attach_respects_upstream_config_gate() { + // preserve_file_partitions with Partitioned mode: mirrors DataFusion's + // allow_join_dynamic_filter_pushdown exclusion. + let plain = test_hash_join_with_mode(JoinType::Inner, PartitionMode::Partitioned); + let mut config = ConfigOptions::default(); + config.optimizer.preserve_file_partitions = 1; + let (_, installed) = attach_join_dynamic_filter(Arc::clone(&plain) as _, &config).unwrap(); + assert!(installed.is_none()); + + // Session flag off: no attachment. + let plain = test_hash_join(JoinType::Inner); + let mut config = ConfigOptions::default(); + config.optimizer.enable_join_dynamic_filter_pushdown = false; + let (_, installed) = attach_join_dynamic_filter(plain as _, &config).unwrap(); + assert!(installed.is_none()); + } +} diff --git a/native/core/src/execution/operators/mod.rs b/native/core/src/execution/operators/mod.rs index d68252bd9b..181d9689f7 100644 --- a/native/core/src/execution/operators/mod.rs +++ b/native/core/src/execution/operators/mod.rs @@ -26,6 +26,8 @@ pub use scan::*; mod aligned_stream_reader; mod copy; +mod dynamic_filter; +pub use dynamic_filter::{attach_join_dynamic_filter, DynamicFilterExec}; mod expand; pub use expand::ExpandExec; mod iceberg_scan; diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 25162332fd..401631cb5f 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -21,6 +21,7 @@ pub mod expression_registry; pub mod macros; pub mod operator_registry; +use crate::execution::operators::attach_join_dynamic_filter; use crate::execution::operators::init_csv_datasource_exec; use crate::execution::operators::AlignedArrowStreamReader; use crate::execution::operators::IcebergScanExec; @@ -1981,20 +1982,35 @@ impl PhysicalPlanner { // (which matches DataFusion's default), and swap_inputs would turn LeftAnti // into RightAnti, which DataFusion rejects with null_aware=true. if join.build_side == BuildSide::BuildLeft as i32 || join.null_aware_anti_join { + // Null-aware anti joins are excluded from dynamic filtering: NOT IN + // semantics depend on observing build-side nulls, so probe rows must + // not be pre-filtered. + let mut additional_native_plans: Vec> = vec![]; + let hash_join = self.apply_join_dynamic_filter( + hash_join, + join.dynamic_filter_enabled && !join.null_aware_anti_join, + &mut additional_native_plans, + )?; Ok(( scans, shuffle_scans, - Arc::new(SparkPlan::new( + Arc::new(SparkPlan::new_with_additional( spark_plan.plan_id, hash_join, vec![join_params.left, join_params.right], + additional_native_plans, )), )) } else { let swapped_hash_join = hash_join.as_ref().swap_inputs(PartitionMode::Partitioned)?; + let mut additional_native_plans: Vec> = vec![]; + let swapped_hash_join = self.apply_join_dynamic_filter( + swapped_hash_join, + join.dynamic_filter_enabled, + &mut additional_native_plans, + )?; - let mut additional_native_plans = vec![]; if swapped_hash_join.is::() { // a projection was added to the hash join additional_native_plans.push(Arc::clone(swapped_hash_join.children()[0])); @@ -2159,6 +2175,25 @@ impl PhysicalPlanner { } } + /// When `enabled`, attaches a runtime dynamic filter to the hash join in `plan` + /// (which may sit under a `ProjectionExec` from `swap_inputs`) and registers the + /// probe-side wrapper in `additional_native_plans` so its metrics + /// (dynamic_filter_rows_pruned) surface on the join's SparkPlan node. + fn apply_join_dynamic_filter( + &self, + plan: Arc, + enabled: bool, + additional_native_plans: &mut Vec>, + ) -> Result, ExecutionError> { + if !enabled { + return Ok(plan); + } + let session_config = self.session_ctx.copied_config(); + let (attached, wrapper) = attach_join_dynamic_filter(plan, session_config.options())?; + additional_native_plans.extend(wrapper); + Ok(attached) + } + #[allow(clippy::too_many_arguments)] fn parse_join_parameters( &self, @@ -4602,6 +4637,7 @@ mod tests { condition: None, build_side: 0, null_aware_anti_join: false, + dynamic_filter_enabled: false, })), }; diff --git a/native/proto/src/proto/operator.proto b/native/proto/src/proto/operator.proto index 2fcfe7f25b..499aa7fc82 100644 --- a/native/proto/src/proto/operator.proto +++ b/native/proto/src/proto/operator.proto @@ -391,6 +391,10 @@ message HashJoin { // True for BroadcastHashJoinExec null-aware anti-joins (NOT IN subquery semantics). // When true, any null in the build side suppresses all left rows. bool null_aware_anti_join = 6; + // When true, attach a runtime dynamic filter to the join: after the build side + // completes, build-key bounds/membership predicates are applied to probe-side + // batches before the hash probe. See spark.comet.exec.join.dynamicFilter.enabled. + bool dynamic_filter_enabled = 7; } message SortMergeJoin { diff --git a/spark/src/main/scala/org/apache/comet/CometConf.scala b/spark/src/main/scala/org/apache/comet/CometConf.scala index 8e47151358..89c2c0d666 100644 --- a/spark/src/main/scala/org/apache/comet/CometConf.scala +++ b/spark/src/main/scala/org/apache/comet/CometConf.scala @@ -368,6 +368,20 @@ object CometConf extends ShimCometConf { .booleanConf .createWithDefault(false) + val COMET_EXEC_JOIN_DYNAMIC_FILTER_ENABLED: ConfigEntry[Boolean] = + conf(s"$COMET_EXEC_CONFIG_PREFIX.join.dynamicFilter.enabled") + .category(CATEGORY_EXEC) + .doc( + "Experimental: when enabled, Comet native hash joins apply a runtime dynamic " + + "filter to the probe side. After the join's build side completes, min/max bounds " + + "and membership predicates derived from the build keys are used to drop probe-side " + + "rows before the hash probe, which can significantly speed up selective joins such " + + "as star-schema queries. Applies to inner, left outer, left semi, and left anti " + + "joins. A selectivity guard disables the filter on streams where it prunes " + + "few rows.") + .booleanConf + .createWithDefault(false) + val COMET_SCALA_UDF_CODEGEN_ENABLED: ConfigEntry[Boolean] = conf("spark.comet.exec.scalaUDF.codegen.enabled") .category(CATEGORY_EXEC) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index e4d6b53770..43b28fee6d 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -2070,6 +2070,7 @@ trait CometHashJoin { .setBuildSide(if (join.buildSide == BuildLeft) OperatorOuterClass.BuildSide.BuildLeft else OperatorOuterClass.BuildSide.BuildRight) .setNullAwareAntiJoin(isNullAwareAntiJoin) + .setDynamicFilterEnabled(CometConf.COMET_EXEC_JOIN_DYNAMIC_FILTER_ENABLED.get(join.conf)) condition.foreach(joinBuilder.setCondition) Some(builder.setHashJoin(joinBuilder).build()) } else { diff --git a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala index f01d5e2109..fd76b97159 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala @@ -239,6 +239,79 @@ class CometJoinSuite extends CometTestBase { } } + test("HashJoin with dynamic filter enabled") { + // Runs the same join shapes as the plain hash join tests with the runtime dynamic + // filter enabled, covering eligible types (inner, left/semi/anti), an ineligible + // type (full outer, which must not be filtered), a selective build side, an empty + // build side, and NULL join keys. Results must be identical to Spark's. + withSQLConf( + CometConf.COMET_EXEC_JOIN_DYNAMIC_FILTER_ENABLED.key -> "true", + CometConf.COMET_BATCH_SIZE.key -> "100", + SQLConf.PREFER_SORTMERGEJOIN.key -> "false", + "spark.sql.join.forceApplyShuffledHashJoin" -> "true", + SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + withParquetTable((0 until 1000).map(i => (i, i % 5)), "probe") { + withParquetTable((0 until 10).map(i => (i * 100, i + 2)), "build") { + // Selective broadcast inner join: most probe rows should be pruned pre-probe. + val df1 = + sql("SELECT /*+ BROADCAST(build) */ * FROM probe JOIN build ON probe._1 = build._1") + checkSparkAnswerAndOperator( + df1, + Seq(classOf[CometBroadcastExchangeExec], classOf[CometBroadcastHashJoinExec])) + + // Shuffled hash join, inner. + val df2 = sql( + "SELECT /*+ SHUFFLE_HASH(build) */ * FROM probe JOIN build ON probe._1 = build._1") + checkSparkAnswerAndOperator(df2) + + // Left outer join (probe side preserved under ON clause: eligible). Spark + // cannot broadcast the preserved side of a LEFT JOIN, so this plans as a + // shuffled hash join with BuildLeft. + val df3 = sql( + "SELECT /*+ BROADCAST(build) */ * FROM build LEFT JOIN probe ON build._1 = probe._1") + checkSparkAnswerAndOperator(df3) + + // Full outer join (ineligible: filter must not be attached). + val df4 = sql( + "SELECT /*+ SHUFFLE_HASH(build) */ * FROM probe FULL JOIN build ON probe._1 = build._1") + checkSparkAnswerAndOperator(df4) + + // Left semi and left anti joins. + val df5 = sql( + "SELECT /*+ SHUFFLE_HASH(build) */ * FROM probe LEFT SEMI JOIN build " + + "ON probe._1 = build._1") + checkSparkAnswerAndOperator(df5) + val df6 = sql( + "SELECT /*+ SHUFFLE_HASH(build) */ * FROM probe LEFT ANTI JOIN build " + + "ON probe._1 = build._1") + checkSparkAnswerAndOperator(df6) + + // Empty build side: the dynamic filter may become constant-false. + val df7 = sql( + "SELECT /*+ BROADCAST(build) */ * FROM probe JOIN build " + + "ON probe._1 = build._1 WHERE build._2 < 0") + checkSparkAnswer(df7) + } + } + + // NULL join keys: a NULL key never matches, and the dynamic filter must not + // change that (NULL evaluates to not-kept, matching join semantics). + withParquetTable( + (0 until 100).map(i => (if (i % 3 == 0) None else Some(i), i.toString)), + "probe_nulls") { + withParquetTable((0 until 10).map(i => (Some(i * 9), i.toString)), "build_nulls") { + val df = sql( + "SELECT /*+ BROADCAST(build_nulls) */ * FROM probe_nulls JOIN build_nulls " + + "ON probe_nulls._1 = build_nulls._1") + checkSparkAnswerAndOperator( + df, + Seq(classOf[CometBroadcastExchangeExec], classOf[CometBroadcastHashJoinExec])) + } + } + } + } + test("Broadcast HashJoin with join filter") { withSQLConf( CometConf.COMET_BATCH_SIZE.key -> "100", diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometJoinDynamicFilterBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometJoinDynamicFilterBenchmark.scala new file mode 100644 index 0000000000..320eba915f --- /dev/null +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometJoinDynamicFilterBenchmark.scala @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.benchmark + +import org.apache.spark.SparkConf +import org.apache.spark.benchmark.Benchmark +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.internal.SQLConf + +import org.apache.comet.{CometConf, CometSparkSessionExtensions} + +/** + * Benchmark to measure the effect of runtime dynamic filter pushdown + * (spark.comet.exec.join.dynamicFilter.enabled) on broadcast hash joins with three build-side + * shapes: sparse selective keys (membership predicate does the pruning), clustered selective keys + * (min/max bounds do the pruning), and a non-selective build side covering the full probe domain + * (the selectivity guard should disable the filter, bounding overhead). To run: + * {{{ + * SPARK_GENERATE_BENCHMARK_FILES=1 make benchmark-org.apache.spark.sql.benchmark.CometJoinDynamicFilterBenchmark + * }}} + */ +object CometJoinDynamicFilterBenchmark extends CometBenchmarkBase { + + override def getSparkSession: SparkSession = { + val conf = new SparkConf() + .setAppName("CometJoinDynamicFilterBenchmark") + .set("spark.master", "local[5]") + .setIfMissing("spark.driver.memory", "3g") + .setIfMissing("spark.executor.memory", "3g") + .set( + "spark.shuffle.manager", + "org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager") + + val sparkSession = SparkSession.builder + .config(conf) + .withExtensions(new CometSparkSessionExtensions) + .getOrCreate() + sparkSession.conf.set("spark.sql.shuffle.partitions", "2") + sparkSession + } + + // Force BroadcastHashJoin: builds below the 10MB threshold. + private val cometConfigs: Map[String, String] = Map( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "10MB", + SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "10MB") + + private def benchmarkQuery(name: String, cardinality: Long, query: String): Unit = { + val benchmark = new Benchmark(name, cardinality, output = output) + + benchmark.addCase("Spark") { _ => + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + spark.sql(query).noop() + } + } + + benchmark.addCase("Comet (dynamic filter off)") { _ => + val configs = cometConfigs + (CometConf.COMET_EXEC_JOIN_DYNAMIC_FILTER_ENABLED.key -> + "false") + withSQLConf(configs.toSeq: _*) { + spark.sql(query).noop() + } + } + + benchmark.addCase("Comet (dynamic filter on)") { _ => + val configs = cometConfigs + (CometConf.COMET_EXEC_JOIN_DYNAMIC_FILTER_ENABLED.key -> + "true") + withSQLConf(configs.toSeq: _*) { + spark.sql(query).noop() + } + } + + benchmark.run() + } + + override def runCometBenchmark(mainArgs: Array[String]): Unit = { + val probeRows = 16 * 1024 * 1024 + val sparseBuildRows = 1024 + val clusteredBuildRows = 64 * 1024 + val fullDomain = 64 * 1024 + + withTempPath { dir => + withTempTable("probe", "probe_mod", "build_sparse", "build_clustered", "build_full") { + // Probe with unique keys across [0, probeRows). + spark + .range(probeRows) + .selectExpr("id AS k", "id % 1000 AS v") + .write + .parquet(s"${dir.getAbsolutePath}/probe") + // Probe whose keys all fall in [0, fullDomain) so build_full matches every row. + spark + .range(probeRows) + .selectExpr(s"id % $fullDomain AS k", "id % 1000 AS v") + .write + .parquet(s"${dir.getAbsolutePath}/probe_mod") + // Sparse selective build: keys spread across the whole probe domain, so min/max + // bounds prune nothing and the membership predicate does the work (~0.006% + // of probe rows match). + spark + .range(sparseBuildRows) + .selectExpr(s"id * ${probeRows / sparseBuildRows} AS k", "id AS w") + .write + .parquet(s"${dir.getAbsolutePath}/build_sparse") + // Clustered selective build: contiguous keys in the middle of the probe domain, + // so the min/max bounds prune ~99.6% of probe rows cheaply. + spark + .range(clusteredBuildRows) + .selectExpr(s"id + ${probeRows / 2} AS k", "id AS w") + .write + .parquet(s"${dir.getAbsolutePath}/build_clustered") + // Non-selective build: covers every probe_mod key, so the filter keeps 100% of + // rows and the selectivity guard should disable evaluation. + spark + .range(fullDomain) + .selectExpr("id AS k", "id AS w") + .write + .parquet(s"${dir.getAbsolutePath}/build_full") + + Seq("probe", "probe_mod", "build_sparse", "build_clustered", "build_full").foreach { t => + spark.read.parquet(s"${dir.getAbsolutePath}/$t").createOrReplaceTempView(t) + } + + runBenchmark("BroadcastHashJoin dynamic filter - sparse selective build") { + benchmarkQuery( + "sparse selective build (membership pruning)", + probeRows, + "SELECT /*+ BROADCAST(b) */ count(*), sum(p.v) FROM probe p " + + "JOIN build_sparse b ON p.k = b.k") + } + + runBenchmark("BroadcastHashJoin dynamic filter - clustered selective build") { + benchmarkQuery( + "clustered selective build (bounds pruning)", + probeRows, + "SELECT /*+ BROADCAST(b) */ count(*), sum(p.v) FROM probe p " + + "JOIN build_clustered b ON p.k = b.k") + } + + runBenchmark("BroadcastHashJoin dynamic filter - non-selective build") { + benchmarkQuery( + "non-selective build (guard disables filter)", + probeRows, + "SELECT /*+ BROADCAST(b) */ count(*), sum(p.v) FROM probe_mod p " + + "JOIN build_full b ON p.k = b.k") + } + } + } + } +}