-
Notifications
You must be signed in to change notification settings - Fork 28.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-50589][SQL] Avoid extra expression duplication when push filter #49202
base: master
Are you sure you want to change the base?
Changes from 13 commits
ca81109
8b17a9f
808752e
53b1852
ea68366
b212228
b9df556
5acc5fa
044ffaa
f1aa618
6247c34
8fb567e
58459d0
776122d
a94d51b
925f80e
9683b68
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -112,6 +112,22 @@ object With { | |
With(replaced(commonExprRefs), commonExprDefs) | ||
} | ||
|
||
/** | ||
* Helper function to create a [[With]] statement when push down filter. | ||
* @param expr original expression | ||
* @param replaceMap Replaced attributes and common expressions | ||
*/ | ||
def apply(expr: Expression, replaceMap: Map[Attribute, Expression]): With = { | ||
cloud-fan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
val commonExprDefsMap = replaceMap.map(m => m._1 -> CommonExpressionDef(m._2)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do we create a map if we never look up from it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Convenient to generate commonExprRefsMap. |
||
val commonExprRefsMap = | ||
AttributeMap(commonExprDefsMap.map(m => m._1 -> new CommonExpressionRef(m._2))) | ||
val replaced = expr.transform { | ||
case a: Attribute if commonExprRefsMap.contains(a) => | ||
commonExprRefsMap.get(a).get | ||
} | ||
With(replaced, commonExprDefsMap.values.toSeq) | ||
} | ||
|
||
private[sql] def childContainsUnsupportedAggExpr(withExpr: With): Boolean = { | ||
lazy val commonExprIds = withExpr.defs.map(_.id).toSet | ||
withExpr.child.exists { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.sql.catalyst.optimizer | ||
|
||
import scala.collection.mutable | ||
|
||
import org.apache.spark.sql.catalyst.expressions._ | ||
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan} | ||
import org.apache.spark.sql.catalyst.rules.Rule | ||
import org.apache.spark.sql.catalyst.trees.TreePattern.{AND, FILTER, WITH_EXPRESSION} | ||
|
||
/** | ||
* Before rewrite with expression, merge with expression which has same common expression for | ||
* avoid extra expression duplication. | ||
*/ | ||
object MergeWithExpression extends Rule[LogicalPlan] { | ||
override def apply(plan: LogicalPlan): LogicalPlan = { | ||
plan.transformUpWithSubqueriesAndPruning(_.containsPattern(FILTER)) { | ||
case f @ Filter(cond, _) => | ||
val newCond = cond.transformUpWithPruning(_.containsAllPatterns(AND, WITH_EXPRESSION)) { | ||
case And(left @ With(_, _), right @ With(_, _)) => | ||
mergeWith(left, right) | ||
case And(left @ With(_, _), right) => | ||
With(And(left.child, right), left.defs) | ||
case And(left, right @ With(_, _)) => | ||
With(And(left, right.child), right.defs) | ||
} | ||
f.copy(condition = newCond) | ||
} | ||
} | ||
|
||
private def mergeWith(left: With, right: With): Expression = { | ||
val newDefs = left.defs.toBuffer | ||
val replaceMap = mutable.HashMap.empty[CommonExpressionId, CommonExpressionRef] | ||
right.defs.foreach {rDef => | ||
val index = left.defs.indexWhere(lDef => rDef.child.fastEquals(lDef.child)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. plan comparison can be expensive. Since it's a very special case for filter pushdown, shall we use the same |
||
if (index == -1) { | ||
newDefs.append(rDef) | ||
} else { | ||
replaceMap.put(rDef.id, new CommonExpressionRef(left.defs(index))) | ||
} | ||
} | ||
val newChild = if (replaceMap.nonEmpty) { | ||
val newRightChild = right.child.transform { | ||
case r: CommonExpressionRef if replaceMap.contains(r.id) => | ||
replaceMap(r.id) | ||
} | ||
And(left.child, newRightChild) | ||
} else { | ||
And(left.child, right.child) | ||
} | ||
With(newChild, newDefs.toSeq) | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -161,14 +161,23 @@ abstract class Optimizer(catalogManager: CatalogManager) | |
val operatorOptimizationBatch: Seq[Batch] = Seq( | ||
Batch("Operator Optimization before Inferring Filters", fixedPoint, | ||
operatorOptimizationRuleSet: _*), | ||
// With expression will destroy infer filters, so need rewrite it before infer filters. | ||
Batch("Merge With expression", fixedPoint, MergeWithExpression), | ||
Batch("Rewrite With expression", fixedPoint, | ||
RewriteWithExpression, | ||
CollapseProject), | ||
Batch("Infer Filters", Once, | ||
InferFiltersFromGenerate, | ||
InferFiltersFromConstraints), | ||
Batch("Operator Optimization after Inferring Filters", fixedPoint, | ||
operatorOptimizationRuleSet: _*), | ||
Batch("Push extra predicate through join", fixedPoint, | ||
PushExtraPredicateThroughJoin, | ||
PushDownPredicates)) | ||
PushDownPredicates), | ||
Batch("Merge With expression", fixedPoint, MergeWithExpression), | ||
Batch("Rewrite With expression", fixedPoint, | ||
RewriteWithExpression, | ||
CollapseProject)) | ||
|
||
val batches: Seq[Batch] = flattenBatches(Seq( | ||
Batch("Finish Analysis", FixedPoint(1), FinishAnalysis), | ||
|
@@ -1811,7 +1820,8 @@ object PushPredicateThroughNonJoin extends Rule[LogicalPlan] with PredicateHelpe | |
case Filter(condition, project @ Project(fields, grandChild)) | ||
if fields.forall(_.deterministic) && canPushThroughCondition(grandChild, condition) => | ||
val aliasMap = getAliasMap(project) | ||
project.copy(child = Filter(replaceAlias(condition, aliasMap), grandChild)) | ||
val replacedByWith = rewriteConditionByWith(condition, aliasMap) | ||
project.copy(child = Filter(replaceAlias(replacedByWith, aliasMap), grandChild)) | ||
|
||
// We can push down deterministic predicate through Aggregate, including throwable predicate. | ||
// If we can push down a filter through Aggregate, it means the filter only references the | ||
|
@@ -1831,8 +1841,8 @@ object PushPredicateThroughNonJoin extends Rule[LogicalPlan] with PredicateHelpe | |
} | ||
|
||
if (pushDown.nonEmpty) { | ||
val pushDownPredicate = pushDown.reduce(And) | ||
cloud-fan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
val replaced = replaceAlias(pushDownPredicate, aliasMap) | ||
val replacedByWith = rewriteConditionByWith(pushDown.reduce(And), aliasMap) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Calculate the number of times each common expression is used through the entire condition. |
||
val replaced = replaceAlias(replacedByWith, aliasMap) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. doesn't There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, |
||
val newAggregate = aggregate.copy(child = Filter(replaced, aggregate.child)) | ||
// If there is no more filter to stay up, just eliminate the filter. | ||
// Otherwise, create "Filter(stayUp) <- Aggregate <- Filter(pushDownPredicate)". | ||
|
@@ -1978,6 +1988,45 @@ object PushPredicateThroughNonJoin extends Rule[LogicalPlan] with PredicateHelpe | |
case _ => false | ||
} | ||
} | ||
|
||
/** | ||
* Use [[With]] to rewrite condition which contains attribute that are not cheap and be consumed | ||
* multiple times. | ||
*/ | ||
private def rewriteConditionByWith( | ||
cond: Expression, | ||
aliasMap: AttributeMap[Alias]): Expression = { | ||
if (!SQLConf.get.getConf(SQLConf.ALWAYS_INLINE_COMMON_EXPR)) { | ||
val replaceWithMap = cond.collect { case a: Attribute => a } | ||
.groupBy(identity) | ||
.transform((_, v) => v.size) | ||
.filter(m => aliasMap.contains(m._1) && m._2 > 1) | ||
.map(m => m._1 -> trimAliases(aliasMap.getOrElse(m._1, m._1))) | ||
.filter(m => !CollapseProject.isCheap(m._2)) | ||
splitConjunctivePredicates(cond) | ||
.map(rewriteByWith(_, AttributeMap(replaceWithMap))) | ||
.reduce(And) | ||
} else cond | ||
} | ||
|
||
// With does not support inline subquery | ||
private def canRewriteByWith(expr: Expression): Boolean = { | ||
!expr.containsPattern(PLAN_EXPRESSION) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this check too strong? We only require the common expression to not contain subqueries. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems not feasible, https://github.com/zml1206/spark/actions/runs/12491712984/job/34858074548 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Both rewriting subqueries and pushing down predicates are in batch "operator optimization before inferring filters". Pushing down predicates may cause SubqueryExpression to contain common refs, and then rewriting subqueries cannot replace common refs. |
||
} | ||
|
||
private def rewriteByWith( | ||
expr: Expression, | ||
replaceMap: AttributeMap[Expression]): Expression = { | ||
if (!canRewriteByWith(expr)) { | ||
return expr | ||
} | ||
val exprAttrSet = expr.collect { case a: Attribute => a }.toSet | ||
val newReplaceMap = replaceMap.filter(x => exprAttrSet.contains(x._1)) | ||
if (newReplaceMap.isEmpty) { | ||
return expr | ||
} | ||
With(expr, newReplaceMap) | ||
} | ||
} | ||
|
||
/** | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -96,14 +96,17 @@ object RewriteWithExpression extends Rule[LogicalPlan] { | |||||
val defs = w.defs.map(rewriteWithExprAndInputPlans(_, inputPlans, isNestedWith = true)) | ||||||
val refToExpr = mutable.HashMap.empty[CommonExpressionId, Expression] | ||||||
val childProjections = Array.fill(inputPlans.length)(mutable.ArrayBuffer.empty[Alias]) | ||||||
val refsCount = child.collect { case r: CommonExpressionRef => r} | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should only collect references to the current
Suggested change
|
||||||
.groupBy(_.id) | ||||||
.transform((_, v) => v.size) | ||||||
|
||||||
defs.zipWithIndex.foreach { case (CommonExpressionDef(child, id), index) => | ||||||
if (id.canonicalized) { | ||||||
throw SparkException.internalError( | ||||||
"Cannot rewrite canonicalized Common expression definitions") | ||||||
} | ||||||
|
||||||
if (CollapseProject.isCheap(child)) { | ||||||
if (CollapseProject.isCheap(child) || refsCount.getOrElse(id, 0) < 2) { | ||||||
refToExpr(id) = child | ||||||
} else { | ||||||
val childProjectionIndex = inputPlans.indexWhere( | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -95,7 +95,11 @@ class SparkOptimizer( | |
EliminateLimits, | ||
ConstantFolding), | ||
Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*), | ||
Batch("Replace CTE with Repartition", Once, ReplaceCTERefWithRepartition))) | ||
Batch("Replace CTE with Repartition", Once, ReplaceCTERefWithRepartition), | ||
Batch("Merge With expression", fixedPoint, MergeWithExpression), | ||
Batch("Rewrite With expression", fixedPoint, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how many places do we need to put this batch? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Before Extract Python UDFs, infer filters and last of each optimizer. |
||
RewriteWithExpression, | ||
CollapseProject))) | ||
|
||
override def nonExcludableRules: Seq[String] = super.nonExcludableRules ++ | ||
Seq( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is very specific to filter pushdown, shall we put this method in the filter pushdown rule?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, wait for CI to finish running.