Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-50589][SQL] Avoid extra expression duplication when push filter #49202

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,22 @@ object With {
With(replaced(commonExprRefs), commonExprDefs)
}

/**
* Helper function to create a [[With]] statement when push down filter.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is very specific to filter pushdown, shall we put this method in the filter pushdown rule?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, wait for CI to finish running.

* @param expr original expression
* @param replaceMap Replaced attributes and common expressions
*/
def apply(expr: Expression, replaceMap: Map[Attribute, Expression]): With = {
cloud-fan marked this conversation as resolved.
Show resolved Hide resolved
val commonExprDefsMap = replaceMap.map(m => m._1 -> CommonExpressionDef(m._2))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we create a map if we never look up from it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Convenient to generate commonExprRefsMap.

val commonExprRefsMap =
AttributeMap(commonExprDefsMap.map(m => m._1 -> new CommonExpressionRef(m._2)))
val replaced = expr.transform {
case a: Attribute if commonExprRefsMap.contains(a) =>
commonExprRefsMap.get(a).get
}
With(replaced, commonExprDefsMap.values.toSeq)
}

private[sql] def childContainsUnsupportedAggExpr(withExpr: With): Boolean = {
lazy val commonExprIds = withExpr.defs.map(_.id).toSet
withExpr.child.exists {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.optimizer

import scala.collection.mutable

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.{AND, FILTER, WITH_EXPRESSION}

/**
* Before rewrite with expression, merge with expression which has same common expression for
* avoid extra expression duplication.
*/
object MergeWithExpression extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = {
plan.transformUpWithSubqueriesAndPruning(_.containsPattern(FILTER)) {
case f @ Filter(cond, _) =>
val newCond = cond.transformUpWithPruning(_.containsAllPatterns(AND, WITH_EXPRESSION)) {
case And(left @ With(_, _), right @ With(_, _)) =>
mergeWith(left, right)
case And(left @ With(_, _), right) =>
With(And(left.child, right), left.defs)
case And(left, right @ With(_, _)) =>
With(And(left, right.child), right.defs)
}
f.copy(condition = newCond)
}
}

private def mergeWith(left: With, right: With): Expression = {
val newDefs = left.defs.toBuffer
val replaceMap = mutable.HashMap.empty[CommonExpressionId, CommonExpressionRef]
right.defs.foreach {rDef =>
val index = left.defs.indexWhere(lDef => rDef.child.fastEquals(lDef.child))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

plan comparison can be expensive. Since it's a very special case for filter pushdown, shall we use the same CommonExpressionDef ids for different Withs in the split predicates?

if (index == -1) {
newDefs.append(rDef)
} else {
replaceMap.put(rDef.id, new CommonExpressionRef(left.defs(index)))
}
}
val newChild = if (replaceMap.nonEmpty) {
val newRightChild = right.child.transform {
case r: CommonExpressionRef if replaceMap.contains(r.id) =>
replaceMap(r.id)
}
And(left.child, newRightChild)
} else {
And(left.child, right.child)
}
With(newChild, newDefs.toSeq)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -161,14 +161,23 @@ abstract class Optimizer(catalogManager: CatalogManager)
val operatorOptimizationBatch: Seq[Batch] = Seq(
Batch("Operator Optimization before Inferring Filters", fixedPoint,
operatorOptimizationRuleSet: _*),
// With expression will destroy infer filters, so need rewrite it before infer filters.
Batch("Merge With expression", fixedPoint, MergeWithExpression),
Batch("Rewrite With expression", fixedPoint,
RewriteWithExpression,
CollapseProject),
Batch("Infer Filters", Once,
InferFiltersFromGenerate,
InferFiltersFromConstraints),
Batch("Operator Optimization after Inferring Filters", fixedPoint,
operatorOptimizationRuleSet: _*),
Batch("Push extra predicate through join", fixedPoint,
PushExtraPredicateThroughJoin,
PushDownPredicates))
PushDownPredicates),
Batch("Merge With expression", fixedPoint, MergeWithExpression),
Batch("Rewrite With expression", fixedPoint,
RewriteWithExpression,
CollapseProject))

val batches: Seq[Batch] = flattenBatches(Seq(
Batch("Finish Analysis", FixedPoint(1), FinishAnalysis),
Expand Down Expand Up @@ -1811,7 +1820,8 @@ object PushPredicateThroughNonJoin extends Rule[LogicalPlan] with PredicateHelpe
case Filter(condition, project @ Project(fields, grandChild))
if fields.forall(_.deterministic) && canPushThroughCondition(grandChild, condition) =>
val aliasMap = getAliasMap(project)
project.copy(child = Filter(replaceAlias(condition, aliasMap), grandChild))
val replacedByWith = rewriteConditionByWith(condition, aliasMap)
project.copy(child = Filter(replaceAlias(replacedByWith, aliasMap), grandChild))

// We can push down deterministic predicate through Aggregate, including throwable predicate.
// If we can push down a filter through Aggregate, it means the filter only references the
Expand All @@ -1831,8 +1841,8 @@ object PushPredicateThroughNonJoin extends Rule[LogicalPlan] with PredicateHelpe
}

if (pushDown.nonEmpty) {
val pushDownPredicate = pushDown.reduce(And)
cloud-fan marked this conversation as resolved.
Show resolved Hide resolved
val replaced = replaceAlias(pushDownPredicate, aliasMap)
val replacedByWith = rewriteConditionByWith(pushDown.reduce(And), aliasMap)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rewriteConditionByWith will split the predicate anyway, why do we combine them with And here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calculate the number of times each common expression is used through the entire condition.

val replaced = replaceAlias(replacedByWith, aliasMap)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesn't rewriteConditionByWith already replace the aliases?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, rewriteConditionByWith only rewrite common attribute to common expression ref.

val newAggregate = aggregate.copy(child = Filter(replaced, aggregate.child))
// If there is no more filter to stay up, just eliminate the filter.
// Otherwise, create "Filter(stayUp) <- Aggregate <- Filter(pushDownPredicate)".
Expand Down Expand Up @@ -1978,6 +1988,45 @@ object PushPredicateThroughNonJoin extends Rule[LogicalPlan] with PredicateHelpe
case _ => false
}
}

/**
* Use [[With]] to rewrite condition which contains attribute that are not cheap and be consumed
* multiple times.
*/
private def rewriteConditionByWith(
cond: Expression,
aliasMap: AttributeMap[Alias]): Expression = {
if (!SQLConf.get.getConf(SQLConf.ALWAYS_INLINE_COMMON_EXPR)) {
val replaceWithMap = cond.collect { case a: Attribute => a }
.groupBy(identity)
.transform((_, v) => v.size)
.filter(m => aliasMap.contains(m._1) && m._2 > 1)
.map(m => m._1 -> trimAliases(aliasMap.getOrElse(m._1, m._1)))
.filter(m => !CollapseProject.isCheap(m._2))
splitConjunctivePredicates(cond)
.map(rewriteByWith(_, AttributeMap(replaceWithMap)))
.reduce(And)
} else cond
}

// With does not support inline subquery
private def canRewriteByWith(expr: Expression): Boolean = {
!expr.containsPattern(PLAN_EXPRESSION)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this check too strong? We only require the common expression to not contain subqueries.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both rewriting subqueries and pushing down predicates are in batch "operator optimization before inferring filters". Pushing down predicates may cause SubqueryExpression to contain common refs, and then rewriting subqueries cannot replace common refs.

}

private def rewriteByWith(
expr: Expression,
replaceMap: AttributeMap[Expression]): Expression = {
if (!canRewriteByWith(expr)) {
return expr
}
val exprAttrSet = expr.collect { case a: Attribute => a }.toSet
val newReplaceMap = replaceMap.filter(x => exprAttrSet.contains(x._1))
if (newReplaceMap.isEmpty) {
return expr
}
With(expr, newReplaceMap)
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,17 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
val defs = w.defs.map(rewriteWithExprAndInputPlans(_, inputPlans, isNestedWith = true))
val refToExpr = mutable.HashMap.empty[CommonExpressionId, Expression]
val childProjections = Array.fill(inputPlans.length)(mutable.ArrayBuffer.empty[Alias])
val refsCount = child.collect { case r: CommonExpressionRef => r}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should only collect references to the current With, not the nested With.

Suggested change
val refsCount = child.collect { case r: CommonExpressionRef => r}
val refsCount = child.collect { case r: CommonExpressionRef if defs.exists(_.id == r.id) => r}

.groupBy(_.id)
.transform((_, v) => v.size)

defs.zipWithIndex.foreach { case (CommonExpressionDef(child, id), index) =>
if (id.canonicalized) {
throw SparkException.internalError(
"Cannot rewrite canonicalized Common expression definitions")
}

if (CollapseProject.isCheap(child)) {
if (CollapseProject.isCheap(child) || refsCount.getOrElse(id, 0) < 2) {
refToExpr(id) = child
} else {
val childProjectionIndex = inputPlans.indexWhere(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval

Expand All @@ -45,7 +46,11 @@ class FilterPushdownSuite extends PlanTest {
CollapseProject) ::
Batch("Push extra predicate through join", FixedPoint(10),
PushExtraPredicateThroughJoin,
PushDownPredicates) :: Nil
PushDownPredicates) ::
Batch("Merge With expression", FixedPoint(10), MergeWithExpression) ::
Batch("Rewrite With expression", FixedPoint(10),
RewriteWithExpression,
CollapseProject) :: Nil
}

val attrA = $"a".int
Expand Down Expand Up @@ -1539,4 +1544,37 @@ class FilterPushdownSuite extends PlanTest {
.analyze
comparePlans(optimizedQueryWithoutStep, correctAnswer)
}

test("SPARK-50589: avoid extra expression duplication when push filter") {
withSQLConf(SQLConf.USE_COMMON_EXPR_ID_FOR_ALIAS.key -> "false") {
// through project
val originalQuery1 = testRelation
.select($"a" + $"b" as "add", $"a" - $"b" as "sub")
.where($"add" < 10 && $"add" + $"add" > 10 && $"sub" > 0)
val optimized1 = Optimize.execute(originalQuery1.analyze)
val correctAnswer1 = testRelation
.select($"a", $"b", $"c", $"a" + $"b" as "_common_expr_0")
.where($"_common_expr_0" < 10 &&
$"_common_expr_0" + $"_common_expr_0" > 10 &&
$"a" - $"b" > 0)
.select($"a" + $"b" as "add", $"a" - $"b" as "sub")
.analyze
comparePlans(optimized1, correctAnswer1)

// through aggregate
val originalQuery2 = testRelation
.groupBy($"a")($"a", $"a" + $"a" as "add", abs($"a") as "abs", count(1) as "ct")
.where($"add" < 10 && $"add" + $"add" > 10 && $"abs" > 5)
val optimized2 = Optimize.execute(originalQuery2.analyze)
val correctAnswer2 = testRelation
.select($"a", $"b", $"c", $"a" + $"a" as "_common_expr_0")
.where($"_common_expr_0" < 10 &&
$"_common_expr_0" + $"_common_expr_0" > 10 &&
abs($"a") > 5)
.select($"a", $"b", $"c")
.groupBy($"a")($"a", $"a" + $"a" as "add", abs($"a") as "abs", count(1) as "ct")
.analyze
comparePlans(optimized2, correctAnswer2)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,26 @@ class InferFiltersFromConstraintsSuite extends PlanTest {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("InferAndPushDownFilters", FixedPoint(100),
Batch("PushDownFilters", FixedPoint(100),
PushPredicateThroughJoin,
PushPredicateThroughNonJoin,
PushPredicateThroughNonJoin) ::
Batch("Merge With expression", FixedPoint(10), MergeWithExpression) ::
Batch("Rewrite With expression", FixedPoint(10),
RewriteWithExpression,
CollapseProject) ::
Batch("InferFilters", FixedPoint(100),
InferFiltersFromConstraints,
CombineFilters,
SimplifyBinaryComparison,
BooleanSimplification,
PruneFilters) :: Nil
PruneFilters) ::
Batch("PushDownFilters", FixedPoint(100),
PushPredicateThroughJoin,
PushPredicateThroughNonJoin) ::
Batch("Merge With expression", FixedPoint(10), MergeWithExpression) ::
Batch("Rewrite With expression", FixedPoint(10),
RewriteWithExpression,
CollapseProject) :: Nil
}

val testRelation = LocalRelation($"a".int, $"b".int, $"c".int)
Expand Down Expand Up @@ -151,8 +163,9 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
.join(t2, Inner, Some("t.a".attr === "t2.a".attr && "t.int_col".attr === "t2.a".attr))
.analyze
val correctAnswer = t1
.where(IsNotNull($"a") && IsNotNull(Coalesce(Seq($"a", $"b"))) &&
$"a" === Coalesce(Seq($"a", $"b")))
.select($"a", $"b", $"c", Coalesce(Seq($"a", $"b")) as "_common_expr_0")
.where(IsNotNull($"a") && IsNotNull($"_common_expr_0") &&
$"a" === $"_common_expr_0")
.select($"a", Coalesce(Seq($"a", $"b")).as("int_col")).as("t")
.join(t2.where(IsNotNull($"a")), Inner,
Some("t.a".attr === "t2.a".attr && "t.int_col".attr === "t2.a".attr))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ class RewriteWithExpressionSuite extends PlanTest {
val commonExprDef2 = CommonExpressionDef(a + a, CommonExpressionId(2))
val ref2 = new CommonExpressionRef(commonExprDef2)
// The inner main expression references the outer expression
val innerExpr2 = With(ref2 + outerRef, Seq(commonExprDef2))
val innerExpr2 = With(ref2 + ref2 + outerRef, Seq(commonExprDef2))
val outerExpr2 = With(outerRef + innerExpr2, Seq(outerCommonExprDef))
comparePlans(
Optimizer.execute(testRelation.select(outerExpr2.as("col"))),
Expand All @@ -152,7 +152,8 @@ class RewriteWithExpressionSuite extends PlanTest {
.select(star(), (a + a).as("_common_expr_2"))
// The final Project contains the final result expression, which references both common
// expressions.
.select(($"_common_expr_0" + ($"_common_expr_2" + $"_common_expr_0")).as("col"))
.select(($"_common_expr_0" +
($"_common_expr_2" + $"_common_expr_2" + $"_common_expr_0")).as("col"))
.analyze
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,11 @@ class SparkOptimizer(
EliminateLimits,
ConstantFolding),
Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*),
Batch("Replace CTE with Repartition", Once, ReplaceCTERefWithRepartition)))
Batch("Replace CTE with Repartition", Once, ReplaceCTERefWithRepartition),
Batch("Merge With expression", fixedPoint, MergeWithExpression),
Batch("Rewrite With expression", fixedPoint,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how many places do we need to put this batch?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before Extract Python UDFs, infer filters and last of each optimizer.

RewriteWithExpression,
CollapseProject)))

override def nonExcludableRules: Seq[String] = super.nonExcludableRules ++
Seq(
Expand Down
Loading