Skip to content

Commit

Permalink
[BugFix][CherryPick] fix aggregation without rollup rewrite for 3.0 (#…
Browse files Browse the repository at this point in the history
…29698) (#30630)

Signed-off-by: ABingHuang <[email protected]>
  • Loading branch information
ABingHuang authored Sep 11, 2023
1 parent cff06ce commit 1b21c2a
Show file tree
Hide file tree
Showing 23 changed files with 490 additions and 201 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@
import java.math.BigInteger;
import java.math.RoundingMode;
import java.text.DecimalFormat;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.temporal.ChronoUnit;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Function;

import static com.starrocks.catalog.Type.TINYINT;
import static java.util.Collections.emptyList;
Expand Down Expand Up @@ -73,6 +75,9 @@ public final class ConstantOperator extends ScalarOperator implements Comparable
public static final ConstantOperator TRUE = ConstantOperator.createBoolean(true);
public static final ConstantOperator FALSE = ConstantOperator.createBoolean(false);

private static final BigInteger MAX_LARGE_INT = new BigInteger("2").pow(127).subtract(BigInteger.ONE);
private static final BigInteger MIN_LARGE_INT = new BigInteger("2").pow(128).multiply(BigInteger.valueOf(-1));

private static void requiredValid(LocalDateTime dateTime) throws SemanticException {
if (null == dateTime || dateTime.isBefore(MIN_DATETIME) || dateTime.isAfter(MAX_DATETIME)) {
throw new SemanticException("Invalid date value: " + (dateTime == null ? "NULL" : dateTime.toString()));
Expand Down Expand Up @@ -522,4 +527,62 @@ public ConstantOperator castTo(Type desc) throws Exception {

throw UnsupportedException.unsupportedException(this + " cast to " + desc.getPrimitiveType().toString());
}

public Optional<ConstantOperator> successor() {
return computeValue(1);
}

public Optional<ConstantOperator> predecessor() {
return computeValue(-1);
}

private Optional<ConstantOperator> computeValue(int delta) {
return computeWithLimits(delta,
v -> (byte) (v + delta),
v -> (short) (v + delta),
v -> v + delta,
v -> (long) v + delta,
v -> v.add(BigInteger.valueOf(delta)),
date -> date.plus(delta, ChronoUnit.DAYS),
date -> date.plus(delta, ChronoUnit.SECONDS)
);
}

private Optional<ConstantOperator> computeWithLimits(int delta,
Function<Byte, Byte> byteFunc,
Function<Short, Short> smallFunc,
Function<Integer, Integer> intFunc,
Function<Long, Long> longFunc,
Function<BigInteger, BigInteger> bigintFunc,
Function<LocalDateTime, LocalDateTime> dateFunc,
Function<LocalDateTime, LocalDateTime> datetimeFunc) {
if (type.isTinyint()) {
return compute(delta, getTinyInt(), Byte.MAX_VALUE, Byte.MIN_VALUE, byteFunc, ConstantOperator::createTinyInt);
} else if (type.isSmallint()) {
return compute(delta, getSmallint(), Short.MAX_VALUE, Short.MIN_VALUE, smallFunc, ConstantOperator::createSmallInt);
} else if (type.isInt()) {
return compute(delta, getInt(), Integer.MAX_VALUE, Integer.MIN_VALUE, intFunc, ConstantOperator::createInt);
} else if (type.isBigint()) {
return compute(delta, getBigint(), Long.MAX_VALUE, Long.MIN_VALUE, longFunc, ConstantOperator::createBigint);
} else if (type.isLargeint()) {
return compute(delta, getLargeInt(), MAX_LARGE_INT, MIN_LARGE_INT, bigintFunc, ConstantOperator::createLargeInt);
} else if (type.isDatetime()) {
return compute(delta, (LocalDateTime) value, LocalDateTime.MAX, LocalDateTime.MIN,
datetimeFunc, ConstantOperator::createDatetime);
} else if (type.isDateType()) {
return compute(delta, (LocalDateTime) value, LocalDate.MAX.atStartOfDay(), LocalDate.MIN.atStartOfDay(),
dateFunc, ConstantOperator::createDate);
} else {
return Optional.empty();
}
}

private <T> Optional<ConstantOperator> compute(int delta,
T value, T maxValue, T minValue, Function<T, T> func, Function<T, ConstantOperator> creator) {
if ((delta > 0 && value.equals(maxValue)) || (delta < 0 && value.equals(minValue))) {
return Optional.empty();
} else {
return Optional.of(creator.apply(func.apply(value)));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -261,4 +261,13 @@ public static boolean isColumnEqualBinaryPredicate(ScalarOperator predicate) {
}
return false;
}

public static boolean isColumnEqualConstant(ScalarOperator predicate) {
if (predicate instanceof BinaryPredicateOperator) {
BinaryPredicateOperator binaryPredicate = (BinaryPredicateOperator) predicate;
return binaryPredicate.getBinaryType().isEquivalence()
&& binaryPredicate.getChild(0).isColumnRef() && binaryPredicate.getChild(1).isConstantRef();
}
return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,88 +15,26 @@

package com.starrocks.sql.optimizer.rewrite.scalar;

import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import com.starrocks.catalog.Function;
import com.starrocks.catalog.FunctionSet;
import com.starrocks.catalog.Type;
import com.starrocks.sql.optimizer.Utils;
import com.starrocks.sql.optimizer.operator.scalar.BinaryPredicateOperator;
import com.starrocks.sql.optimizer.operator.scalar.BinaryPredicateOperator.BinaryType;
import com.starrocks.sql.optimizer.operator.scalar.CallOperator;
import com.starrocks.sql.optimizer.operator.scalar.ColumnRefOperator;
import com.starrocks.sql.optimizer.operator.scalar.CompoundPredicateOperator;
import com.starrocks.sql.optimizer.operator.scalar.ConstantOperator;
import com.starrocks.sql.optimizer.operator.scalar.InPredicateOperator;
import com.starrocks.sql.optimizer.operator.scalar.ScalarOperator;
import com.starrocks.sql.optimizer.rewrite.ScalarOperatorRewriteContext;

import java.math.BigInteger;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

import static com.starrocks.sql.optimizer.operator.scalar.ScalarOperatorUtil.findArithmeticFunction;

public class MvNormalizePredicateRule extends NormalizePredicateRule {
// Normalize Binary Predicate
// for integer type:
// a < 3 => a <= 2
// a > 3 => a >= 4
// a = 3 => a >= 3 and a <= 3
// a != 3 => a > 3 and a < 3 (which will be normalized further)
//
@Override
public ScalarOperator visitBinaryPredicate(BinaryPredicateOperator predicate,
ScalarOperatorRewriteContext context) {
ScalarOperator tmp = super.visitBinaryPredicate(predicate, context);
Preconditions.checkState(tmp instanceof BinaryPredicateOperator);
BinaryPredicateOperator binary = (BinaryPredicateOperator) tmp;
if (binary.getChild(0).isVariable() && binary.getChild(1).isConstantRef()) {
ConstantOperator constantOperator = (ConstantOperator) binary.getChild(1);
if (!constantOperator.getType().isIntegerType()) {
return tmp;
}
ConstantOperator one = createConstantIntegerOne(constantOperator.getType());
if (one == null) {
return tmp;
}
Type[] argsType = {constantOperator.getType(), constantOperator.getType()};
switch (binary.getBinaryType()) {
case LT:
Function substractFn = findArithmeticFunction(argsType, FunctionSet.SUBTRACT);
CallOperator sub = new CallOperator(FunctionSet.SUBTRACT,
substractFn.getReturnType(), Lists.newArrayList(constantOperator, one), substractFn);
return new BinaryPredicateOperator(BinaryType.LE, binary.getChild(0), sub);
case GT:
Function addFn = findArithmeticFunction(argsType, FunctionSet.ADD);
CallOperator add = new CallOperator(
FunctionSet.ADD, addFn.getReturnType(), Lists.newArrayList(constantOperator, one), addFn);
return new BinaryPredicateOperator(BinaryType.GE, binary.getChild(0), add);
case EQ:
BinaryPredicateOperator gePart =
new BinaryPredicateOperator(BinaryType.GE, binary.getChild(0), constantOperator);
BinaryPredicateOperator lePart =
new BinaryPredicateOperator(BinaryType.LE, binary.getChild(0), constantOperator);
return new CompoundPredicateOperator(CompoundPredicateOperator.CompoundType.AND, gePart, lePart);
case NE:
BinaryPredicateOperator gtPart =
new BinaryPredicateOperator(BinaryType.GT, binary.getChild(0), constantOperator);
BinaryPredicateOperator ltPart =
new BinaryPredicateOperator(BinaryType.LT, binary.getChild(0), constantOperator);
return new CompoundPredicateOperator(CompoundPredicateOperator.CompoundType.OR, gtPart, ltPart);
default:
break;
}
}
return tmp;
}

// should maintain sequence for case:
// a like "%hello%" and (b * c = 100 or b * c = 200)
// (b * c = 200 or b * c = 100) and a like "%hello%"
Expand Down Expand Up @@ -170,19 +108,6 @@ public ScalarOperator visitInPredicate(InPredicateOperator predicate, ScalarOper
return isIn ? Utils.compoundOr(result) : Utils.compoundAnd(result);
}

private ConstantOperator createConstantIntegerOne(Type type) {
if (Type.SMALLINT.equals(type)) {
return ConstantOperator.createSmallInt((short) 1);
} else if (Type.INT.equals(type)) {
return ConstantOperator.createInt(1);
} else if (Type.BIGINT.equals(type)) {
return ConstantOperator.createBigint(1L);
} else if (Type.LARGEINT.equals(type)) {
return ConstantOperator.createLargeInt(BigInteger.ONE);
}
return null;
}

// NOTE: View-Delta Join may produce redundant compensation predicates as below.
// eg:
// A(pk: a1) <-> B (pk: b1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,23 +90,33 @@ public ScalarOperator toScalarOperator() {
List<ScalarOperator> orOperators = Lists.newArrayList();
for (Range<ConstantOperator> range : columnRanges.asRanges()) {
List<ScalarOperator> andOperators = Lists.newArrayList();
if (range.hasLowerBound() && range.hasUpperBound() && range.upperEndpoint().equals(range.lowerEndpoint())) {
andOperators.add(BinaryPredicateOperator.eq(columnRef, range.upperEndpoint()));
} else {
if (range.hasLowerBound()) {
if (range.lowerBoundType() == BoundType.CLOSED) {
andOperators.add(BinaryPredicateOperator.ge(columnRef, range.lowerEndpoint()));
} else {
andOperators.add(BinaryPredicateOperator.gt(columnRef, range.lowerEndpoint()));
}
if (range.hasLowerBound() && range.hasUpperBound()) {
if (range.lowerBoundType() == BoundType.CLOSED
&& range.upperBoundType() == BoundType.CLOSED
&& range.upperEndpoint().equals(range.lowerEndpoint())) {
orOperators.add(BinaryPredicateOperator.eq(columnRef, range.lowerEndpoint()));
continue;
} else if (range.lowerBoundType() == BoundType.CLOSED
&& range.upperBoundType() == BoundType.OPEN
&& range.lowerEndpoint().successor().isPresent()
&& range.upperEndpoint().equals(range.lowerEndpoint().successor().get())) {
orOperators.add(BinaryPredicateOperator.eq(columnRef, range.lowerEndpoint()));
continue;
}
}
if (range.hasLowerBound()) {
if (range.lowerBoundType() == BoundType.CLOSED) {
andOperators.add(BinaryPredicateOperator.ge(columnRef, range.lowerEndpoint()));
} else {
andOperators.add(BinaryPredicateOperator.gt(columnRef, range.lowerEndpoint()));
}
}

if (range.hasUpperBound()) {
if (range.upperBoundType() == BoundType.CLOSED) {
andOperators.add(BinaryPredicateOperator.le(columnRef, range.upperEndpoint()));
} else {
andOperators.add(BinaryPredicateOperator.lt(columnRef, range.upperEndpoint()));
}
if (range.hasUpperBound()) {
if (range.upperBoundType() == BoundType.CLOSED) {
andOperators.add(BinaryPredicateOperator.le(columnRef, range.upperEndpoint()));
} else {
andOperators.add(BinaryPredicateOperator.lt(columnRef, range.upperEndpoint()));
}
}
orOperators.add(Utils.compoundAnd(andOperators));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1117,9 +1117,8 @@ private OptExpression tryRewriteForRelationMapping(RewriteContext rewriteContext
}
final Operator.Builder newScanOpBuilder = OperatorBuilderFactory.build(mvScanOptExpression.getOp());
newScanOpBuilder.withOperator(mvScanOptExpression.getOp());
final ScalarOperator pruneFinalCompensationPredicate =
MvNormalizePredicateRule.pruneRedundantPredicates(finalCompensationPredicate);
newScanOpBuilder.setPredicate(pruneFinalCompensationPredicate);
final ScalarOperator normalizedPredicate = normalizePredicate(finalCompensationPredicate);
newScanOpBuilder.setPredicate(normalizedPredicate);
mvScanOptExpression = OptExpression.create(newScanOpBuilder.build());
mvScanOptExpression.setLogicalProperty(null);
deriveLogicalProperty(mvScanOptExpression);
Expand All @@ -1130,6 +1129,13 @@ private OptExpression tryRewriteForRelationMapping(RewriteContext rewriteContext
}
}

private ScalarOperator normalizePredicate(ScalarOperator predicate) {
ScalarOperator pruneFinalCompensationPredicate =
MvNormalizePredicateRule.pruneRedundantPredicates(predicate);
PredicateSplit split = PredicateSplit.splitPredicate(pruneFinalCompensationPredicate);
return Utils.compoundAnd(split.getEqualPredicates(), split.getRangePredicates(), split.getResidualPredicates());
}

private ScalarOperator getMVCompensationPredicate(RewriteContext rewriteContext,
ColumnRewriter rewriter,
Map<ColumnRefOperator, ScalarOperator> mvColumnRefToScalarOp,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,18 @@

package com.starrocks.sql.optimizer.rule.transformation.materialization;

import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.starrocks.sql.optimizer.Utils;
import com.starrocks.sql.optimizer.operator.scalar.BinaryPredicateOperator;
import com.starrocks.sql.optimizer.operator.scalar.ColumnRefOperator;
import com.starrocks.sql.optimizer.operator.scalar.ConstantOperator;
import com.starrocks.sql.optimizer.operator.scalar.InPredicateOperator;
import com.starrocks.sql.optimizer.operator.scalar.ScalarOperator;

import java.util.List;
import java.util.Map;
import java.util.Objects;

public class OrRangePredicate extends RangePredicate {
Expand Down Expand Up @@ -56,7 +62,45 @@ public ScalarOperator toScalarOperator() {
for (RangePredicate rangePredicate : childPredicates) {
children.add(rangePredicate.toScalarOperator());
}
return Utils.compoundOr(children);

ScalarOperator orPredicate = Utils.compoundOr(children);
List<ScalarOperator> childPredicates = Utils.extractDisjunctive(orPredicate);
Map<String, List<ScalarOperator>> columnPredicatesMap = Maps.newHashMap();
for (ScalarOperator rangePredicate : childPredicates) {
if (ScalarOperator.isColumnEqualConstant(rangePredicate)) {
BinaryPredicateOperator binaryEqPredicate = (BinaryPredicateOperator) rangePredicate;
ColumnRefOperator columnRef = binaryEqPredicate.getChild(0).cast();
List<ScalarOperator> columnRangePredicates = columnPredicatesMap.computeIfAbsent(
columnRef.getName(), k -> Lists.newArrayList());
columnRangePredicates.add(rangePredicate);
} else if (rangePredicate instanceof InPredicateOperator && rangePredicate.getChild(0).isColumnRef()) {
InPredicateOperator inPredicate = rangePredicate.cast();
List<ScalarOperator> columnRangePredicates = columnPredicatesMap.computeIfAbsent(
((ColumnRefOperator) inPredicate.getChild(0)).getName(), k -> Lists.newArrayList());
columnRangePredicates.add(rangePredicate);
}
}
for (List<ScalarOperator> value : columnPredicatesMap.values()) {
if (value.size() > 1) {
childPredicates.removeAll(value);
// add InPredicateOperator
List<ScalarOperator> arguments = Lists.newArrayList();
arguments.add(value.get(0).getChild(0));
for (ScalarOperator predicate : value) {
if (ScalarOperator.isColumnEqualConstant(predicate)) {
arguments.add(predicate.getChild(1));
} else {
// must be InPredicateOperator
Preconditions.checkState(predicate instanceof InPredicateOperator);
arguments.addAll(predicate.getChildren().subList(1, predicate.getChildren().size()));
}
}
InPredicateOperator inPredicateOperator = new InPredicateOperator(false, arguments);
childPredicates.add(inPredicateOperator);
}
}

return Utils.compoundOr(childPredicates);
}

// for
Expand Down
Loading

0 comments on commit 1b21c2a

Please sign in to comment.