From 895b34f0b539fcc46151ef95cc1387d8581cf9fc Mon Sep 17 00:00:00 2001 From: duanzhengqiang Date: Thu, 14 Nov 2024 17:58:41 +0800 Subject: [PATCH] Refactor encrypt token generator and encrypt sql rewrite it --- ...nsertSelectColumnsEncryptorComparator.java | 11 +++++--- ...EncryptInsertCipherNameTokenGenerator.java | 1 - .../EncryptProjectionTokenGenerator.java | 10 +++---- .../statement/dml/InsertStatementContext.java | 20 +++++++++++--- .../impl/StandardParameterBuilder.java | 4 ++- .../rewrite/sql/impl/AbstractSQLBuilder.java | 19 +++++++------- .../generic/SubstitutableColumnNameToken.java | 5 ++++ .../test/it/rewrite/engine/SQLRewriterIT.java | 2 +- .../SQLRewriteEngineTestParameters.java | 5 +++- ...SQLRewriteEngineTestParametersBuilder.java | 5 +++- .../rewrite/engine/type/SQLExecuteType.java | 26 +++++++++++++++++++ 11 files changed, 83 insertions(+), 25 deletions(-) create mode 100644 test/it/rewriter/src/test/java/org/apache/shardingsphere/test/it/rewrite/engine/type/SQLExecuteType.java diff --git a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/comparator/InsertSelectColumnsEncryptorComparator.java b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/comparator/InsertSelectColumnsEncryptorComparator.java index 20ea0dbf69673..b53605e9e02f9 100644 --- a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/comparator/InsertSelectColumnsEncryptorComparator.java +++ b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/comparator/InsertSelectColumnsEncryptorComparator.java @@ -25,6 +25,7 @@ import org.apache.shardingsphere.infra.binder.context.segment.select.projection.impl.ColumnProjection; import org.apache.shardingsphere.infra.binder.context.segment.select.projection.impl.ExpressionProjection; import org.apache.shardingsphere.infra.binder.context.segment.select.projection.impl.ParameterMarkerProjection; +import org.apache.shardingsphere.infra.binder.context.segment.select.projection.impl.SubqueryProjection; import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.column.ColumnSegment; import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.ExpressionSegment; import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.simple.LiteralExpressionSegment; @@ -78,8 +79,12 @@ private static boolean isLiteralOrParameterMarker(final Projection projection) { } private static ColumnSegmentBoundInfo getColumnSegmentBoundInfo(final Projection projection) { - return projection instanceof ColumnProjection - ? new ColumnSegmentBoundInfo(null, null, ((ColumnProjection) projection).getOriginalTable(), ((ColumnProjection) projection).getOriginalColumn()) - : new ColumnSegmentBoundInfo(new IdentifierValue(projection.getColumnLabel())); + if (projection instanceof ColumnProjection) { + return ((ColumnProjection) projection).getColumnBoundInfo(); + } + if (projection instanceof SubqueryProjection) { + return getColumnSegmentBoundInfo(((SubqueryProjection) projection).getProjection()); + } + return new ColumnSegmentBoundInfo(new IdentifierValue(projection.getColumnLabel())); } } diff --git a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/insert/EncryptInsertCipherNameTokenGenerator.java b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/insert/EncryptInsertCipherNameTokenGenerator.java index 989421390da16..f6e6d7e86013c 100644 --- a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/insert/EncryptInsertCipherNameTokenGenerator.java +++ b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/insert/EncryptInsertCipherNameTokenGenerator.java @@ -87,6 +87,5 @@ private void checkInsertSelectEncryptor(final SelectStatementContext selectState ShardingSpherePreconditions.checkState(insertColumns.size() == projections.size(), () -> new UnsupportedSQLOperationException("Column count doesn't match value count.")); ShardingSpherePreconditions.checkState(InsertSelectColumnsEncryptorComparator.isSame(insertColumns, projections, encryptRule), () -> new UnsupportedSQLOperationException("Can not use different encryptor in insert select columns")); - selectStatementContext.getSubqueryContexts().values().forEach(each -> checkInsertSelectEncryptor(each, insertColumns)); } } diff --git a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/projection/EncryptProjectionTokenGenerator.java b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/projection/EncryptProjectionTokenGenerator.java index 035b6fa259aa5..facfea346055b 100644 --- a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/projection/EncryptProjectionTokenGenerator.java +++ b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/projection/EncryptProjectionTokenGenerator.java @@ -203,11 +203,11 @@ private Collection generateProjectionsInInsertSelectSubquery(final E ParenthesesSegment leftParentheses = columnProjection.getLeftParentheses().orElse(null); ParenthesesSegment rightParentheses = columnProjection.getRightParentheses().orElse(null); result.add(new ColumnProjection(columnProjection.getOwner().orElse(null), columnName, null, databaseType, leftParentheses, rightParentheses)); - IdentifierValue assistedColumOwner = columnProjection.getOwner().orElse(null); - encryptColumn.getAssistedQuery().ifPresent( - optional -> result.add(new ColumnProjection(assistedColumOwner, new IdentifierValue(optional.getName(), quoteCharacter), null, databaseType, leftParentheses, rightParentheses))); - encryptColumn.getLikeQuery().ifPresent( - optional -> result.add(new ColumnProjection(assistedColumOwner, new IdentifierValue(optional.getName(), quoteCharacter), null, databaseType, leftParentheses, rightParentheses))); + IdentifierValue columOwner = columnProjection.getOwner().orElse(null); + encryptColumn.getAssistedQuery() + .ifPresent(optional -> result.add(new ColumnProjection(columOwner, new IdentifierValue(optional.getName(), quoteCharacter), null, databaseType, leftParentheses, rightParentheses))); + encryptColumn.getLikeQuery() + .ifPresent(optional -> result.add(new ColumnProjection(columOwner, new IdentifierValue(optional.getName(), quoteCharacter), null, databaseType, leftParentheses, rightParentheses))); return result; } diff --git a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/statement/dml/InsertStatementContext.java b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/statement/dml/InsertStatementContext.java index 5bea4c943ee0e..264c03ac33799 100644 --- a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/statement/dml/InsertStatementContext.java +++ b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/statement/dml/InsertStatementContext.java @@ -55,6 +55,7 @@ import java.util.LinkedList; import java.util.List; import java.util.Map; +import java.util.Map.Entry; import java.util.Optional; import java.util.concurrent.atomic.AtomicInteger; @@ -138,6 +139,14 @@ private Optional getInsertSelectContext(final ShardingSpher SubquerySegment insertSelectSegment = getSqlStatement().getInsertSelect().get(); SelectStatementContext selectStatementContext = new SelectStatementContext(metaData, params, insertSelectSegment.getSelect(), currentDatabaseName, Collections.emptyList()); selectStatementContext.setSubqueryType(SubqueryType.INSERT_SELECT); + setCombineSelectSubqueryType(selectStatementContext); + setProjectionSelectSubqueryType(selectStatementContext); + InsertSelectContext insertSelectContext = new InsertSelectContext(selectStatementContext, params, paramsOffset.get()); + paramsOffset.addAndGet(insertSelectContext.getParameterCount()); + return Optional.of(insertSelectContext); + } + + private void setCombineSelectSubqueryType(final SelectStatementContext selectStatementContext) { if (selectStatementContext.getSqlStatement().getCombine().isPresent()) { CombineSegment combineSegment = selectStatementContext.getSqlStatement().getCombine().get(); Optional.ofNullable(selectStatementContext.getSubqueryContexts().get(combineSegment.getLeft().getStartIndex())) @@ -145,9 +154,14 @@ private Optional getInsertSelectContext(final ShardingSpher Optional.ofNullable(selectStatementContext.getSubqueryContexts().get(combineSegment.getRight().getStartIndex())) .ifPresent(optional -> optional.setSubqueryType(SubqueryType.INSERT_SELECT)); } - InsertSelectContext insertSelectContext = new InsertSelectContext(selectStatementContext, params, paramsOffset.get()); - paramsOffset.addAndGet(insertSelectContext.getParameterCount()); - return Optional.of(insertSelectContext); + } + + private void setProjectionSelectSubqueryType(final SelectStatementContext selectStatementContext) { + for (Entry entry : selectStatementContext.getSubqueryContexts().entrySet()) { + if (entry.getKey() >= selectStatementContext.getProjectionsContext().getStartIndex() && entry.getKey() <= selectStatementContext.getProjectionsContext().getStopIndex()) { + entry.getValue().setSubqueryType(SubqueryType.INSERT_SELECT); + } + } } private Optional getOnDuplicateKeyUpdateValueContext(final List params, final AtomicInteger parametersOffset) { diff --git a/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/parameter/builder/impl/StandardParameterBuilder.java b/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/parameter/builder/impl/StandardParameterBuilder.java index 2c1d5fbf11381..abd15210a1c5e 100644 --- a/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/parameter/builder/impl/StandardParameterBuilder.java +++ b/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/parameter/builder/impl/StandardParameterBuilder.java @@ -25,6 +25,7 @@ import java.util.Collection; import java.util.HashMap; import java.util.LinkedHashMap; +import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Map.Entry; @@ -55,7 +56,8 @@ public final class StandardParameterBuilder implements ParameterBuilder { public void addAddedParameters(final int index, final Collection params) { addedParameterCount += params.size(); maxAddedParameterIndex = Math.max(maxAddedParameterIndex, index); - addedIndexAndParameters.put(index, params); + Collection existedAddedIndexAndParameters = addedIndexAndParameters.computeIfAbsent(index, unused -> new LinkedList<>()); + existedAddedIndexAndParameters.addAll(params); } /** diff --git a/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/sql/impl/AbstractSQLBuilder.java b/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/sql/impl/AbstractSQLBuilder.java index c6f93979d2c00..ec94f012d4d13 100644 --- a/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/sql/impl/AbstractSQLBuilder.java +++ b/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/sql/impl/AbstractSQLBuilder.java @@ -62,24 +62,25 @@ private boolean isContainsAttachableToken(final SQLToken sqlToken, final SQLToke private void appendRewriteSQL(final SQLToken sqlToken, final StringBuilder builder) { builder.append(getSQLTokenText(sqlToken)); - builder.append(getConjunctionText(sqlToken)); + builder.append(getConjunctionText(sqlToken, sqlTokens, sql.length())); } protected abstract String getSQLTokenText(SQLToken sqlToken); - private String getConjunctionText(final SQLToken sqlToken) { - int startIndex = getStartIndex(sqlToken); - return sql.substring(startIndex, getStopIndex(sqlToken, startIndex)); + private String getConjunctionText(final SQLToken sqlToken, final List sqlTokens, final int sqlLength) { + int startIndex = getStartIndex(sqlToken, sqlLength); + int stopIndex = getStopIndex(sqlToken, sqlTokens, sqlLength, startIndex); + return sql.substring(startIndex, stopIndex); } - private int getStartIndex(final SQLToken sqlToken) { + private int getStartIndex(final SQLToken sqlToken, final int sqlLength) { int startIndex = sqlToken instanceof Substitutable ? ((Substitutable) sqlToken).getStopIndex() + 1 : sqlToken.getStartIndex(); - return Math.min(startIndex, sql.length()); + return Math.min(startIndex, sqlLength); } - private int getStopIndex(final SQLToken sqlToken, final int startIndex) { + private int getStopIndex(final SQLToken sqlToken, final List sqlTokens, final int sqlLength, final int startIndex) { int currentSQLTokenIndex = sqlTokens.indexOf(sqlToken); - int stopIndex = sqlTokens.size() - 1 == currentSQLTokenIndex ? sql.length() : sqlTokens.get(currentSQLTokenIndex + 1).getStartIndex(); - return startIndex <= stopIndex ? stopIndex : getStopIndex(sqlTokens.get(currentSQLTokenIndex + 1), startIndex); + int stopIndex = sqlTokens.size() - 1 == currentSQLTokenIndex ? sqlLength : sqlTokens.get(currentSQLTokenIndex + 1).getStartIndex(); + return startIndex <= stopIndex ? stopIndex : getStopIndex(sqlTokens.get(currentSQLTokenIndex + 1), sqlTokens, sqlLength, startIndex); } } diff --git a/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/sql/token/common/pojo/generic/SubstitutableColumnNameToken.java b/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/sql/token/common/pojo/generic/SubstitutableColumnNameToken.java index af2836cbeb75f..a28afbe6e16d6 100644 --- a/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/sql/token/common/pojo/generic/SubstitutableColumnNameToken.java +++ b/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/sql/token/common/pojo/generic/SubstitutableColumnNameToken.java @@ -45,15 +45,20 @@ public final class SubstitutableColumnNameToken extends SQLToken implements Subs @Getter private final int stopIndex; + @Getter private final Collection projections; private final QuoteCharacter quoteCharacter; + @Getter + private final DatabaseType databaseType; + public SubstitutableColumnNameToken(final int startIndex, final int stopIndex, final Collection projections, final DatabaseType databaseType) { super(startIndex); this.stopIndex = stopIndex; quoteCharacter = new DatabaseTypeRegistry(databaseType).getDialectDatabaseMetaData().getQuoteCharacter(); this.projections = projections; + this.databaseType = databaseType; } @Override diff --git a/test/it/rewriter/src/test/java/org/apache/shardingsphere/test/it/rewrite/engine/SQLRewriterIT.java b/test/it/rewriter/src/test/java/org/apache/shardingsphere/test/it/rewrite/engine/SQLRewriterIT.java index d6cfd5f5bba2a..4e7146180fa3a 100644 --- a/test/it/rewriter/src/test/java/org/apache/shardingsphere/test/it/rewrite/engine/SQLRewriterIT.java +++ b/test/it/rewriter/src/test/java/org/apache/shardingsphere/test/it/rewrite/engine/SQLRewriterIT.java @@ -105,7 +105,7 @@ void assertRewrite(final SQLRewriteEngineTestParameters testParams) throws IOExc assertThat(actual.size(), is(testParams.getOutputSQLs().size())); int count = 0; for (SQLRewriteUnit each : actual) { - assertThat(each.getSql(), is(testParams.getOutputSQLs().get(count))); + assertThat(each.getSql().replace("\t", " "), is(testParams.getOutputSQLs().get(count).replace("\t", " "))); assertThat(each.getParameters().size(), is(testParams.getOutputGroupedParameters().get(count).size())); for (int i = 0; i < each.getParameters().size(); i++) { assertThat(String.valueOf(each.getParameters().get(i)), is(String.valueOf(testParams.getOutputGroupedParameters().get(count).get(i)))); diff --git a/test/it/rewriter/src/test/java/org/apache/shardingsphere/test/it/rewrite/engine/parameter/SQLRewriteEngineTestParameters.java b/test/it/rewriter/src/test/java/org/apache/shardingsphere/test/it/rewrite/engine/parameter/SQLRewriteEngineTestParameters.java index 41e09eff4780e..2e6f507661770 100644 --- a/test/it/rewriter/src/test/java/org/apache/shardingsphere/test/it/rewrite/engine/parameter/SQLRewriteEngineTestParameters.java +++ b/test/it/rewriter/src/test/java/org/apache/shardingsphere/test/it/rewrite/engine/parameter/SQLRewriteEngineTestParameters.java @@ -19,6 +19,7 @@ import lombok.Getter; import lombok.RequiredArgsConstructor; +import org.apache.shardingsphere.test.it.rewrite.engine.type.SQLExecuteType; import java.util.List; @@ -47,8 +48,10 @@ public final class SQLRewriteEngineTestParameters { private final String databaseType; + private final SQLExecuteType sqlExecuteType; + @Override public String toString() { - return String.format("{%s}: {%s} ({%s}) -> {%s}", type, name, databaseType, fileName); + return String.format("{%s}: {%s} ({%s}) -> {%s}", type + "/" + sqlExecuteType, name, databaseType, fileName); } } diff --git a/test/it/rewriter/src/test/java/org/apache/shardingsphere/test/it/rewrite/engine/parameter/SQLRewriteEngineTestParametersBuilder.java b/test/it/rewriter/src/test/java/org/apache/shardingsphere/test/it/rewrite/engine/parameter/SQLRewriteEngineTestParametersBuilder.java index 17eb328007a6a..5a6e3d7ca33bb 100644 --- a/test/it/rewriter/src/test/java/org/apache/shardingsphere/test/it/rewrite/engine/parameter/SQLRewriteEngineTestParametersBuilder.java +++ b/test/it/rewriter/src/test/java/org/apache/shardingsphere/test/it/rewrite/engine/parameter/SQLRewriteEngineTestParametersBuilder.java @@ -23,6 +23,7 @@ import lombok.NoArgsConstructor; import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.math.NumberUtils; +import org.apache.shardingsphere.test.it.rewrite.engine.type.SQLExecuteType; import org.apache.shardingsphere.test.it.rewrite.entity.RewriteAssertionEntity; import org.apache.shardingsphere.test.it.rewrite.entity.RewriteAssertionsRootEntity; import org.apache.shardingsphere.test.it.rewrite.entity.RewriteOutputEntity; @@ -97,8 +98,10 @@ private static Collection createTestParameters(f Collection result = new LinkedList<>(); for (RewriteAssertionEntity each : rootAssertions.getAssertions()) { for (String databaseType : getDatabaseTypes(each.getDatabaseTypes())) { + // TODO support appendLiteralCases for exits cases and remove duplicate cases + SQLExecuteType sqlExecuteType = null == each.getInput().getParameters() || each.getInput().getParameters().isEmpty() ? SQLExecuteType.LITERAL : SQLExecuteType.PLACEHOLDER; result.add(new SQLRewriteEngineTestParameters(type, each.getId(), fileName, rootAssertions.getYamlRule(), each.getInput().getSql(), - createParameters(each.getInput().getParameters()), createOutputSQLs(each.getOutputs()), createOutputGroupedParameters(each.getOutputs()), databaseType)); + createParameters(each.getInput().getParameters()), createOutputSQLs(each.getOutputs()), createOutputGroupedParameters(each.getOutputs()), databaseType, sqlExecuteType)); } } return result; diff --git a/test/it/rewriter/src/test/java/org/apache/shardingsphere/test/it/rewrite/engine/type/SQLExecuteType.java b/test/it/rewriter/src/test/java/org/apache/shardingsphere/test/it/rewrite/engine/type/SQLExecuteType.java new file mode 100644 index 0000000000000..f52275b999865 --- /dev/null +++ b/test/it/rewriter/src/test/java/org/apache/shardingsphere/test/it/rewrite/engine/type/SQLExecuteType.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.shardingsphere.test.it.rewrite.engine.type; + +/** + * SQL execute type. + */ +public enum SQLExecuteType { + + LITERAL, PLACEHOLDER +}