Skip to content

Commit

Permalink
Support encrypt insert select sql rewrite (#28425)
Browse files Browse the repository at this point in the history
  • Loading branch information
strongduanmu authored Sep 13, 2023
1 parent 6c04da1 commit 05d5a9b
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.shardingsphere.encrypt.rewrite.token.generator;

import com.google.common.base.Preconditions;
import lombok.Setter;
import org.apache.shardingsphere.encrypt.rewrite.aware.DatabaseTypeAware;
import org.apache.shardingsphere.encrypt.rewrite.aware.EncryptRuleAware;
Expand All @@ -30,6 +29,7 @@
import org.apache.shardingsphere.infra.binder.context.segment.select.projection.impl.ColumnProjection;
import org.apache.shardingsphere.infra.binder.context.segment.select.projection.impl.ShorthandProjection;
import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.context.statement.dml.InsertStatementContext;
import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.database.core.metadata.database.DialectDatabaseMetaData;
import org.apache.shardingsphere.infra.database.core.metadata.database.enums.QuoteCharacter;
Expand Down Expand Up @@ -69,21 +69,28 @@ public final class EncryptProjectionTokenGenerator implements CollectionSQLToken

@Override
public boolean isGenerateSQLToken(final SQLStatementContext sqlStatementContext) {
return sqlStatementContext instanceof SelectStatementContext && !((SelectStatementContext) sqlStatementContext).getAllTables().isEmpty();
return sqlStatementContext instanceof SelectStatementContext && !((SelectStatementContext) sqlStatementContext).getAllTables().isEmpty()
|| sqlStatementContext instanceof InsertStatementContext && null != ((InsertStatementContext) sqlStatementContext).getInsertSelectContext();
}

@Override
public Collection<SQLToken> generateSQLTokens(final SQLStatementContext sqlStatementContext) {
Preconditions.checkState(sqlStatementContext instanceof SelectStatementContext);
Collection<SQLToken> result = new LinkedHashSet<>();
SelectStatementContext selectStatementContext = (SelectStatementContext) sqlStatementContext;
addGenerateSQLTokens(result, selectStatementContext);
for (SelectStatementContext each : selectStatementContext.getSubqueryContexts().values()) {
addGenerateSQLTokens(result, each);
if (sqlStatementContext instanceof SelectStatementContext) {
generateSQLTokens((SelectStatementContext) sqlStatementContext, result);
} else if (sqlStatementContext instanceof InsertStatementContext && null != ((InsertStatementContext) sqlStatementContext).getInsertSelectContext()) {
generateSQLTokens(((InsertStatementContext) sqlStatementContext).getInsertSelectContext().getSelectStatementContext(), result);
}
return result;
}

private void generateSQLTokens(final SelectStatementContext selectStatementContext, final Collection<SQLToken> sqlTokens) {
addGenerateSQLTokens(sqlTokens, selectStatementContext);
for (SelectStatementContext each : selectStatementContext.getSubqueryContexts().values()) {
addGenerateSQLTokens(sqlTokens, each);
}
}

private void addGenerateSQLTokens(final Collection<SQLToken> sqlTokens, final SelectStatementContext selectStatementContext) {
for (ProjectionSegment each : selectStatementContext.getSqlStatement().getProjections().getProjections()) {
SubqueryType subqueryType = selectStatementContext.getSubqueryType();
Expand Down Expand Up @@ -153,8 +160,11 @@ private Collection<Projection> generateProjections(final EncryptColumn encryptCo
return generateProjectionsInTableSegmentSubquery(encryptColumn, columnProjection, shorthandProjection, subqueryType);
} else if (SubqueryType.PREDICATE_SUBQUERY == subqueryType) {
return Collections.singleton(generateProjectionInPredicateSubquery(encryptColumn, columnProjection, shorthandProjection));
} else if (SubqueryType.INSERT_SELECT_SUBQUERY == subqueryType) {
return generateProjectionsInInsertSelectSubquery(encryptColumn, columnProjection, shorthandProjection);
}
throw new UnsupportedSQLOperationException("Projections not in simple select, table subquery, join subquery and predicate subquery are not supported in encrypt feature.");
throw new UnsupportedSQLOperationException(
"Projections not in simple select, table subquery, join subquery, predicate subquery and insert select subquery are not supported in encrypt feature.");
}

private ColumnProjection generateProjection(final EncryptColumn encryptColumn, final ColumnProjection columnProjection, final boolean shorthandProjection) {
Expand Down Expand Up @@ -186,6 +196,18 @@ private ColumnProjection generateProjectionInPredicateSubquery(final EncryptColu
databaseType));
}

private Collection<Projection> generateProjectionsInInsertSelectSubquery(final EncryptColumn encryptColumn, final ColumnProjection columnProjection, final boolean shorthandProjection) {
QuoteCharacter quoteCharacter = columnProjection.getName().getQuoteCharacter();
IdentifierValue columnName = new IdentifierValue(encryptColumn.getCipher().getName(), quoteCharacter);
Collection<Projection> result = new LinkedList<>();
IdentifierValue encryptColumnOwner = shorthandProjection ? columnProjection.getOwner().orElse(null) : null;
result.add(new ColumnProjection(encryptColumnOwner, columnName, null, databaseType));
IdentifierValue assistedColumOwner = columnProjection.getOwner().orElse(null);
encryptColumn.getAssistedQuery().ifPresent(optional -> result.add(new ColumnProjection(assistedColumOwner, new IdentifierValue(optional.getName(), quoteCharacter), null, databaseType)));
encryptColumn.getLikeQuery().ifPresent(optional -> result.add(new ColumnProjection(assistedColumOwner, new IdentifierValue(optional.getName(), quoteCharacter), null, databaseType)));
return result;
}

private ShorthandProjection getShorthandProjection(final ShorthandProjectionSegment segment, final ProjectionsContext projectionsContext) {
Optional<String> owner = segment.getOwner().isPresent() ? Optional.of(segment.getOwner().get().getIdentifier().getValue()) : Optional.empty();
for (Projection each : projectionsContext.getProjections()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
package org.apache.shardingsphere.infra.binder.context.statement.dml;

import lombok.Getter;
import org.apache.shardingsphere.infra.exception.dialect.exception.syntax.database.NoDatabaseSelectedException;
import org.apache.shardingsphere.infra.exception.dialect.exception.syntax.database.UnknownDatabaseException;
import org.apache.shardingsphere.infra.binder.context.aware.ParameterAware;
import org.apache.shardingsphere.infra.binder.context.segment.insert.keygen.GeneratedKeyContext;
import org.apache.shardingsphere.infra.binder.context.segment.insert.keygen.engine.GeneratedKeyContextEngine;
Expand All @@ -30,10 +28,13 @@
import org.apache.shardingsphere.infra.binder.context.statement.CommonSQLStatementContext;
import org.apache.shardingsphere.infra.binder.context.type.TableAvailable;
import org.apache.shardingsphere.infra.database.core.type.DatabaseTypeRegistry;
import org.apache.shardingsphere.infra.exception.core.ShardingSpherePreconditions;
import org.apache.shardingsphere.infra.exception.dialect.exception.syntax.database.NoDatabaseSelectedException;
import org.apache.shardingsphere.infra.exception.dialect.exception.syntax.database.UnknownDatabaseException;
import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
import org.apache.shardingsphere.infra.exception.core.ShardingSpherePreconditions;
import org.apache.shardingsphere.sql.parser.sql.common.enums.SubqueryType;
import org.apache.shardingsphere.sql.parser.sql.common.extractor.TableExtractor;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.AssignmentSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.InsertValuesSegment;
Expand Down Expand Up @@ -129,6 +130,7 @@ private Optional<InsertSelectContext> getInsertSelectContext(final ShardingSpher
}
SubquerySegment insertSelectSegment = getSqlStatement().getInsertSelect().get();
SelectStatementContext selectStatementContext = new SelectStatementContext(metaData, params, insertSelectSegment.getSelect(), defaultDatabaseName);
selectStatementContext.setSubqueryType(SubqueryType.INSERT_SELECT_SUBQUERY);
InsertSelectContext insertSelectContext = new InsertSelectContext(selectStatementContext, params, paramsOffset.get());
paramsOffset.addAndGet(insertSelectContext.getParameterCount());
return Optional.of(insertSelectContext);
Expand Down

0 comments on commit 05d5a9b

Please sign in to comment.