Skip to content

Commit

Permalink
support agg pullup (#8923)
Browse files Browse the repository at this point in the history
Approved by: @aunjgr
  • Loading branch information
badboynt1 authored Apr 10, 2023
1 parent c335c8f commit cce87e6
Show file tree
Hide file tree
Showing 2 changed files with 205 additions and 8 deletions.
211 changes: 203 additions & 8 deletions pkg/sql/plan/agg_pushdown_pullup.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,9 @@ func replaceCol(expr *plan.Expr, oldRelPos, oldColPos, newRelPos, newColPos int3
}

case *plan.Expr_Col:
//for now, shouldAggPushDown make sure only one column in expr,and only one expr in exprlist, so new colpos is always 0
//if multi expr in agg list and group list supported in the future, this need to be fixed
if exprImpl.Col.RelPos == oldRelPos {
if exprImpl.Col.RelPos == oldRelPos && exprImpl.Col.ColPos == oldColPos {
exprImpl.Col.RelPos = newRelPos
exprImpl.Col.ColPos = 0
exprImpl.Col.ColPos = newColPos
}
}
}
Expand All @@ -95,7 +93,7 @@ func filterTag(expr *Expr, tag int32) *Expr {
return nil
}

func createNewAggNode(agg, join, leftChild *plan.Node, builder *QueryBuilder) {
func applyAggPushdown(agg, join, leftChild *plan.Node, builder *QueryBuilder) {
leftChildTag := leftChild.BindingTags[0]
newAggList := DeepCopyExprList(agg.AggList)
//newGroupBy := DeepCopyExprList(agg.GroupBy)
Expand All @@ -117,8 +115,11 @@ func createNewAggNode(agg, join, leftChild *plan.Node, builder *QueryBuilder) {
join.Children[0] = newNodeID

//replace relpos for exprs in join and agg node
replaceCol(join.OnList[0], leftChildTag, 0, newGroupTag, 0)
replaceCol(agg.AggList[0], leftChildTag, 0, newAggTag, 0)
colGroupBy, _ := filterTag(join.OnList[0], leftChildTag).Expr.(*plan.Expr_Col)
replaceCol(join.OnList[0], leftChildTag, colGroupBy.Col.ColPos, newGroupTag, 0)

colAgg, _ := filterTag(agg.AggList[0], leftChildTag).Expr.(*plan.Expr_Col)
replaceCol(agg.AggList[0], leftChildTag, colAgg.Col.ColPos, newAggTag, 0)
}

func (builder *QueryBuilder) aggPushDown(nodeID int32) int32 {
Expand All @@ -138,6 +139,7 @@ func (builder *QueryBuilder) aggPushDown(nodeID int32) int32 {
if join.NodeType != plan.Node_JOIN || join.JoinType != plan.Node_INNER {
return nodeID
}

//make sure left child is bigger and agg pushdown to left child
builder.applySwapRuleByStats(join.NodeId, false)

Expand All @@ -148,6 +150,199 @@ func (builder *QueryBuilder) aggPushDown(nodeID int32) int32 {
return nodeID
}

createNewAggNode(node, join, leftChild, builder)
applyAggPushdown(node, join, leftChild, builder)
return nodeID
}

func getJoinCondCol(cond *Expr, leftTag int32, rightTag int32) (*plan.Expr_Col, *plan.Expr_Col) {
fun, ok := cond.Expr.(*plan.Expr_F)
if !ok {
return nil, nil
}
leftCol, ok := fun.F.Args[0].Expr.(*plan.Expr_Col)
if !ok {
return nil, nil
}
rightCol, ok := fun.F.Args[1].Expr.(*plan.Expr_Col)
if !ok {
return nil, nil
}
if leftCol.Col.RelPos != leftTag {
leftCol, rightCol = rightCol, leftCol
}
if leftCol.Col.RelPos != leftTag || rightCol.Col.RelPos != rightTag {
return nil, nil
}
return leftCol, rightCol
}

func replaceAllColRefInExprList(exprlist []*plan.Expr, from *plan.Expr_Col, to *plan.Expr_Col) {
for _, expr := range exprlist {
replaceCol(expr, from.Col.RelPos, from.Col.ColPos, to.Col.RelPos, to.Col.ColPos)
}
}

func replaceAllColRefInPlan(nodeID int32, exceptID int32, from *plan.Expr_Col, to *plan.Expr_Col, builder *QueryBuilder) {
//change all nodes in plan, except join and its children
if nodeID == exceptID {
return
}
node := builder.qry.Nodes[nodeID]
if len(node.Children) > 0 {
for _, child := range node.Children {
replaceAllColRefInPlan(child, exceptID, from, to, builder)
}
}
replaceAllColRefInExprList(node.OnList, from, to)
replaceAllColRefInExprList(node.ProjectList, from, to)
replaceAllColRefInExprList(node.FilterList, from, to)
replaceAllColRefInExprList(node.AggList, from, to)
replaceAllColRefInExprList(node.GroupBy, from, to)
replaceAllColRefInExprList(node.GroupingSet, from, to)
for _, orderby := range node.OrderBy {
replaceCol(orderby.Expr, from.Col.RelPos, from.Col.ColPos, to.Col.RelPos, to.Col.ColPos)
}
}

func checkColRef(expr *plan.Expr, col *plan.Expr_Col) bool {
if expr == nil {
return true
}
switch exprImpl := expr.Expr.(type) {
case *plan.Expr_F:
for _, arg := range exprImpl.F.Args {
if !checkColRef(arg, col) {
return false
}
}

case *plan.Expr_Col:
if exprImpl.Col.RelPos == col.Col.RelPos && exprImpl.Col.ColPos != col.Col.ColPos {
return false
}
}
return true
}

func checkAllColRefInExprList(exprlist []*plan.Expr, col *plan.Expr_Col) bool {
for _, expr := range exprlist {
if !checkColRef(expr, col) {
return false
}
}
return true
}

func checkAllColRefInPlan(nodeID int32, exceptID int32, col *plan.Expr_Col, builder *QueryBuilder) bool {
//change all nodes in plan, except join and its children
if nodeID == exceptID {
return true
}
node := builder.qry.Nodes[nodeID]
if len(node.Children) > 0 {
for _, child := range node.Children {
if !checkAllColRefInPlan(child, exceptID, col, builder) {
return false
}
}
}
ret := true
ret = ret && checkAllColRefInExprList(node.OnList, col)
ret = ret && checkAllColRefInExprList(node.ProjectList, col)
ret = ret && checkAllColRefInExprList(node.FilterList, col)
ret = ret && checkAllColRefInExprList(node.AggList, col)
ret = ret && checkAllColRefInExprList(node.GroupBy, col)
ret = ret && checkAllColRefInExprList(node.GroupingSet, col)
for _, orderby := range node.OrderBy {
ret = ret && checkColRef(orderby.Expr, col)
}
return ret
}

func applyAggPullup(rootID int32, join, agg, leftScan, rightScan *plan.Node, builder *QueryBuilder) bool {
if len(agg.GroupBy) != 1 {
return false
}
groupColInAgg, ok := agg.GroupBy[0].Expr.(*plan.Expr_Col)
if !ok {
return false
}
if !IsEquiJoin(join.OnList) || len(join.OnList) != 1 {
return false
}

leftCol, rightCol := getJoinCondCol(join.OnList[0], agg.BindingTags[0], rightScan.BindingTags[0])
if leftCol == nil {
return false
}

//rightcol must be primary key of right table
// or we add rowid in group by, implement this in the future
pkDef := builder.compCtx.GetPrimaryKeyDef(rightScan.ObjRef.SchemaName, rightScan.ObjRef.ObjName)
if len(pkDef) != 1 {
return false
}
rightBinding := builder.ctxByNode[rightScan.NodeId].bindingByTag[rightScan.BindingTags[0]]
if rightBinding.FindColumn(pkDef[0].Name) != rightCol.Col.ColPos {
return false
}

if agg.Stats.Outcnt/leftScan.Stats.Outcnt < join.Stats.Outcnt/agg.Stats.Outcnt {
return false
}

//col ref to right table can not been seen after agg pulled up
//since join cond is leftcol=rightcol, we can change col ref from right col to left col
// and other col in right table must not be referenced
if !checkAllColRefInPlan(rootID, join.NodeId, rightCol, builder) {
return false
}
replaceAllColRefInPlan(rootID, join.NodeId, rightCol, leftCol, builder)

join.Children[0] = agg.Children[0]
agg.Children[0] = join.NodeId
leftCol.Col.RelPos = groupColInAgg.Col.RelPos
leftCol.Col.ColPos = groupColInAgg.Col.ColPos
return true

}

func (builder *QueryBuilder) aggPullup(rootID, nodeID int32) int32 {
// agg pullup only support node->inner join->agg for now
// we can change it to node->agg->inner join
node := builder.qry.Nodes[nodeID]

if len(node.Children) > 0 {
for i, child := range node.Children {
node.Children[i] = builder.aggPullup(rootID, child)
}
} else {
return nodeID
}

join := node
if join.NodeType != plan.Node_JOIN || join.JoinType != plan.Node_INNER {
return nodeID
}

//make sure left child is bigger
builder.applySwapRuleByStats(join.NodeId, false)

agg := builder.qry.Nodes[join.Children[0]]
if agg.NodeType != plan.Node_AGG {
return nodeID
}
leftScan := builder.qry.Nodes[agg.Children[0]]
if leftScan.NodeType != plan.Node_TABLE_SCAN {
return nodeID
}
rightScan := builder.qry.Nodes[join.Children[1]]
if rightScan.NodeType != plan.Node_TABLE_SCAN {
return nodeID
}

if applyAggPullup(rootID, join, agg, leftScan, rightScan, builder) {
return agg.NodeId
}
return nodeID
}
2 changes: 2 additions & 0 deletions pkg/sql/plan/query_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -827,6 +827,8 @@ func (builder *QueryBuilder) createQuery() (*Query, error) {
ReCalcNodeStats(rootID, builder, true, false)
rootID = builder.determineJoinOrder(rootID)
ReCalcNodeStats(rootID, builder, true, false)
rootID = builder.aggPullup(rootID, rootID)
ReCalcNodeStats(rootID, builder, true, false)
rootID = builder.pushdownSemiAntiJoins(rootID)
ReCalcNodeStats(rootID, builder, true, false)
rootID = builder.applySwapRuleByStats(rootID, true)
Expand Down

0 comments on commit cce87e6

Please sign in to comment.