Skip to content

Commit

Permalink
Refactor EncryptPredicateColumnTokenGenerator (#32376)
Browse files Browse the repository at this point in the history
  • Loading branch information
terrymanu authored Aug 2, 2024
1 parent d42197b commit cf29e95
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

import lombok.Setter;
import org.apache.shardingsphere.encrypt.exception.syntax.UnsupportedEncryptSQLException;
import org.apache.shardingsphere.encrypt.rewrite.aware.DatabaseTypeAware;
import org.apache.shardingsphere.encrypt.rewrite.aware.EncryptRuleAware;
import org.apache.shardingsphere.encrypt.rewrite.token.comparator.JoinConditionsEncryptorComparator;
import org.apache.shardingsphere.encrypt.rule.EncryptRule;
Expand Down Expand Up @@ -61,61 +60,56 @@
*/
@HighFrequencyInvocation
@Setter
public final class EncryptPredicateColumnTokenGenerator implements CollectionSQLTokenGenerator<SQLStatementContext>, SchemaMetaDataAware, EncryptRuleAware, DatabaseTypeAware {
public final class EncryptPredicateColumnTokenGenerator implements CollectionSQLTokenGenerator<SQLStatementContext>, SchemaMetaDataAware, EncryptRuleAware {

private String databaseName;

private Map<String, ShardingSphereSchema> schemas;

private EncryptRule encryptRule;

private DatabaseType databaseType;

@Override
public boolean isGenerateSQLToken(final SQLStatementContext sqlStatementContext) {
return sqlStatementContext instanceof WhereAvailable && !((WhereAvailable) sqlStatementContext).getWhereSegments().isEmpty();
}

@Override
public Collection<SQLToken> generateSQLTokens(final SQLStatementContext sqlStatementContext) {
Collection<ColumnSegment> columnSegments = Collections.emptyList();
Collection<WhereSegment> whereSegments = Collections.emptyList();
Collection<BinaryOperationExpression> joinConditions = Collections.emptyList();
if (sqlStatementContext instanceof WhereAvailable) {
columnSegments = ((WhereAvailable) sqlStatementContext).getColumnSegments();
whereSegments = ((WhereAvailable) sqlStatementContext).getWhereSegments();
joinConditions = ((WhereAvailable) sqlStatementContext).getJoinConditions();
}
ShardingSpherePreconditions.checkState(JoinConditionsEncryptorComparator.isSame(joinConditions, encryptRule),
ShardingSpherePreconditions.checkState(JoinConditionsEncryptorComparator.isSame(((WhereAvailable) sqlStatementContext).getJoinConditions(), encryptRule),
() -> new UnsupportedSQLOperationException("Can not use different encryptor in join condition"));
Collection<ColumnSegment> columnSegments = ((WhereAvailable) sqlStatementContext).getColumnSegments();
Collection<WhereSegment> whereSegments = ((WhereAvailable) sqlStatementContext).getWhereSegments();
String defaultSchema = new DatabaseTypeRegistry(sqlStatementContext.getDatabaseType()).getDefaultSchemaName(databaseName);
ShardingSphereSchema schema = ((TableAvailable) sqlStatementContext).getTablesContext().getSchemaName().map(schemas::get).orElseGet(() -> schemas.get(defaultSchema));
Map<String, String> columnExpressionTableNames = ((TableAvailable) sqlStatementContext).getTablesContext().findTableNames(columnSegments, schema);
return generateSQLTokens(columnSegments, columnExpressionTableNames, whereSegments);
return generateSQLTokens(columnSegments, columnExpressionTableNames, whereSegments, sqlStatementContext.getDatabaseType());
}

private Collection<SQLToken> generateSQLTokens(final Collection<ColumnSegment> columnSegments, final Map<String, String> columnExpressionTableNames, final Collection<WhereSegment> whereSegments) {
private Collection<SQLToken> generateSQLTokens(final Collection<ColumnSegment> columnSegments,
final Map<String, String> columnExpressionTableNames, final Collection<WhereSegment> whereSegments, final DatabaseType databaseType) {
Collection<SQLToken> result = new LinkedHashSet<>(columnSegments.size(), 1F);
for (ColumnSegment each : columnSegments) {
String tableName = Optional.ofNullable(columnExpressionTableNames.get(each.getExpression())).orElse("");
Optional<EncryptTable> encryptTable = encryptRule.findEncryptTable(tableName);
if (encryptTable.isPresent() && encryptTable.get().isEncryptColumn(each.getIdentifier().getValue())) {
result.add(buildSubstitutableColumnNameToken(encryptTable.get().getEncryptColumn(each.getIdentifier().getValue()), each, whereSegments));
result.add(buildSubstitutableColumnNameToken(encryptTable.get().getEncryptColumn(each.getIdentifier().getValue()), each, whereSegments, databaseType));
}
}
return result;
}

private SubstitutableColumnNameToken buildSubstitutableColumnNameToken(final EncryptColumn encryptColumn, final ColumnSegment columnSegment, final Collection<WhereSegment> whereSegments) {
private SubstitutableColumnNameToken buildSubstitutableColumnNameToken(final EncryptColumn encryptColumn,
final ColumnSegment columnSegment, final Collection<WhereSegment> whereSegments, final DatabaseType databaseType) {
int startIndex = columnSegment.getOwner().isPresent() ? columnSegment.getOwner().get().getStopIndex() + 2 : columnSegment.getStartIndex();
int stopIndex = columnSegment.getStopIndex();
if (includesLike(whereSegments, columnSegment)) {
LikeQueryColumnItem likeQueryColumnItem = encryptColumn.getLikeQuery().orElseThrow(() -> new UnsupportedEncryptSQLException("LIKE"));
return new SubstitutableColumnNameToken(startIndex, stopIndex, createColumnProjections(likeQueryColumnItem.getName(), columnSegment.getIdentifier().getQuoteCharacter()), databaseType);
return new SubstitutableColumnNameToken(
startIndex, stopIndex, createColumnProjections(likeQueryColumnItem.getName(), columnSegment.getIdentifier().getQuoteCharacter(), databaseType), databaseType);
}
Collection<Projection> columnProjections =
encryptColumn.getAssistedQuery().map(optional -> createColumnProjections(optional.getName(), columnSegment.getIdentifier().getQuoteCharacter()))
.orElseGet(() -> createColumnProjections(encryptColumn.getCipher().getName(), columnSegment.getIdentifier().getQuoteCharacter()));
encryptColumn.getAssistedQuery().map(optional -> createColumnProjections(optional.getName(), columnSegment.getIdentifier().getQuoteCharacter(), databaseType))
.orElseGet(() -> createColumnProjections(encryptColumn.getCipher().getName(), columnSegment.getIdentifier().getQuoteCharacter(), databaseType));
return new SubstitutableColumnNameToken(startIndex, stopIndex, columnProjections, databaseType);
}

Expand Down Expand Up @@ -145,7 +139,7 @@ private boolean isSameColumnSegment(final ExpressionSegment columnSegment, final
return columnSegment instanceof ColumnSegment && columnSegment.getStartIndex() == targetColumnSegment.getStartIndex() && columnSegment.getStopIndex() == targetColumnSegment.getStopIndex();
}

private Collection<Projection> createColumnProjections(final String columnName, final QuoteCharacter quoteCharacter) {
private Collection<Projection> createColumnProjections(final String columnName, final QuoteCharacter quoteCharacter, final DatabaseType databaseType) {
return Collections.singleton(new ColumnProjection(null, new IdentifierValue(columnName, quoteCharacter), null, databaseType));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@
import org.apache.shardingsphere.encrypt.config.rule.EncryptTableRuleConfiguration;
import org.apache.shardingsphere.encrypt.rewrite.token.pojo.EncryptInsertValuesToken;
import org.apache.shardingsphere.encrypt.rule.EncryptRule;
import org.apache.shardingsphere.infra.algorithm.core.config.AlgorithmConfiguration;
import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.context.statement.dml.InsertStatementContext;
import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.binder.context.statement.dml.UpdateStatementContext;
import org.apache.shardingsphere.infra.algorithm.core.config.AlgorithmConfiguration;
import org.apache.shardingsphere.infra.config.props.ConfigurationProperties;
import org.apache.shardingsphere.infra.database.core.DefaultDatabase;
import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData;
Expand All @@ -56,6 +56,7 @@
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.table.SimpleTableSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.table.TableNameSegment;
import org.apache.shardingsphere.sql.parser.statement.core.statement.dml.InsertStatement;
import org.apache.shardingsphere.sql.parser.statement.core.statement.dml.UpdateStatement;
import org.apache.shardingsphere.sql.parser.statement.core.value.identifier.IdentifierValue;
import org.apache.shardingsphere.sql.parser.statement.mysql.dml.MySQLInsertStatement;
import org.apache.shardingsphere.sql.parser.statement.mysql.dml.MySQLSelectStatement;
Expand Down Expand Up @@ -164,7 +165,7 @@ private static InsertStatement createInsertSelectStatement(final boolean contain
* @return created update statement context
*/
public static UpdateStatementContext createUpdateStatementContext() {
MySQLUpdateStatement updateStatement = new MySQLUpdateStatement();
UpdateStatement updateStatement = new MySQLUpdateStatement();
updateStatement.setTable(new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("t_user"))));
updateStatement.setWhere(createWhereSegment());
updateStatement.setSetAssignment(createSetAssignmentSegment());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,9 @@

import org.apache.shardingsphere.encrypt.rewrite.token.generator.fixture.EncryptGeneratorFixtureBuilder;
import org.apache.shardingsphere.infra.database.core.DefaultDatabase;
import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
import org.apache.shardingsphere.infra.exception.generic.UnsupportedSQLOperationException;
import org.apache.shardingsphere.infra.rewrite.sql.token.pojo.SQLToken;
import org.apache.shardingsphere.infra.rewrite.sql.token.pojo.generic.SubstitutableColumnNameToken;
import org.apache.shardingsphere.infra.spi.type.typed.TypedSPILoader;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

Expand Down Expand Up @@ -55,7 +53,6 @@ void assertIsGenerateSQLToken() {
void assertGenerateSQLTokenFromGenerateNewSQLToken() {
generator.setDatabaseName(DefaultDatabase.LOGIC_NAME);
generator.setSchemas(Collections.emptyMap());
generator.setDatabaseType(TypedSPILoader.getService(DatabaseType.class, "FIXTURE"));
Collection<SQLToken> substitutableColumnNameTokens = generator.generateSQLTokens(EncryptGeneratorFixtureBuilder.createUpdateStatementContext());
assertThat(substitutableColumnNameTokens.size(), is(1));
assertThat(((SubstitutableColumnNameToken) substitutableColumnNameTokens.iterator().next()).toString(null), is("pwd_assist"));
Expand Down

0 comments on commit cf29e95

Please sign in to comment.