Skip to content

Commit

Permalink
[CALCITE-5409] Implement BatchNestedLoopJoin for JDBC
Browse files Browse the repository at this point in the history
  • Loading branch information
kramerul committed Dec 5, 2023
1 parent bd7d4e8 commit 499d4b0
Show file tree
Hide file tree
Showing 7 changed files with 266 additions and 11 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to you under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.calcite.adapter.jdbc;

import org.apache.calcite.DataContext;
import org.apache.calcite.adapter.java.JavaTypeFactory;
import org.apache.calcite.linq4j.QueryProvider;
import org.apache.calcite.schema.SchemaPlus;

import org.checkerframework.checker.nullness.qual.Nullable;

/**
* A special DataContext which handles correlation variable for batch nested loop joins.
*/
public class JdbcCorrelationDataContext implements DataContext {
public static final int OFFSET = 10000;

private final DataContext delegate;
private final Object[] parameters;

public JdbcCorrelationDataContext(DataContext delegate, Object[] parameters) {
this.delegate = delegate;
this.parameters = parameters;
}
@Override public @Nullable SchemaPlus getRootSchema() {
return delegate.getRootSchema();
}

@Override public JavaTypeFactory getTypeFactory() {
return delegate.getTypeFactory();
}

@Override public QueryProvider getQueryProvider() {
return delegate.getQueryProvider();
}

@Override public @Nullable Object get(String name) {
if (name.startsWith("?")) {
int index = Integer.parseInt(name.substring(1));
if (index >= OFFSET && index < OFFSET + parameters.length) {
return parameters[index - OFFSET];
}
}
return delegate.get(name);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to you under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.calcite.adapter.jdbc;

import org.apache.calcite.rel.core.CorrelationId;

import java.lang.reflect.Type;

/**
* A class to build an object of type JdbcCorrelationDataContext.
*/
public interface JdbcCorrelationDataContextBuilder {
int add(CorrelationId id, int ordinal, Type type);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to you under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.calcite.adapter.jdbc;

import org.apache.calcite.DataContext;
import org.apache.calcite.adapter.enumerable.EnumerableRelImplementor;
import org.apache.calcite.linq4j.tree.BlockBuilder;
import org.apache.calcite.linq4j.tree.Expression;
import org.apache.calcite.linq4j.tree.Expressions;
import org.apache.calcite.linq4j.tree.Types;
import org.apache.calcite.rel.core.CorrelationId;

import com.google.common.collect.ImmutableList;

import java.lang.reflect.Constructor;
import java.lang.reflect.Type;

/**
* An implementation class of JdbcCorrelationDataContext.
*/
public class JdbcCorrelationDataContextBuilderImpl implements JdbcCorrelationDataContextBuilder {
private static final Constructor NEW =
Types.lookupConstructor(JdbcCorrelationDataContext.class, DataContext.class, Object[].class);
private final ImmutableList.Builder<Expression> parameters = new ImmutableList.Builder<>();
private int offset = JdbcCorrelationDataContext.OFFSET;
private final EnumerableRelImplementor implementor;
private final BlockBuilder builder;
private final Expression dataContext;

public JdbcCorrelationDataContextBuilderImpl(EnumerableRelImplementor implementor,
BlockBuilder builder, Expression dataContext) {
this.implementor = implementor;
this.builder = builder;
this.dataContext = dataContext;
}

@Override public int add(CorrelationId id, int ordinal, Type type) {
parameters.add(implementor.getCorrelVariableGetter(id.getName()).field(builder, ordinal, type));
return offset++;
}

public Expression build() {
return Expressions.new_(NEW, dataContext,
Expressions.newArrayInit(Object.class, 1, parameters.build()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,64 @@

import org.apache.calcite.adapter.java.JavaTypeFactory;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.CorrelationId;
import org.apache.calcite.rel.rel2sql.RelToSqlConverter;
import org.apache.calcite.rel.rel2sql.SqlImplementor;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexCorrelVariable;
import org.apache.calcite.sql.SqlDialect;
import org.apache.calcite.util.Util;
import org.apache.calcite.sql.SqlDynamicParam;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.parser.SqlParserPos;

import java.lang.reflect.Type;
import java.util.List;

/**
* State for generating a SQL statement.
*/
public class JdbcImplementor extends RelToSqlConverter {
public JdbcImplementor(SqlDialect dialect, JavaTypeFactory typeFactory) {

private final JdbcCorrelationDataContextBuilder dataContextBuilder;
private final JavaTypeFactory typeFactory;

public JdbcImplementor(SqlDialect dialect, JavaTypeFactory typeFactory,
JdbcCorrelationDataContextBuilder dataContextBuilder) {
super(dialect);
Util.discard(typeFactory);
this. typeFactory = typeFactory;
this.dataContextBuilder = dataContextBuilder;
}

public JdbcImplementor(SqlDialect dialect, JavaTypeFactory typeFactory) {
this(dialect, typeFactory, new JdbcCorrelationDataContextBuilder() {
private int counter = 1;
@Override public int add(CorrelationId id, int ordinal, Type type) {
return counter++;
}
});
}

public Result implement(RelNode node) {
return dispatch(node);
}

@Override protected Context getAliasContext(RexCorrelVariable variable) {
Context context = correlTableMap.get(variable.id);
if (context != null) {
return context;
}
List<RelDataTypeField> fieldList = variable.getType().getFieldList();
return new Context(dialect, fieldList.size()) {
@Override public SqlNode field(int ordinal) {
RelDataTypeField field = fieldList.get(ordinal);
return new SqlDynamicParam(
dataContextBuilder.add(variable.id, ordinal,
typeFactory.getJavaClass(field.getType())), SqlParserPos.ZERO);
}

@Override public SqlImplementor implementor() {
return JdbcImplementor.this;
}
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,9 @@ protected JdbcToEnumerableConverter(
final JdbcConvention jdbcConvention =
(JdbcConvention) requireNonNull(child.getConvention(),
() -> "child.getConvention() is null for " + child);
SqlString sqlString = generateSql(jdbcConvention.dialect);
JdbcCorrelationDataContextBuilderImpl dataContextBuilder =
new JdbcCorrelationDataContextBuilderImpl(implementor, builder0, DataContext.ROOT);
SqlString sqlString = generateSql(jdbcConvention.dialect, dataContextBuilder);
String sql = sqlString.getSql();
if (CalciteSystemProperty.DEBUG.value()) {
System.out.println("[" + sql + "]");
Expand Down Expand Up @@ -179,7 +181,7 @@ protected JdbcToEnumerableConverter(
Expressions.call(BuiltInMethod.CREATE_ENRICHER.method,
Expressions.newArrayInit(Integer.class, 1,
toIndexesTableExpression(sqlString)),
DataContext.ROOT));
dataContextBuilder.build()));

enumerable =
builder0.append("enumerable",
Expand Down Expand Up @@ -356,10 +358,11 @@ private static String jdbcGetMethod(@Nullable Primitive primitive) {
: "get" + SqlFunctions.initcap(castNonNull(primitive.primitiveName));
}

private SqlString generateSql(SqlDialect dialect) {
private SqlString generateSql(SqlDialect dialect,
JdbcCorrelationDataContextBuilder dataContextBuilder) {
final JdbcImplementor jdbcImplementor =
new JdbcImplementor(dialect,
(JavaTypeFactory) getCluster().getTypeFactory());
(JavaTypeFactory) getCluster().getTypeFactory(), dataContextBuilder);
final JdbcImplementor.Result result =
jdbcImplementor.visitRoot(this.getInput());
return result.asStatement().toSqlString(dialect);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -664,8 +664,12 @@ public SqlNode toSql(@Nullable RexProgram program, RexNode rex) {
final Context correlAliasContext = getAliasContext(variable);
final RexFieldAccess lastAccess = accesses.pollLast();
assert lastAccess != null;
sqlIdentifier = (SqlIdentifier) correlAliasContext
SqlNode node = correlAliasContext
.field(lastAccess.getField().getIndex());
if (node instanceof SqlDynamicParam) {
return node;
}
sqlIdentifier = (SqlIdentifier) node;
break;
case ROW:
case ITEM:
Expand Down Expand Up @@ -1478,6 +1482,12 @@ public static SqlNode toSql(RexLiteral literal) {
}
}

protected Context getAliasContext(RexCorrelVariable variable) {
return requireNonNull(
correlTableMap.get(variable.id),
() -> "variable " + variable.id + " is not found");
}

/** Simple implementation of {@link Context} that cannot handle sub-queries
* or correlations. Because it is so simple, you do not need to create a
* {@link SqlImplementor} or {@link org.apache.calcite.tools.RelBuilder}
Expand Down Expand Up @@ -1507,9 +1517,7 @@ protected abstract class BaseContext extends Context {
}

@Override protected Context getAliasContext(RexCorrelVariable variable) {
return requireNonNull(
correlTableMap.get(variable.id),
() -> "variable " + variable.id + " is not found");
return SqlImplementor.this.getAliasContext(variable);
}

@Override public SqlImplementor implementor() {
Expand Down
52 changes: 52 additions & 0 deletions core/src/test/java/org/apache/calcite/test/JdbcAdapterTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,16 @@
*/
package org.apache.calcite.test;

import org.apache.calcite.adapter.enumerable.EnumerableRules;
import org.apache.calcite.adapter.java.ReflectiveSchema;
import org.apache.calcite.config.CalciteConnectionProperty;
import org.apache.calcite.config.Lex;
import org.apache.calcite.plan.RelOptPlanner;
import org.apache.calcite.runtime.Hook;
import org.apache.calcite.test.CalciteAssert.AssertThat;
import org.apache.calcite.test.CalciteAssert.DatabaseInstance;
import org.apache.calcite.test.schemata.foodmart.FoodmartSchema;
import org.apache.calcite.test.schemata.hr.HrSchema;
import org.apache.calcite.util.Smalls;
import org.apache.calcite.util.TestUtil;

Expand All @@ -35,6 +40,7 @@
import java.util.Properties;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Consumer;

import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.is;
Expand Down Expand Up @@ -1147,6 +1153,52 @@ private LockWrapper exclusiveCleanDb(Connection c) throws SQLException {
});
}

@Test void testBatchNestedLoopJoinPlan() {
final String sql = "SELECT *\n"
+ "FROM \"s\".\"emps\" A\n"
+ "LEFT OUTER JOIN \"foodmart\".\"store\" B ON A.\"empid\" = B.\"store_id\"";
final String explain = "JdbcFilter(condition=[OR(=($cor0.empid0, $0), =($cor1.empid0, $0)";
final String jdbcSql = "SELECT *\n"
+ "FROM \"foodmart\".\"store\"\n"
+ "WHERE ? = \"store_id\" OR (? = \"store_id\" OR ? = \"store_id\") OR (? = \"store_id\" OR"
+ " (? = \"store_id\" OR ? = \"store_id\")) OR (? = \"store_id\" OR (? = \"store_id\" OR ? "
+ "= \"store_id\") OR (? = \"store_id\" OR (? = \"store_id\" OR ? = \"store_id\"))) OR (? ="
+ " \"store_id\" OR (? = \"store_id\" OR ? = \"store_id\") OR (? = \"store_id\" OR (? = "
+ "\"store_id\" OR ? = \"store_id\")) OR (? = \"store_id\" OR (? = \"store_id\" OR ? = "
+ "\"store_id\") OR (? = \"store_id\" OR ? = \"store_id\" OR (? = \"store_id\" OR ? = "
+ "\"store_id\")))) OR (? = \"store_id\" OR (? = \"store_id\" OR ? = \"store_id\") OR (? = "
+ "\"store_id\" OR (? = \"store_id\" OR ? = \"store_id\")) OR (? = \"store_id\" OR (? = "
+ "\"store_id\" OR ? = \"store_id\") OR (? = \"store_id\" OR (? = \"store_id\" OR ? = "
+ "\"store_id\"))) OR (? = \"store_id\" OR (? = \"store_id\" OR ? = \"store_id\") OR (? = "
+ "\"store_id\" OR (? = \"store_id\" OR ? = \"store_id\")) OR (? = \"store_id\" OR (? = "
+ "\"store_id\" OR ? = \"store_id\") OR (? = \"store_id\" OR ? = \"store_id\" OR (? = "
+ "\"store_id\" OR ? = \"store_id\"))))) OR (? = \"store_id\" OR (? = \"store_id\" OR ? = "
+ "\"store_id\") OR (? = \"store_id\" OR (? = \"store_id\" OR ? = \"store_id\")) OR (? = "
+ "\"store_id\" OR (? = \"store_id\" OR ? = \"store_id\") OR (? = \"store_id\" OR (? = "
+ "\"store_id\" OR ? = \"store_id\"))) OR (? = \"store_id\" OR (? = \"store_id\" OR ? = "
+ "\"store_id\") OR (? = \"store_id\" OR (? = \"store_id\" OR ? = \"store_id\")) OR (? = "
+ "\"store_id\" OR (? = \"store_id\" OR ? = \"store_id\") OR (? = \"store_id\" OR ? = "
+ "\"store_id\" OR (? = \"store_id\" OR ? = \"store_id\")))) OR (? = \"store_id\" OR (? = "
+ "\"store_id\" OR ? = \"store_id\") OR (? = \"store_id\" OR (? = \"store_id\" OR ? = "
+ "\"store_id\")) OR (? = \"store_id\" OR (? = \"store_id\" OR ? = \"store_id\") OR (? = "
+ "\"store_id\" OR (? = \"store_id\" OR ? = \"store_id\"))) OR (? = \"store_id\" OR (? = "
+ "\"store_id\" OR ? = \"store_id\") OR (? = \"store_id\" OR (? = \"store_id\" OR ? = "
+ "\"store_id\")) OR (? = \"store_id\" OR (? = \"store_id\" OR ? = \"store_id\") OR (? = "
+ "\"store_id\" OR ? = \"store_id\" OR (? = \"store_id\" OR ? = \"store_id\"))))))";
CalciteAssert.model(FoodmartSchema.FOODMART_MODEL)
.withSchema("s", new ReflectiveSchema(new HrSchema()))
.withHook(Hook.PLANNER, (Consumer<RelOptPlanner>) planner -> {
planner.addRule(EnumerableRules.ENUMERABLE_BATCH_NESTED_LOOP_JOIN_RULE);
})
.query(sql)
.explainContains(explain)
.runs()
.enable(CalciteAssert.DB == CalciteAssert.DatabaseInstance.HSQLDB
|| CalciteAssert.DB == DatabaseInstance.POSTGRESQL)
.planHasSql(jdbcSql)
.returnsCount(4);
}

/** Acquires a lock, and releases it when closed. */
static class LockWrapper implements AutoCloseable {
private final Lock lock;
Expand Down

0 comments on commit 499d4b0

Please sign in to comment.