Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor encrypt token generator and encrypt sql rewrite it #33662

Merged
merged 1 commit into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.apache.shardingsphere.infra.binder.context.segment.select.projection.impl.ColumnProjection;
import org.apache.shardingsphere.infra.binder.context.segment.select.projection.impl.ExpressionProjection;
import org.apache.shardingsphere.infra.binder.context.segment.select.projection.impl.ParameterMarkerProjection;
import org.apache.shardingsphere.infra.binder.context.segment.select.projection.impl.SubqueryProjection;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.ExpressionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.simple.LiteralExpressionSegment;
Expand Down Expand Up @@ -78,8 +79,12 @@ private static boolean isLiteralOrParameterMarker(final Projection projection) {
}

private static ColumnSegmentBoundInfo getColumnSegmentBoundInfo(final Projection projection) {
return projection instanceof ColumnProjection
? new ColumnSegmentBoundInfo(null, null, ((ColumnProjection) projection).getOriginalTable(), ((ColumnProjection) projection).getOriginalColumn())
: new ColumnSegmentBoundInfo(new IdentifierValue(projection.getColumnLabel()));
if (projection instanceof ColumnProjection) {
return ((ColumnProjection) projection).getColumnBoundInfo();
}
if (projection instanceof SubqueryProjection) {
return getColumnSegmentBoundInfo(((SubqueryProjection) projection).getProjection());
}
return new ColumnSegmentBoundInfo(new IdentifierValue(projection.getColumnLabel()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,5 @@ private void checkInsertSelectEncryptor(final SelectStatementContext selectState
ShardingSpherePreconditions.checkState(insertColumns.size() == projections.size(), () -> new UnsupportedSQLOperationException("Column count doesn't match value count."));
ShardingSpherePreconditions.checkState(InsertSelectColumnsEncryptorComparator.isSame(insertColumns, projections, encryptRule),
() -> new UnsupportedSQLOperationException("Can not use different encryptor in insert select columns"));
selectStatementContext.getSubqueryContexts().values().forEach(each -> checkInsertSelectEncryptor(each, insertColumns));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -203,11 +203,11 @@ private Collection<Projection> generateProjectionsInInsertSelectSubquery(final E
ParenthesesSegment leftParentheses = columnProjection.getLeftParentheses().orElse(null);
ParenthesesSegment rightParentheses = columnProjection.getRightParentheses().orElse(null);
result.add(new ColumnProjection(columnProjection.getOwner().orElse(null), columnName, null, databaseType, leftParentheses, rightParentheses));
IdentifierValue assistedColumOwner = columnProjection.getOwner().orElse(null);
encryptColumn.getAssistedQuery().ifPresent(
optional -> result.add(new ColumnProjection(assistedColumOwner, new IdentifierValue(optional.getName(), quoteCharacter), null, databaseType, leftParentheses, rightParentheses)));
encryptColumn.getLikeQuery().ifPresent(
optional -> result.add(new ColumnProjection(assistedColumOwner, new IdentifierValue(optional.getName(), quoteCharacter), null, databaseType, leftParentheses, rightParentheses)));
IdentifierValue columOwner = columnProjection.getOwner().orElse(null);
encryptColumn.getAssistedQuery()
.ifPresent(optional -> result.add(new ColumnProjection(columOwner, new IdentifierValue(optional.getName(), quoteCharacter), null, databaseType, leftParentheses, rightParentheses)));
encryptColumn.getLikeQuery()
.ifPresent(optional -> result.add(new ColumnProjection(columOwner, new IdentifierValue(optional.getName(), quoteCharacter), null, databaseType, leftParentheses, rightParentheses)));
return result;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicInteger;

Expand Down Expand Up @@ -138,16 +139,29 @@ private Optional<InsertSelectContext> getInsertSelectContext(final ShardingSpher
SubquerySegment insertSelectSegment = getSqlStatement().getInsertSelect().get();
SelectStatementContext selectStatementContext = new SelectStatementContext(metaData, params, insertSelectSegment.getSelect(), currentDatabaseName, Collections.emptyList());
selectStatementContext.setSubqueryType(SubqueryType.INSERT_SELECT);
setCombineSelectSubqueryType(selectStatementContext);
setProjectionSelectSubqueryType(selectStatementContext);
InsertSelectContext insertSelectContext = new InsertSelectContext(selectStatementContext, params, paramsOffset.get());
paramsOffset.addAndGet(insertSelectContext.getParameterCount());
return Optional.of(insertSelectContext);
}

private void setCombineSelectSubqueryType(final SelectStatementContext selectStatementContext) {
if (selectStatementContext.getSqlStatement().getCombine().isPresent()) {
CombineSegment combineSegment = selectStatementContext.getSqlStatement().getCombine().get();
Optional.ofNullable(selectStatementContext.getSubqueryContexts().get(combineSegment.getLeft().getStartIndex()))
.ifPresent(optional -> optional.setSubqueryType(SubqueryType.INSERT_SELECT));
Optional.ofNullable(selectStatementContext.getSubqueryContexts().get(combineSegment.getRight().getStartIndex()))
.ifPresent(optional -> optional.setSubqueryType(SubqueryType.INSERT_SELECT));
}
InsertSelectContext insertSelectContext = new InsertSelectContext(selectStatementContext, params, paramsOffset.get());
paramsOffset.addAndGet(insertSelectContext.getParameterCount());
return Optional.of(insertSelectContext);
}

private void setProjectionSelectSubqueryType(final SelectStatementContext selectStatementContext) {
for (Entry<Integer, SelectStatementContext> entry : selectStatementContext.getSubqueryContexts().entrySet()) {
if (entry.getKey() >= selectStatementContext.getProjectionsContext().getStartIndex() && entry.getKey() <= selectStatementContext.getProjectionsContext().getStopIndex()) {
entry.getValue().setSubqueryType(SubqueryType.INSERT_SELECT);
}
}
}

private Optional<OnDuplicateUpdateContext> getOnDuplicateKeyUpdateValueContext(final List<Object> params, final AtomicInteger parametersOffset) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.util.Collection;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
Expand Down Expand Up @@ -55,7 +56,8 @@ public final class StandardParameterBuilder implements ParameterBuilder {
public void addAddedParameters(final int index, final Collection<Object> params) {
addedParameterCount += params.size();
maxAddedParameterIndex = Math.max(maxAddedParameterIndex, index);
addedIndexAndParameters.put(index, params);
Collection<Object> existedAddedIndexAndParameters = addedIndexAndParameters.computeIfAbsent(index, unused -> new LinkedList<>());
existedAddedIndexAndParameters.addAll(params);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,24 +62,25 @@ private boolean isContainsAttachableToken(final SQLToken sqlToken, final SQLToke

private void appendRewriteSQL(final SQLToken sqlToken, final StringBuilder builder) {
builder.append(getSQLTokenText(sqlToken));
builder.append(getConjunctionText(sqlToken));
builder.append(getConjunctionText(sqlToken, sqlTokens, sql.length()));
}

protected abstract String getSQLTokenText(SQLToken sqlToken);

private String getConjunctionText(final SQLToken sqlToken) {
int startIndex = getStartIndex(sqlToken);
return sql.substring(startIndex, getStopIndex(sqlToken, startIndex));
private String getConjunctionText(final SQLToken sqlToken, final List<SQLToken> sqlTokens, final int sqlLength) {
int startIndex = getStartIndex(sqlToken, sqlLength);
int stopIndex = getStopIndex(sqlToken, sqlTokens, sqlLength, startIndex);
return sql.substring(startIndex, stopIndex);
}

private int getStartIndex(final SQLToken sqlToken) {
private int getStartIndex(final SQLToken sqlToken, final int sqlLength) {
int startIndex = sqlToken instanceof Substitutable ? ((Substitutable) sqlToken).getStopIndex() + 1 : sqlToken.getStartIndex();
return Math.min(startIndex, sql.length());
return Math.min(startIndex, sqlLength);
}

private int getStopIndex(final SQLToken sqlToken, final int startIndex) {
private int getStopIndex(final SQLToken sqlToken, final List<SQLToken> sqlTokens, final int sqlLength, final int startIndex) {
int currentSQLTokenIndex = sqlTokens.indexOf(sqlToken);
int stopIndex = sqlTokens.size() - 1 == currentSQLTokenIndex ? sql.length() : sqlTokens.get(currentSQLTokenIndex + 1).getStartIndex();
return startIndex <= stopIndex ? stopIndex : getStopIndex(sqlTokens.get(currentSQLTokenIndex + 1), startIndex);
int stopIndex = sqlTokens.size() - 1 == currentSQLTokenIndex ? sqlLength : sqlTokens.get(currentSQLTokenIndex + 1).getStartIndex();
return startIndex <= stopIndex ? stopIndex : getStopIndex(sqlTokens.get(currentSQLTokenIndex + 1), sqlTokens, sqlLength, startIndex);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,20 @@ public final class SubstitutableColumnNameToken extends SQLToken implements Subs
@Getter
private final int stopIndex;

@Getter
private final Collection<Projection> projections;

private final QuoteCharacter quoteCharacter;

@Getter
private final DatabaseType databaseType;

public SubstitutableColumnNameToken(final int startIndex, final int stopIndex, final Collection<Projection> projections, final DatabaseType databaseType) {
super(startIndex);
this.stopIndex = stopIndex;
quoteCharacter = new DatabaseTypeRegistry(databaseType).getDialectDatabaseMetaData().getQuoteCharacter();
this.projections = projections;
this.databaseType = databaseType;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ void assertRewrite(final SQLRewriteEngineTestParameters testParams) throws IOExc
assertThat(actual.size(), is(testParams.getOutputSQLs().size()));
int count = 0;
for (SQLRewriteUnit each : actual) {
assertThat(each.getSql(), is(testParams.getOutputSQLs().get(count)));
assertThat(each.getSql().replace("\t", " "), is(testParams.getOutputSQLs().get(count).replace("\t", " ")));
assertThat(each.getParameters().size(), is(testParams.getOutputGroupedParameters().get(count).size()));
for (int i = 0; i < each.getParameters().size(); i++) {
assertThat(String.valueOf(each.getParameters().get(i)), is(String.valueOf(testParams.getOutputGroupedParameters().get(count).get(i))));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import lombok.Getter;
import lombok.RequiredArgsConstructor;
import org.apache.shardingsphere.test.it.rewrite.engine.type.SQLExecuteType;

import java.util.List;

Expand Down Expand Up @@ -47,8 +48,10 @@ public final class SQLRewriteEngineTestParameters {

private final String databaseType;

private final SQLExecuteType sqlExecuteType;

@Override
public String toString() {
return String.format("{%s}: {%s} ({%s}) -> {%s}", type, name, databaseType, fileName);
return String.format("{%s}: {%s} ({%s}) -> {%s}", type + "/" + sqlExecuteType, name, databaseType, fileName);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import lombok.NoArgsConstructor;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.math.NumberUtils;
import org.apache.shardingsphere.test.it.rewrite.engine.type.SQLExecuteType;
import org.apache.shardingsphere.test.it.rewrite.entity.RewriteAssertionEntity;
import org.apache.shardingsphere.test.it.rewrite.entity.RewriteAssertionsRootEntity;
import org.apache.shardingsphere.test.it.rewrite.entity.RewriteOutputEntity;
Expand Down Expand Up @@ -97,8 +98,10 @@ private static Collection<SQLRewriteEngineTestParameters> createTestParameters(f
Collection<SQLRewriteEngineTestParameters> result = new LinkedList<>();
for (RewriteAssertionEntity each : rootAssertions.getAssertions()) {
for (String databaseType : getDatabaseTypes(each.getDatabaseTypes())) {
// TODO support appendLiteralCases for exits cases and remove duplicate cases
SQLExecuteType sqlExecuteType = null == each.getInput().getParameters() || each.getInput().getParameters().isEmpty() ? SQLExecuteType.LITERAL : SQLExecuteType.PLACEHOLDER;
result.add(new SQLRewriteEngineTestParameters(type, each.getId(), fileName, rootAssertions.getYamlRule(), each.getInput().getSql(),
createParameters(each.getInput().getParameters()), createOutputSQLs(each.getOutputs()), createOutputGroupedParameters(each.getOutputs()), databaseType));
createParameters(each.getInput().getParameters()), createOutputSQLs(each.getOutputs()), createOutputGroupedParameters(each.getOutputs()), databaseType, sqlExecuteType));
}
}
return result;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.shardingsphere.test.it.rewrite.engine.type;

/**
* SQL execute type.
*/
public enum SQLExecuteType {

LITERAL, PLACEHOLDER
}
Loading