Skip to content

Commit

Permalink
Add more test cases on ShardingIndexTokenGenerator
Browse files Browse the repository at this point in the history
  • Loading branch information
terrymanu committed Nov 16, 2024
1 parent 1488b80 commit 8f4d15e
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
@RequiredArgsConstructor
public final class ShardingCursorTokenGenerator implements OptionalSQLTokenGenerator<SQLStatementContext> {

private final ShardingRule shardingRule;
private final ShardingRule rule;

@Override
public boolean isGenerateSQLToken(final SQLStatementContext sqlStatementContext) {
Expand All @@ -43,6 +43,6 @@ public boolean isGenerateSQLToken(final SQLStatementContext sqlStatementContext)
@Override
public SQLToken generateSQLToken(final SQLStatementContext sqlStatementContext) {
CursorNameSegment cursorName = ((CursorAvailable) sqlStatementContext).getCursorName().orElseThrow(CursorNameNotFoundException::new);
return new CursorToken(cursorName.getStartIndex(), cursorName.getStopIndex(), cursorName.getIdentifier(), sqlStatementContext, shardingRule);
return new CursorToken(cursorName.getStartIndex(), cursorName.getStopIndex(), cursorName.getIdentifier(), sqlStatementContext, rule);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
@Setter
public final class ShardingIndexTokenGenerator implements CollectionSQLTokenGenerator<SQLStatementContext>, SchemaMetaDataAware {

private final ShardingRule shardingRule;
private final ShardingRule rule;

private Map<String, ShardingSphereSchema> schemas;

Expand All @@ -65,7 +65,7 @@ public Collection<SQLToken> generateSQLTokens(final SQLStatementContext sqlState
if (sqlStatementContext instanceof IndexAvailable) {
for (IndexSegment each : ((IndexAvailable) sqlStatementContext).getIndexes()) {
ShardingSphereSchema schema = each.getOwner().isPresent() ? schemas.get(each.getOwner().get().getIdentifier().getValue()) : defaultSchema;
result.add(new IndexToken(each.getIndexName().getStartIndex(), each.getStopIndex(), each.getIndexName().getIdentifier(), sqlStatementContext, shardingRule, schema));
result.add(new IndexToken(each.getIndexName().getStartIndex(), each.getStopIndex(), each.getIndexName().getIdentifier(), sqlStatementContext, rule, schema));
}
}
return result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
@Setter
public final class ShardingTableTokenGenerator implements CollectionSQLTokenGenerator<SQLStatementContext>, RouteContextAware {

private final ShardingRule shardingRule;
private final ShardingRule rule;

private RouteContext routeContext;

Expand All @@ -53,9 +53,9 @@ public boolean isGenerateSQLToken(final SQLStatementContext sqlStatementContext)

private boolean isAllBindingTables(final SQLStatementContext sqlStatementContext) {
Collection<String> shardingLogicTableNames = sqlStatementContext instanceof TableAvailable
? shardingRule.getShardingLogicTableNames(((TableAvailable) sqlStatementContext).getTablesContext().getTableNames())
? rule.getShardingLogicTableNames(((TableAvailable) sqlStatementContext).getTablesContext().getTableNames())
: Collections.emptyList();
return shardingLogicTableNames.size() > 1 && shardingRule.isAllBindingTables(shardingLogicTableNames);
return shardingLogicTableNames.size() > 1 && rule.isAllBindingTables(shardingLogicTableNames);
}

@Override
Expand All @@ -67,8 +67,8 @@ private Collection<SQLToken> generateSQLTokens(final TableAvailable sqlStatement
Collection<SQLToken> result = new LinkedList<>();
for (SimpleTableSegment each : sqlStatementContext.getTablesContext().getSimpleTables()) {
TableNameSegment tableName = each.getTableName();
if (shardingRule.findShardingTable(tableName.getIdentifier().getValue()).isPresent()) {
result.add(new ShardingTableToken(tableName.getStartIndex(), tableName.getStopIndex(), tableName.getIdentifier(), (SQLStatementContext) sqlStatementContext, shardingRule));
if (rule.findShardingTable(tableName.getIdentifier().getValue()).isPresent()) {
result.add(new ShardingTableToken(tableName.getStartIndex(), tableName.getStopIndex(), tableName.getIdentifier(), (SQLStatementContext) sqlStatementContext, rule));
}
}
return result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public final class IndexToken extends SQLToken implements Substitutable, RouteUn

private final SQLStatementContext sqlStatementContext;

private final ShardingRule shardingRule;
private final ShardingRule rule;

private final ShardingSphereSchema schema;

Expand All @@ -54,7 +54,7 @@ public IndexToken(final int startIndex, final int stopIndex, final IdentifierVal
this.stopIndex = stopIndex;
this.identifier = identifier;
this.sqlStatementContext = sqlStatementContext;
this.shardingRule = shardingRule;
this.rule = shardingRule;
this.schema = schema;
}

Expand All @@ -70,10 +70,10 @@ private boolean isGeneratedIndex() {

private String getIndexValue(final RouteUnit routeUnit) {
Optional<String> logicTableName = findLogicTableNameFromMetaData(identifier.getValue());
if (logicTableName.isPresent() && !shardingRule.isShardingTable(logicTableName.get())) {
if (logicTableName.isPresent() && !rule.isShardingTable(logicTableName.get())) {
return identifier.getValue();
}
Map<String, String> logicAndActualTables = ShardingTokenUtils.getLogicAndActualTableMap(routeUnit, sqlStatementContext, shardingRule);
Map<String, String> logicAndActualTables = ShardingTokenUtils.getLogicAndActualTableMap(routeUnit, sqlStatementContext, rule);
String actualTableName = logicTableName.map(logicAndActualTables::get).orElseGet(() -> logicAndActualTables.isEmpty() ? null : logicAndActualTables.values().iterator().next());
return IndexMetaDataUtils.getActualIndexName(identifier.getValue(), actualTableName);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,24 @@

package org.apache.shardingsphere.sharding.rewrite.token.generator.impl;

import org.apache.shardingsphere.infra.binder.context.statement.UnknownSQLStatementContext;
import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.context.statement.ddl.AlterIndexStatementContext;
import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
import org.apache.shardingsphere.infra.rewrite.sql.token.common.pojo.SQLToken;
import org.apache.shardingsphere.infra.spi.type.typed.TypedSPILoader;
import org.apache.shardingsphere.sharding.rewrite.token.pojo.IndexToken;
import org.apache.shardingsphere.sharding.rule.ShardingRule;
import org.apache.shardingsphere.sql.parser.statement.core.segment.ddl.index.IndexNameSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.ddl.index.IndexSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.OwnerSegment;
import org.apache.shardingsphere.sql.parser.statement.core.value.identifier.IdentifierValue;
import org.junit.jupiter.api.Test;
import org.mockito.internal.configuration.plugins.Plugins;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedList;
import java.util.Optional;

import static org.hamcrest.CoreMatchers.is;
Expand All @@ -45,33 +48,71 @@
class ShardingIndexTokenGeneratorTest {

@Test
void assertIsGenerateSQLToken() {
UnknownSQLStatementContext sqlStatementContext = mock(UnknownSQLStatementContext.class);
void assertIsNotGenerateSQLTokenWithNotIndexAvailable() {
SQLStatementContext sqlStatementContext = mock(SQLStatementContext.class);
ShardingIndexTokenGenerator generator = new ShardingIndexTokenGenerator(mock(ShardingRule.class));
assertFalse(generator.isGenerateSQLToken(sqlStatementContext));
}

@Test
void assertIsNotGenerateSQLTokenWithEmptyIndex() {
AlterIndexStatementContext sqlStatementContext = mock(AlterIndexStatementContext.class, RETURNS_DEEP_STUBS);
when(sqlStatementContext.getIndexes().isEmpty()).thenReturn(true);
ShardingIndexTokenGenerator generator = new ShardingIndexTokenGenerator(mock(ShardingRule.class));
assertFalse(generator.isGenerateSQLToken(sqlStatementContext));
AlterIndexStatementContext alterIndexStatementContext = mock(AlterIndexStatementContext.class);
Collection<IndexSegment> indexSegments = new LinkedList<>();
when(alterIndexStatementContext.getIndexes()).thenReturn(indexSegments);
assertFalse(generator.isGenerateSQLToken(alterIndexStatementContext));
indexSegments.add(mock(IndexSegment.class));
when(alterIndexStatementContext.getDatabaseType()).thenReturn(TypedSPILoader.getService(DatabaseType.class, "PostgreSQL"));
assertTrue(generator.isGenerateSQLToken(alterIndexStatementContext));
}

@Test
void assertGenerateSQLTokens() {
IndexSegment indexSegment = mock(IndexSegment.class, RETURNS_DEEP_STUBS);
when(indexSegment.getStartIndex()).thenReturn(1);
when(indexSegment.getStopIndex()).thenReturn(3);
when(indexSegment.getIndexName()).thenReturn(new IndexNameSegment(1, 3, mock(IdentifierValue.class)));
AlterIndexStatementContext alterIndexStatementContext = mock(AlterIndexStatementContext.class, RETURNS_DEEP_STUBS);
when(alterIndexStatementContext.getIndexes()).thenReturn(Collections.singleton(indexSegment));
when(alterIndexStatementContext.getDatabaseType()).thenReturn(TypedSPILoader.getService(DatabaseType.class, "FIXTURE"));
when(alterIndexStatementContext.getTablesContext().getSchemaName()).thenReturn(Optional.empty());
void assertIsGenerateSQLToken() {
AlterIndexStatementContext sqlStatementContext = mock(AlterIndexStatementContext.class, RETURNS_DEEP_STUBS);
when(sqlStatementContext.getDatabaseType()).thenReturn(TypedSPILoader.getService(DatabaseType.class, "PostgreSQL"));
ShardingIndexTokenGenerator generator = new ShardingIndexTokenGenerator(mock(ShardingRule.class));
assertTrue(generator.isGenerateSQLToken(sqlStatementContext));
}

@Test
void assertGenerateSQLTokensWithNotIndexAvailable() {
ShardingIndexTokenGenerator generator = new ShardingIndexTokenGenerator(mock(ShardingRule.class));
generator.setSchemas(Collections.singletonMap("test", mock(ShardingSphereSchema.class)));
Collection<SQLToken> actual = generator.generateSQLTokens(alterIndexStatementContext);
Collection<SQLToken> actual = generator.generateSQLTokens(mock(SQLStatementContext.class));
assertTrue(actual.isEmpty());
}

@Test
void assertGenerateSQLTokensWithSchemaOwner() throws ReflectiveOperationException {
IndexSegment indexSegment = new IndexSegment(1, 3, new IndexNameSegment(1, 3, mock(IdentifierValue.class)));
indexSegment.setOwner(new OwnerSegment(0, 0, new IdentifierValue("foo_schema")));
AlterIndexStatementContext sqlStatementContext = mockAlterIndexStatementContext(indexSegment);
ShardingIndexTokenGenerator generator = new ShardingIndexTokenGenerator(mock(ShardingRule.class));
ShardingSphereSchema schema = mock(ShardingSphereSchema.class);
generator.setSchemas(Collections.singletonMap("foo_schema", schema));
Collection<SQLToken> actual = generator.generateSQLTokens(sqlStatementContext);
assertTokens(actual, schema);
}

@Test
void assertGenerateSQLTokensWithoutSchemaOwner() throws ReflectiveOperationException {
IndexSegment indexSegment = new IndexSegment(1, 3, new IndexNameSegment(1, 3, mock(IdentifierValue.class)));
AlterIndexStatementContext sqlStatementContext = mockAlterIndexStatementContext(indexSegment);
ShardingIndexTokenGenerator generator = new ShardingIndexTokenGenerator(mock(ShardingRule.class));
ShardingSphereSchema schema = mock(ShardingSphereSchema.class);
generator.setDefaultSchema(schema);
Collection<SQLToken> actual = generator.generateSQLTokens(sqlStatementContext);
assertTokens(actual, schema);
}

private AlterIndexStatementContext mockAlterIndexStatementContext(final IndexSegment indexSegment) {
AlterIndexStatementContext result = mock(AlterIndexStatementContext.class, RETURNS_DEEP_STUBS);
when(result.getIndexes()).thenReturn(Collections.singleton(indexSegment));
when(result.getDatabaseType()).thenReturn(TypedSPILoader.getService(DatabaseType.class, "FIXTURE"));
when(result.getTablesContext().getSchemaName()).thenReturn(Optional.empty());
return result;
}

private void assertTokens(final Collection<SQLToken> actual, final ShardingSphereSchema schema) throws ReflectiveOperationException {
assertThat(actual.size(), is(1));
assertThat((new LinkedList<>(actual)).get(0).getStartIndex(), is(1));
IndexToken actualToken = (IndexToken) new ArrayList<>(actual).get(0);
assertThat(actualToken.getStartIndex(), is(1));
assertThat(actualToken.getStopIndex(), is(3));
assertThat(schema, is((ShardingSphereSchema) Plugins.getMemberAccessor().get(IndexToken.class.getDeclaredField("schema"), actualToken)));
}
}

0 comments on commit 8f4d15e

Please sign in to comment.