diff --git a/.github/workflows/pr_build_linux.yml b/.github/workflows/pr_build_linux.yml index 85a2a0fa6a..678ce4efcd 100644 --- a/.github/workflows/pr_build_linux.yml +++ b/.github/workflows/pr_build_linux.yml @@ -348,6 +348,7 @@ jobs: org.apache.comet.rules.CometScanRuleSuite org.apache.comet.rules.CometScanSchemeFallbackSuite org.apache.comet.rules.CometExecRuleSuite + org.apache.comet.rules.RevertNativeForTransitionHeavyStagesSuite org.apache.spark.sql.CometTPCDSQuerySuite org.apache.spark.sql.CometTPCDSQueryTestSuite org.apache.spark.sql.CometTPCHQuerySuite diff --git a/.github/workflows/pr_build_macos.yml b/.github/workflows/pr_build_macos.yml index abde1554f6..880ebf031d 100644 --- a/.github/workflows/pr_build_macos.yml +++ b/.github/workflows/pr_build_macos.yml @@ -164,6 +164,7 @@ jobs: org.apache.comet.rules.CometScanRuleSuite org.apache.comet.rules.CometScanSchemeFallbackSuite org.apache.comet.rules.CometExecRuleSuite + org.apache.comet.rules.RevertNativeForTransitionHeavyStagesSuite org.apache.spark.sql.CometTPCDSQuerySuite org.apache.spark.sql.CometTPCDSQueryTestSuite org.apache.spark.sql.CometTPCHQuerySuite diff --git a/spark/src/main/scala/org/apache/comet/CometConf.scala b/spark/src/main/scala/org/apache/comet/CometConf.scala index 8e47151358..926fc6bc95 100644 --- a/spark/src/main/scala/org/apache/comet/CometConf.scala +++ b/spark/src/main/scala/org/apache/comet/CometConf.scala @@ -450,6 +450,32 @@ object CometConf extends ShimCometConf { .booleanConf .createWithDefault(true) + val COMET_EXEC_TRANSITION_REVERT_ENABLED: ConfigEntry[Boolean] = + conf(s"$COMET_EXEC_CONFIG_PREFIX.transitionRevert.enabled") + .category(CATEGORY_EXEC) + .doc( + "When enabled, Comet reverts a query stage to Spark row-based execution if the number " + + "of columnar-to-row (C2R) transitions in the stage exceeds the configured threshold. " + + "This avoids the overhead of repeated format conversions in stages where many " + + "operators fall back to row-based execution.") + .booleanConf + .createWithDefault(false) + + val COMET_EXEC_TRANSITION_REVERT_MAX_TRANSITIONS: ConfigEntry[Int] = + conf(s"$COMET_EXEC_CONFIG_PREFIX.transitionRevert.maxTransitions") + .category(CATEGORY_EXEC) + .doc( + "The maximum number of columnar-to-row (C2R) transitions allowed in a single query " + + "stage before Comet reverts the entire stage to Spark row-based execution. When " + + "columnar shuffle is enabled, each such C2R typically implies a corresponding " + + "row-to-columnar conversion to feed back into the columnar shuffle, so each counted " + + "C2R is a useful proxy for the conversion overhead in the stage. Set to 0 to revert " + + "any stage with transitions. " + + "Only effective when spark.comet.exec.transitionRevert.enabled is true.") + .intConf + .checkValue(_ >= 0, "Must be >= 0.") + .createWithDefault(2) + val COMET_EXEC_SHUFFLE_COMPRESSION_CODEC: ConfigEntry[String] = conf(s"$COMET_EXEC_CONFIG_PREFIX.shuffle.compression.codec") .category(CATEGORY_SHUFFLE) diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index 22c3c9c93e..f117372b3c 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.internal.SQLConf import org.apache.comet.CometConf._ -import org.apache.comet.rules.{CometExecRule, CometPlanAdaptiveDynamicPruningFilters, CometReuseSubquery, CometScanRule, CometSpark34AqeDppFallbackRule, EliminateRedundantTransitions} +import org.apache.comet.rules.{CometExecRule, CometPlanAdaptiveDynamicPruningFilters, CometReuseSubquery, CometScanRule, CometSpark34AqeDppFallbackRule, EliminateRedundantTransitions, RevertNativeForTransitionHeavyStages} import org.apache.comet.shims.ShimCometSparkSessionExtensions /** @@ -51,7 +51,8 @@ import org.apache.comet.shims.ShimCometSparkSessionExtensions * - CometExecRule.convertSubqueryBroadcasts converts SubqueryBroadcastExec to * CometSubqueryBroadcastExec for exchange reuse with Comet broadcasts * b. insertTransitions: ColumnarToRow/RowToColumnar added - * c. postColumnarTransitions: EliminateRedundantTransitions + * c. postColumnarTransitions: RevertNativeForTransitionHeavyStages, + * EliminateRedundantTransitions * 5. ReuseExchangeAndSubquery -- Spark deduplicates subqueries (sees Comet nodes) * }}} * @@ -74,7 +75,8 @@ import org.apache.comet.shims.ShimCometSparkSessionExtensions * 2. postStageCreationRules -> ApplyColumnarRulesAndInsertTransitions: * a. preColumnarTransitions: CometScanRule, CometExecRule (no-ops, already converted) * b. insertTransitions - * c. postColumnarTransitions: EliminateRedundantTransitions + * c. postColumnarTransitions: RevertNativeForTransitionHeavyStages, + * EliminateRedundantTransitions * }}} * * On Spark 3.4, injectQueryStageOptimizerRule is unavailable. CometExecRule does not wrap SABs, @@ -106,8 +108,11 @@ class CometSparkSessionExtensions case class CometExecColumnar(session: SparkSession) extends ColumnarRule { override def preColumnarTransitions: Rule[SparkPlan] = CometExecRule(session) - override def postColumnarTransitions: Rule[SparkPlan] = - EliminateRedundantTransitions(session) + override def postColumnarTransitions: Rule[SparkPlan] = { + val rules = + Seq(RevertNativeForTransitionHeavyStages(session), EliminateRedundantTransitions(session)) + plan => rules.foldLeft(plan) { case (p, rule) => rule(p) } + } } } diff --git a/spark/src/main/scala/org/apache/comet/rules/RevertNativeForTransitionHeavyStages.scala b/spark/src/main/scala/org/apache/comet/rules/RevertNativeForTransitionHeavyStages.scala new file mode 100644 index 0000000000..1e0cfc79e0 --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/rules/RevertNativeForTransitionHeavyStages.scala @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.rules + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.comet.{CometColumnarToRowExec, CometExec, CometNativeColumnarToRowExec, CometSparkToColumnarExec} +import org.apache.spark.sql.execution.{ColumnarToRowExec, ColumnarToRowTransition, RowToColumnarExec, SparkPlan} +import org.apache.spark.sql.execution.adaptive.QueryStageExec +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, ShuffleExchangeLike} + +import org.apache.comet.CometConf +import org.apache.comet.CometSparkSessionExtensions.withFallbackReason + +/** + * Reverts a query stage to Spark row-based execution when it has too many columnar-to-row (C2R) + * transitions. Each C2R indicates Comet could not keep execution columnar and had to fall back. + * With columnar shuffle enabled, each C2R implies a corresponding R2C round-trip. + */ +case class RevertNativeForTransitionHeavyStages(session: SparkSession) + extends Rule[SparkPlan] + with Logging { + + private def enabled = CometConf.COMET_EXEC_TRANSITION_REVERT_ENABLED.get() + private def maxTransitions = CometConf.COMET_EXEC_TRANSITION_REVERT_MAX_TRANSITIONS.get() + + override def apply(plan: SparkPlan): SparkPlan = { + if (!enabled) return plan + + if (session.sessionState.conf.adaptiveExecutionEnabled) { + applyForAQE(plan) + } else { + applyForNonAQE(plan) + } + } + + private def applyForAQE(plan: SparkPlan): SparkPlan = { + plan match { + case _: BroadcastExchangeLike => plan + case exchange: ShuffleExchangeLike => + revertStageIfNeeded(exchange.child, exchange.supportsColumnar) + .map(reverted => exchange.withNewChildren(Seq(reverted))) + .getOrElse(plan) + case _ => + // Result stage: its output is collected as rows, so no consumer requires columnar input + // and the reverted stage needs no trailing R2C. + revertStageIfNeeded(plan, outputColumnar = false).getOrElse(plan) + } + } + + private def applyForNonAQE(plan: SparkPlan): SparkPlan = { + val withRevertedStages = plan.transformUp { case exchange: ShuffleExchangeLike => + revertStageIfNeeded(exchange.child, exchange.supportsColumnar) + .map(reverted => exchange.withNewChildren(Seq(reverted))) + .getOrElse(exchange) + } + revertStageIfNeeded(withRevertedStages, outputColumnar = false) + .getOrElse(withRevertedStages) + } + + /** + * Reverts the stage if C2R count exceeds threshold. Wraps in R2C if exchange needs columnar. + */ + private def revertStageIfNeeded( + stagePlan: SparkPlan, + outputColumnar: Boolean): Option[SparkPlan] = { + val transitionCount = countTransitions(stagePlan) + if (transitionCount <= maxTransitions) return None + + val reason = + s"Stage reverted: $transitionCount C2R transitions exceed threshold $maxTransitions" + + val reverted = revertToSpark(stagePlan) + val result = if (outputColumnar && !reverted.supportsColumnar) { + RowToColumnarExec(withFallbackReason(reverted, reason)) + } else { + withFallbackReason(reverted, reason) + } + Some(result) + } + + /** + * A node that marks the boundary between this stage and an adjacent one. + */ + private def isStageBoundary(plan: SparkPlan): Boolean = plan match { + case _: QueryStageExec | _: ShuffleExchangeLike | _: BroadcastExchangeLike => true + case _ => false + } + + /** + * Like `transformDown`, never descends stage-boundary children. + */ + private def transformStageDown(plan: SparkPlan)( + rule: PartialFunction[SparkPlan, SparkPlan]): SparkPlan = { + val transformed = rule.applyOrElse(plan, identity[SparkPlan]) + val newChildren = transformed.children.map { child => + if (isStageBoundary(child)) child else transformStageDown(child)(rule) + } + if (newChildren == transformed.children) transformed + else transformed.withNewChildren(newChildren) + } + + /** Like `transformUp`, never descends stage-boundary children. */ + private def transformStageUp(plan: SparkPlan)( + rule: PartialFunction[SparkPlan, SparkPlan]): SparkPlan = { + val newChildren = plan.children.map { child => + if (isStageBoundary(child)) child else transformStageUp(child)(rule) + } + val withNewChildren = + if (newChildren == plan.children) plan else plan.withNewChildren(newChildren) + rule.applyOrElse(withNewChildren, identity[SparkPlan]) + } + + /** Counts C2R transitions within this stage, stopping at stage boundaries. */ + private[rules] def countTransitions(plan: SparkPlan): Int = { + var count = 0 + def visit(node: SparkPlan): Unit = node match { + case _ if isStageBoundary(node) => () + case _: ColumnarToRowTransition => + count += 1 + node.children.foreach(visit) + case _ => + node.children.foreach(visit) + } + visit(plan) + count + } + + private[rules] def revertToSpark(plan: SparkPlan): SparkPlan = { + val stripped = transformStageDown(plan) { + case CometNativeColumnarToRowExec(child) => child + case CometColumnarToRowExec(child) => child + case ColumnarToRowExec(child) => child + case sparkToColumnar: CometSparkToColumnarExec => sparkToColumnar.child + case RowToColumnarExec(child) => child + } + val reverted = transformStageUp(stripped) { case cometExec: CometExec => + if (cometExec.originalPlan.children.size == cometExec.children.size) { + cometExec.originalPlan.withNewChildren(cometExec.children) + } else { + logWarning( + "Comet plan and original have different child count for " + + s"${cometExec.getClass.getSimpleName}, using originalPlan as-is.") + cometExec.originalPlan + } + } + insertTransitions(reverted) + } + + private def insertTransitions(plan: SparkPlan): SparkPlan = { + // transformStageUp never descends into stage-boundary nodes (QueryStageExec, exchanges), so + // this only needs to bridge row nodes that still have a columnar child within the stage. + transformStageUp(plan) { + case node if !node.supportsColumnar => + val newChildren = node.children.map { child => + if (child.supportsColumnar) ColumnarToRowExec(child) else child + } + if (newChildren != node.children) node.withNewChildren(newChildren) else node + } + } +} diff --git a/spark/src/test/scala/org/apache/comet/rules/RevertNativeForTransitionHeavyStagesSuite.scala b/spark/src/test/scala/org/apache/comet/rules/RevertNativeForTransitionHeavyStagesSuite.scala new file mode 100644 index 0000000000..fb66ee8193 --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/rules/RevertNativeForTransitionHeavyStagesSuite.scala @@ -0,0 +1,282 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.rules + +import org.apache.spark.sql.CometTestBase +import org.apache.spark.sql.comet._ +import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec +import org.apache.spark.sql.execution._ + +import org.apache.comet.CometConf + +class RevertNativeForTransitionHeavyStagesSuite extends CometTestBase { + + private def createSparkPlan(sql: String): SparkPlan = { + var plan: SparkPlan = null + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + plan = spark.sql(sql).queryExecution.executedPlan + } + stripAQEPlan(plan) + } + + private def applyCometExecRule(plan: SparkPlan): SparkPlan = { + CometExecRule(spark).apply(plan) + } + + private def applyFullColumnarPipeline(plan: SparkPlan): SparkPlan = { + val cometPlan = CometScanRule(spark).apply(plan) + val execPlan = CometExecRule(spark).apply(cometPlan) + val withTransitions = ApplyColumnarRulesAndInsertTransitions(Seq.empty, false).apply(execPlan) + EliminateRedundantTransitions(spark).apply(withTransitions) + } + + private def countCometExecs(plan: SparkPlan): Int = { + plan.collect { case _: CometExec => true }.size + } + + private def countC2RNodes(plan: SparkPlan): Int = { + plan.collect { case _: ColumnarToRowTransition => true }.size + } + + /** + * Returns every node that produces a columnar output but consumes a row-based child without a + * RowToColumnar transition. Such a node is an invalid columnar/row boundary: a columnar parent + * (e.g. a native CometShuffleExchangeExec) requires columnar input. RowToColumnarExec and + * CometSparkToColumnarExec are the legitimate row->columnar bridges and are excluded. + */ + private def invalidColumnarBoundaries(plan: SparkPlan): Seq[SparkPlan] = { + plan.collect { + case n + if n.supportsColumnar && !n.isInstanceOf[RowToColumnarTransition] && + n.children.exists(c => !c.supportsColumnar) => + n + } + } + + test("rule is a no-op when disabled") { + withSQLConf(CometConf.COMET_EXEC_TRANSITION_REVERT_ENABLED.key -> "false") { + withTempView("test_data") { + spark.range(10).toDF("id").createOrReplaceTempView("test_data") + val sparkPlan = createSparkPlan("SELECT id, id * 2 FROM test_data WHERE id > 5") + val cometPlan = applyCometExecRule(sparkPlan) + assert(countCometExecs(cometPlan) > 0, "Plan should have CometExec nodes") + + val rule = RevertNativeForTransitionHeavyStages(spark) + val result = rule.apply(cometPlan) + assert(result eq cometPlan, "Rule should be a no-op when disabled") + } + } + } + + test("rule does not revert plan below threshold") { + withSQLConf( + CometConf.COMET_EXEC_TRANSITION_REVERT_ENABLED.key -> "true", + CometConf.COMET_EXEC_TRANSITION_REVERT_MAX_TRANSITIONS.key -> "10", + "spark.comet.exec.project.enabled" -> "false") { + withTempView("test_data") { + spark.range(10).toDF("id").createOrReplaceTempView("test_data") + val sparkPlan = + createSparkPlan("SELECT id, id * 2 as doubled FROM test_data WHERE id > 5") + val cometPlan = applyFullColumnarPipeline(sparkPlan) + + val rule = RevertNativeForTransitionHeavyStages(spark) + val transitions = rule.countTransitions(cometPlan) + assert(transitions > 0, s"Plan should have transitions, got $transitions") + assert(transitions <= 10, "Transitions should be below threshold") + + val result = rule.apply(cometPlan) + assert(result eq cometPlan, "Plan should be unchanged when below threshold") + } + } + } + + test("revertToSpark preserves plan structure") { + withSQLConf(CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") { + + withTempView("test_data") { + spark.range(10).toDF("id").createOrReplaceTempView("test_data") + val sparkPlan = + createSparkPlan("SELECT id, id * 2 as doubled FROM test_data WHERE id > 5") + val cometPlan = applyCometExecRule(sparkPlan) + val rule = RevertNativeForTransitionHeavyStages(spark) + val reverted = rule.revertToSpark(cometPlan) + + // Reverted plan should have same output schema + assert( + reverted.output.map(_.name) == cometPlan.output.map(_.name), + "Output schema should be preserved after revert") + } + } + } + + test("revertToSpark removes all Comet operators from a plan with transitions") { + withSQLConf(CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") { + + withTempView("test_data") { + spark.range(10).toDF("id").createOrReplaceTempView("test_data") + val sparkPlan = + createSparkPlan("SELECT id, id * 2 as doubled FROM test_data WHERE id > 5") + val cometPlan = applyFullColumnarPipeline(sparkPlan) + assert(countCometExecs(cometPlan) > 0, "Should have CometExec nodes before revert") + + val rule = RevertNativeForTransitionHeavyStages(spark) + val result = rule.revertToSpark(cometPlan) + assert( + countCometExecs(result) == 0, + s"All CometExec should be reverted. Plan:\n${result.treeString}") + } + } + } + + test("non-AQE path applies rule per-stage via transformUp") { + withSQLConf( + CometConf.COMET_EXEC_TRANSITION_REVERT_ENABLED.key -> "true", + CometConf.COMET_EXEC_TRANSITION_REVERT_MAX_TRANSITIONS.key -> "10", + "spark.sql.adaptive.enabled" -> "false") { + + withTempView("test_data") { + spark + .range(10) + .selectExpr("id", "id % 3 as grp") + .createOrReplaceTempView("test_data") + val sparkPlan = createSparkPlan("SELECT grp, count(*) FROM test_data GROUP BY grp") + val cometPlan = applyCometExecRule(sparkPlan) + + // With high threshold, the non-AQE path should not revert anything + val rule = RevertNativeForTransitionHeavyStages(spark) + val result = rule.apply(cometPlan) + assert(result eq cometPlan, "Non-AQE path should not revert when below threshold") + } + } + } + + test("revert fires with unsupported UDF producing transitions") { + withParquetTable((0 until 100).map(i => (i, i % 10, s"val_$i")), "tbl") { + spark.udf.register("identity_udf", (x: Int) => x) + val query = "SELECT _2, identity_udf(_1), count(*) FROM tbl GROUP BY _2, identity_udf(_1)" + + // Without revert, plan should have transitions due from UDF + withSQLConf(CometConf.COMET_EXEC_TRANSITION_REVERT_ENABLED.key -> "false") { + val df = sql(query) + df.collect() + val plan = stripAQEPlan(df.queryExecution.executedPlan) + assert(countC2RNodes(plan) > 0, "UDF should cause C2R transitions") + } + + // With threshold 0, stage should be reverted + withSQLConf( + CometConf.COMET_EXEC_TRANSITION_REVERT_ENABLED.key -> "true", + CometConf.COMET_EXEC_TRANSITION_REVERT_MAX_TRANSITIONS.key -> "0") { + val (_, cometPlan) = checkSparkAnswer(query) + val executedPlan = stripAQEPlan(cometPlan) + assert( + countCometExecs(executedPlan) == 0, + s"Revert should have removed all CometExec nodes:\n${executedPlan.treeString}") + } + } + } + + test("revert fires and produces correct results when transitions exceed threshold") { + withParquetTable((0 until 100).map(i => (i, i % 10, s"val_$i")), "tbl") { + val query = "SELECT _2, count(*), sum(_1) FROM tbl GROUP BY _2" + + // Without revert, plan should have CometExec nodes with transitions + withSQLConf( + CometConf.COMET_EXEC_TRANSITION_REVERT_ENABLED.key -> "false", + "spark.comet.exec.project.enabled" -> "false") { + val df = sql(query) + df.collect() + val plan = stripAQEPlan(df.queryExecution.executedPlan) + assert(countCometExecs(plan) > 0, "Plan without revert should have CometExec nodes") + assert(countC2RNodes(plan) > 0, "Plan without revert should have C2R transitions") + } + + // With revert enabled at threshold 0, all CometExec should be removed + withSQLConf( + CometConf.COMET_EXEC_TRANSITION_REVERT_ENABLED.key -> "true", + CometConf.COMET_EXEC_TRANSITION_REVERT_MAX_TRANSITIONS.key -> "0", + "spark.comet.exec.project.enabled" -> "false") { + val (_, cometPlan) = checkSparkAnswer(query) + val executedPlan = stripAQEPlan(cometPlan) + assert( + countCometExecs(executedPlan) == 0, + s"Revert should have removed all CometExec nodes:\n${executedPlan.treeString}") + } + } + } + + test("revertToSpark must not revert native operators across a shuffle stage boundary") { + withSQLConf("spark.sql.adaptive.enabled" -> "false") { + withParquetTable((0 until 100).map(i => (i, i % 10)), "tbl") { + // A GROUP BY produces partial-agg -> native shuffle -> final-agg, i.e. two stages. + val df = sql("SELECT _2, count(*) FROM tbl GROUP BY _2") + df.collect() + val cometPlan = stripAQEPlan(df.queryExecution.executedPlan) + + val shuffles = cometPlan.collect { case s: CometShuffleExchangeExec => s } + assume(shuffles.nonEmpty, "test requires a native CometShuffleExchangeExec") + assert( + shuffles.map(s => countCometExecs(s.child)).sum > 0, + "expected native CometExec operators below the shuffle") + assert( + invalidColumnarBoundaries(cometPlan).isEmpty, + s"precondition: original plan should be valid:\n${cometPlan.treeString}") + + val rule = RevertNativeForTransitionHeavyStages(spark) + val reverted = rule.revertToSpark(cometPlan) + + val invalid = invalidColumnarBoundaries(reverted) + assert( + invalid.isEmpty, + "revertToSpark produced invalid columnar/row boundaries " + + s"(${invalid.map(_.nodeName).mkString(", ")}):\n${reverted.treeString}") + } + } + } + + test("non-AQE apply must not produce an invalid plan when the result stage reverts") { + withSQLConf( + CometConf.COMET_EXEC_TRANSITION_REVERT_ENABLED.key -> "true", + // Threshold 0 forces the result stage (above the topmost shuffle) to revert. + CometConf.COMET_EXEC_TRANSITION_REVERT_MAX_TRANSITIONS.key -> "0", + "spark.sql.adaptive.enabled" -> "false") { + withParquetTable((0 until 100).map(i => (i, i % 10)), "tbl") { + var cometPlan: SparkPlan = null + withSQLConf(CometConf.COMET_EXEC_TRANSITION_REVERT_ENABLED.key -> "false") { + val df = sql("SELECT _2, count(*) FROM tbl GROUP BY _2") + df.collect() + cometPlan = stripAQEPlan(df.queryExecution.executedPlan) + } + assume( + cometPlan.collect { case s: CometShuffleExchangeExec => s }.nonEmpty, + "test requires a native CometShuffleExchangeExec") + + val rule = RevertNativeForTransitionHeavyStages(spark) + val result = rule.apply(cometPlan) + + val invalid = invalidColumnarBoundaries(result) + assert( + invalid.isEmpty, + "rule.apply produced invalid columnar/row boundaries " + + s"(${invalid.map(_.nodeName).mkString(", ")}):\n${result.treeString}") + } + } + } +}