diff --git a/features/broadcast/core/src/main/java/org/apache/shardingsphere/broadcast/route/engine/BroadcastRouteEngineFactory.java b/features/broadcast/core/src/main/java/org/apache/shardingsphere/broadcast/route/engine/BroadcastRouteEngineFactory.java index 3c4559a33517c..6237e1e0b4cd9 100644 --- a/features/broadcast/core/src/main/java/org/apache/shardingsphere/broadcast/route/engine/BroadcastRouteEngineFactory.java +++ b/features/broadcast/core/src/main/java/org/apache/shardingsphere/broadcast/route/engine/BroadcastRouteEngineFactory.java @@ -40,7 +40,6 @@ import org.apache.shardingsphere.sql.parser.statement.core.statement.ddl.DDLStatement; import org.apache.shardingsphere.sql.parser.statement.core.statement.dml.SelectStatement; import org.apache.shardingsphere.sql.parser.statement.core.statement.tcl.TCLStatement; -import org.apache.shardingsphere.sql.parser.statement.mysql.dal.MySQLUseStatement; import java.util.Collection; @@ -73,13 +72,17 @@ public static BroadcastRouteEngine newInstance(final BroadcastRule rule, final S if (!(sqlStatementContext instanceof TableAvailable)) { return new BroadcastIgnoreRouteEngine(); } + Collection tableNames = ((TableAvailable) sqlStatementContext).getTablesContext().getTableNames(); + if (tableNames.isEmpty()) { + return new BroadcastIgnoreRouteEngine(); + } if (sqlStatement instanceof DALStatement) { - return getDALRouteEngine(rule, sqlStatementContext); + return getDALRouteEngine(rule, tableNames); } if (sqlStatement instanceof DCLStatement) { - return getDCLRouteEngine(rule, sqlStatementContext); + return getDCLRouteEngine(rule, tableNames); } - return getDMLRouteEngine(rule, sqlStatementContext, queryContext.getConnectionContext()); + return getDMLRouteEngine(rule, sqlStatementContext, queryContext.getConnectionContext(), tableNames); } private static BroadcastRouteEngine getCursorRouteEngine(final BroadcastRule rule, final SQLStatementContext sqlStatementContext, final ConnectionContext connectionContext) { @@ -98,23 +101,17 @@ private static BroadcastRouteEngine getDDLRouteEngine(final BroadcastRule rule, return rule.isAllBroadcastTables(tableNames) ? new BroadcastTableBroadcastRouteEngine(tableNames) : new BroadcastIgnoreRouteEngine(); } - private static BroadcastRouteEngine getDALRouteEngine(final BroadcastRule rule, final SQLStatementContext sqlStatementContext) { - SQLStatement sqlStatement = sqlStatementContext.getSqlStatement(); - if (sqlStatement instanceof MySQLUseStatement) { - return new BroadcastIgnoreRouteEngine(); - } - Collection tableNames = ((TableAvailable) sqlStatementContext).getTablesContext().getTableNames(); - return tableNames.isEmpty() ? new BroadcastIgnoreRouteEngine() : new BroadcastTableBroadcastRouteEngine(rule.filterBroadcastTableNames(tableNames)); + private static BroadcastRouteEngine getDALRouteEngine(final BroadcastRule rule, final Collection tableNames) { + return new BroadcastTableBroadcastRouteEngine(rule.filterBroadcastTableNames(tableNames)); } - private static BroadcastRouteEngine getDCLRouteEngine(final BroadcastRule rule, final SQLStatementContext sqlStatementContext) { - Collection tableNames = ((TableAvailable) sqlStatementContext).getTablesContext().getTableNames(); + private static BroadcastRouteEngine getDCLRouteEngine(final BroadcastRule rule, final Collection tableNames) { Collection broadcastTableNames = rule.filterBroadcastTableNames(tableNames); return broadcastTableNames.isEmpty() ? new BroadcastIgnoreRouteEngine() : new BroadcastTableBroadcastRouteEngine(broadcastTableNames); } - private static BroadcastRouteEngine getDMLRouteEngine(final BroadcastRule rule, final SQLStatementContext sqlStatementContext, final ConnectionContext connectionContext) { - Collection tableNames = ((TableAvailable) sqlStatementContext).getTablesContext().getTableNames(); + private static BroadcastRouteEngine getDMLRouteEngine(final BroadcastRule rule, final SQLStatementContext sqlStatementContext, + final ConnectionContext connectionContext, final Collection tableNames) { if (rule.isAllBroadcastTables(tableNames)) { return sqlStatementContext.getSqlStatement() instanceof SelectStatement ? new BroadcastUnicastRouteEngine(sqlStatementContext, tableNames, connectionContext) diff --git a/features/broadcast/core/src/test/java/org/apache/shardingsphere/broadcast/route/engine/BroadcastRouteEngineFactoryTest.java b/features/broadcast/core/src/test/java/org/apache/shardingsphere/broadcast/route/engine/BroadcastRouteEngineFactoryTest.java index 9f125ba33915e..4a302307909ce 100644 --- a/features/broadcast/core/src/test/java/org/apache/shardingsphere/broadcast/route/engine/BroadcastRouteEngineFactoryTest.java +++ b/features/broadcast/core/src/test/java/org/apache/shardingsphere/broadcast/route/engine/BroadcastRouteEngineFactoryTest.java @@ -34,6 +34,7 @@ import org.apache.shardingsphere.infra.spi.type.typed.TypedSPILoader; import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.table.SimpleTableSegment; import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.table.TableNameSegment; +import org.apache.shardingsphere.sql.parser.statement.core.statement.SQLStatement; import org.apache.shardingsphere.sql.parser.statement.core.statement.dal.DALStatement; import org.apache.shardingsphere.sql.parser.statement.core.statement.dcl.DCLStatement; import org.apache.shardingsphere.sql.parser.statement.core.statement.ddl.DDLStatement; @@ -139,22 +140,22 @@ void assertNewInstanceWithDDLStatementAndIsNotAllBroadcastTables() { } @Test - void assertNewInstanceWithDALStatementAndNotTableAvailable() { - when(queryContext.getSqlStatementContext().getSqlStatement()).thenReturn(mock(DALStatement.class)); + void assertNewInstanceWithoutTableAvailableStatement() { + when(queryContext.getSqlStatementContext().getSqlStatement()).thenReturn(mock(SQLStatement.class)); assertThat(BroadcastRouteEngineFactory.newInstance(rule, database, queryContext), instanceOf(BroadcastIgnoreRouteEngine.class)); } @Test - void assertNewInstanceWithDALStatementAndEmptyTables() { + void assertNewInstanceWithEmptyTables() { SQLStatementContext sqlStatementContext = mock(SQLStatementContext.class, withSettings().extraInterfaces(TableAvailable.class)); when(((TableAvailable) sqlStatementContext).getTablesContext()).thenReturn(new TablesContext(Collections.emptyList(), databaseType, null)); - when(sqlStatementContext.getSqlStatement()).thenReturn(mock(DALStatement.class)); + when(sqlStatementContext.getSqlStatement()).thenReturn(mock(SQLStatement.class)); when(queryContext.getSqlStatementContext()).thenReturn(sqlStatementContext); assertThat(BroadcastRouteEngineFactory.newInstance(rule, database, queryContext), instanceOf(BroadcastIgnoreRouteEngine.class)); } @Test - void assertNewInstanceWithDALStatementAndTables() { + void assertNewInstanceWithDALStatement() { SQLStatementContext sqlStatementContext = mock(SQLStatementContext.class, withSettings().extraInterfaces(TableAvailable.class)); TablesContext tablesContext = new TablesContext(Collections.singleton(new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("foo_tbl")))), databaseType, null); when(((TableAvailable) sqlStatementContext).getTablesContext()).thenReturn(tablesContext); @@ -164,16 +165,17 @@ void assertNewInstanceWithDALStatementAndTables() { } @Test - void assertNewInstanceWithDCLStatementAndEmptyTables() { + void assertNewInstanceWithDCLStatementWithoutBroadcastTables() { SQLStatementContext sqlStatementContext = mock(SQLStatementContext.class, withSettings().extraInterfaces(TableAvailable.class)); - when(((TableAvailable) sqlStatementContext).getTablesContext()).thenReturn(new TablesContext(Collections.emptyList(), databaseType, null)); + TablesContext tablesContext = new TablesContext(Collections.singleton(new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("bar_tbl")))), databaseType, null); + when(((TableAvailable) sqlStatementContext).getTablesContext()).thenReturn(tablesContext); when(sqlStatementContext.getSqlStatement()).thenReturn(mock(DCLStatement.class)); when(queryContext.getSqlStatementContext()).thenReturn(sqlStatementContext); assertThat(BroadcastRouteEngineFactory.newInstance(rule, database, queryContext), instanceOf(BroadcastIgnoreRouteEngine.class)); } @Test - void assertNewInstanceWithDCLStatementAndTables() { + void assertNewInstanceWithDCLStatementWithBroadcastTables() { SQLStatementContext sqlStatementContext = mock(SQLStatementContext.class, withSettings().extraInterfaces(TableAvailable.class)); TablesContext tablesContext = new TablesContext(Collections.singleton(new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("foo_tbl")))), databaseType, null); when(((TableAvailable) sqlStatementContext).getTablesContext()).thenReturn(tablesContext); @@ -185,7 +187,7 @@ void assertNewInstanceWithDCLStatementAndTables() { @Test void assertNewInstanceWithDMLStatementAndIsNotAllBroadcastTables() { SQLStatementContext sqlStatementContext = mock(SQLStatementContext.class, withSettings().extraInterfaces(TableAvailable.class)); - TablesContext tablesContext = new TablesContext(Collections.emptyList(), databaseType, null); + TablesContext tablesContext = new TablesContext(Collections.singleton(new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("bar_tbl")))), databaseType, null); when(((TableAvailable) sqlStatementContext).getTablesContext()).thenReturn(tablesContext); when(sqlStatementContext.getSqlStatement()).thenReturn(mock(DMLStatement.class)); when(queryContext.getSqlStatementContext()).thenReturn(sqlStatementContext); @@ -193,7 +195,7 @@ void assertNewInstanceWithDMLStatementAndIsNotAllBroadcastTables() { } @Test - void assertNewInstanceWithDMLSelectStatementAndIsAllBroadcastTables() { + void assertNewInstanceWithSelectStatementAndIsAllBroadcastTables() { SQLStatementContext sqlStatementContext = mock(SQLStatementContext.class, withSettings().extraInterfaces(TableAvailable.class)); TablesContext tablesContext = new TablesContext(Collections.singleton(new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("foo_tbl")))), databaseType, null); when(((TableAvailable) sqlStatementContext).getTablesContext()).thenReturn(tablesContext); @@ -203,7 +205,7 @@ void assertNewInstanceWithDMLSelectStatementAndIsAllBroadcastTables() { } @Test - void assertNewInstanceWithDMLUpdateStatementAndIsAllBroadcastTables() { + void assertNewInstanceWithUpdateStatementAndIsAllBroadcastTables() { SQLStatementContext sqlStatementContext = mock(SQLStatementContext.class, withSettings().extraInterfaces(TableAvailable.class)); TablesContext tablesContext = new TablesContext(Collections.singleton(new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("foo_tbl")))), databaseType, null); when(((TableAvailable) sqlStatementContext).getTablesContext()).thenReturn(tablesContext);