Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Caching SQLServerBulkCopy object for batch insert #2435

Merged
merged 4 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,14 @@ private void setPreparedStatementHandle(int handle) {
*/
private static final Pattern execEscapePattern = Pattern.compile("^\\s*(?i)(?:exec|execute)\\b");

/**
* For caching data related to batch insert with bulkcopy
*/
private SQLServerBulkCopy bcOperation = null;
private String bcOperationTableName = null;
private ArrayList<String> bcOperationColumnList = null;
private ArrayList<String> bcOperationValueList = null;

/** Returns the prepared statement SQL */
@Override
public String toString() {
Expand Down Expand Up @@ -392,6 +400,10 @@ final void closeInternal() {

// Clean up client-side state
batchParamValues = null;

if (null != bcOperation) {
bcOperation.close();
}
}

/**
Expand Down Expand Up @@ -2287,9 +2299,17 @@ public int[] executeBatch() throws SQLServerException, BatchUpdateException, SQL
}
}

String tableName = parseUserSQLForTableNameDW(false, false, false, false);
ArrayList<String> columnList = parseUserSQLForColumnListDW();
ArrayList<String> valueList = parseUserSQLForValueListDW(false);
if (null == bcOperationTableName) {
bcOperationTableName = parseUserSQLForTableNameDW(false, false, false, false);
}

if (null == bcOperationColumnList) {
bcOperationColumnList = parseUserSQLForColumnListDW();
}

if (null == bcOperationValueList) {
bcOperationValueList = parseUserSQLForValueListDW(false);
}

checkAdditionalQuery();

Expand All @@ -2298,28 +2318,28 @@ public int[] executeBatch() throws SQLServerException, BatchUpdateException, SQL
stmtColumnEncriptionSetting);
SQLServerResultSet rs = stmt
.executeQueryInternal("sp_executesql N'SET FMTONLY ON SELECT * FROM "
+ Util.escapeSingleQuotes(tableName) + " '")) {
+ Util.escapeSingleQuotes(bcOperationTableName) + " '")) {
Map<Integer, Integer> columnMappings = null;
if (null != columnList && !columnList.isEmpty()) {
if (columnList.size() != valueList.size()) {
if (null != bcOperationColumnList && !bcOperationColumnList.isEmpty()) {
if (bcOperationColumnList.size() != bcOperationValueList.size()) {

MessageFormat form = new MessageFormat(
SQLServerException.getErrString("R_colNotMatchTable"));
Object[] msgArgs = {columnList.size(), valueList.size()};
Object[] msgArgs = {bcOperationColumnList.size(), bcOperationValueList.size()};
throw new IllegalArgumentException(form.format(msgArgs));
}
columnMappings = new HashMap<>(columnList.size());
columnMappings = new HashMap<>(bcOperationColumnList.size());
} else {
if (rs.getColumnCount() != valueList.size()) {
if (rs.getColumnCount() != bcOperationValueList.size()) {
MessageFormat form = new MessageFormat(
SQLServerException.getErrString("R_colNotMatchTable"));
Object[] msgArgs = {rs.getColumnCount(), valueList.size()};
Object[] msgArgs = {rs.getColumnCount(), bcOperationValueList.size()};
throw new IllegalArgumentException(form.format(msgArgs));
}
}

SQLServerBulkBatchInsertRecord batchRecord = new SQLServerBulkBatchInsertRecord(
batchParamValues, columnList, valueList, null);
batchParamValues, bcOperationColumnList, bcOperationValueList, null);

for (int i = 1; i <= rs.getColumnCount(); i++) {
Column c = rs.getColumn(i);
Expand All @@ -2335,8 +2355,8 @@ public int[] executeBatch() throws SQLServerException, BatchUpdateException, SQL
} else {
jdbctype = ti.getSSType().getJDBCType().getIntValue();
}
if (null != columnList && !columnList.isEmpty()) {
int columnIndex = columnList.indexOf(c.getColumnName());
if (null != bcOperationColumnList && !bcOperationColumnList.isEmpty()) {
int columnIndex = bcOperationColumnList.indexOf(c.getColumnName());
if (columnIndex > -1) {
columnMappings.put(columnIndex + 1, i);
batchRecord.addColumnMetadata(columnIndex + 1, c.getColumnName(), jdbctype,
Expand All @@ -2348,20 +2368,23 @@ public int[] executeBatch() throws SQLServerException, BatchUpdateException, SQL
}
}

SQLServerBulkCopy bcOperation = new SQLServerBulkCopy(connection);
SQLServerBulkCopyOptions option = new SQLServerBulkCopyOptions();
option.setBulkCopyTimeout(queryTimeout);
bcOperation.setBulkCopyOptions(option);
bcOperation.setDestinationTableName(tableName);
if (columnMappings != null) {
for (Entry<Integer, Integer> pair : columnMappings.entrySet()) {
bcOperation.addColumnMapping(pair.getKey(), pair.getValue());
if (null == bcOperation) {
bcOperation = new SQLServerBulkCopy(connection);
SQLServerBulkCopyOptions option = new SQLServerBulkCopyOptions();
option.setBulkCopyTimeout(queryTimeout);
bcOperation.setBulkCopyOptions(option);
bcOperation.setDestinationTableName(bcOperationTableName);
if (columnMappings != null) {
for (Entry<Integer, Integer> pair : columnMappings.entrySet()) {
bcOperation.addColumnMapping(pair.getKey(), pair.getValue());
}
}
bcOperation.setStmtColumnEncriptionSetting(this.getStmtColumnEncriptionSetting());
bcOperation.setDestinationTableMetadata(rs);
}
bcOperation.setStmtColumnEncriptionSetting(this.getStmtColumnEncriptionSetting());
bcOperation.setDestinationTableMetadata(rs);

bcOperation.writeToServer(batchRecord);
bcOperation.close();

updateCounts = new int[batchParamValues.size()];
for (int i = 0; i < batchParamValues.size(); ++i) {
updateCounts[i] = 1;
Expand Down Expand Up @@ -2471,9 +2494,17 @@ public long[] executeLargeBatch() throws SQLServerException, BatchUpdateExceptio
}
}

String tableName = parseUserSQLForTableNameDW(false, false, false, false);
ArrayList<String> columnList = parseUserSQLForColumnListDW();
ArrayList<String> valueList = parseUserSQLForValueListDW(false);
if (null == bcOperationTableName) {
bcOperationTableName = parseUserSQLForTableNameDW(false, false, false, false);
}

if (null == bcOperationColumnList) {
bcOperationColumnList = parseUserSQLForColumnListDW();
}

if (null == bcOperationValueList) {
bcOperationValueList = parseUserSQLForValueListDW(false);
}

checkAdditionalQuery();

Expand All @@ -2482,25 +2513,25 @@ public long[] executeLargeBatch() throws SQLServerException, BatchUpdateExceptio
stmtColumnEncriptionSetting);
SQLServerResultSet rs = stmt
.executeQueryInternal("sp_executesql N'SET FMTONLY ON SELECT * FROM "
+ Util.escapeSingleQuotes(tableName) + " '")) {
if (null != columnList && !columnList.isEmpty()) {
if (columnList.size() != valueList.size()) {
+ Util.escapeSingleQuotes(bcOperationTableName) + " '")) {
if (null != bcOperationColumnList && !bcOperationColumnList.isEmpty()) {
if (bcOperationColumnList.size() != bcOperationValueList.size()) {
MessageFormat form = new MessageFormat(
SQLServerException.getErrString("R_colNotMatchTable"));
Object[] msgArgs = {columnList.size(), valueList.size()};
Object[] msgArgs = {bcOperationColumnList.size(), bcOperationValueList.size()};
throw new IllegalArgumentException(form.format(msgArgs));
}
} else {
if (rs.getColumnCount() != valueList.size()) {
if (rs.getColumnCount() != bcOperationValueList.size()) {
MessageFormat form = new MessageFormat(
SQLServerException.getErrString("R_colNotMatchTable"));
Object[] msgArgs = {columnList != null ? columnList.size() : 0, valueList.size()};
Object[] msgArgs = {bcOperationColumnList!= null ? bcOperationColumnList.size() : 0, bcOperationValueList.size()};
throw new IllegalArgumentException(form.format(msgArgs));
}
}

SQLServerBulkBatchInsertRecord batchRecord = new SQLServerBulkBatchInsertRecord(
batchParamValues, columnList, valueList, null);
batchParamValues, bcOperationColumnList, bcOperationValueList, null);

for (int i = 1; i <= rs.getColumnCount(); i++) {
Column c = rs.getColumn(i);
Expand All @@ -2517,15 +2548,18 @@ public long[] executeLargeBatch() throws SQLServerException, BatchUpdateExceptio
ti.getScale());
}

SQLServerBulkCopy bcOperation = new SQLServerBulkCopy(connection);
SQLServerBulkCopyOptions option = new SQLServerBulkCopyOptions();
option.setBulkCopyTimeout(queryTimeout);
bcOperation.setBulkCopyOptions(option);
bcOperation.setDestinationTableName(tableName);
bcOperation.setStmtColumnEncriptionSetting(this.getStmtColumnEncriptionSetting());
bcOperation.setDestinationTableMetadata(rs);
if (null == bcOperation) {
bcOperation = new SQLServerBulkCopy(connection);
SQLServerBulkCopyOptions option = new SQLServerBulkCopyOptions();
option.setBulkCopyTimeout(queryTimeout);
bcOperation.setBulkCopyOptions(option);
bcOperation.setDestinationTableName(bcOperationTableName);
bcOperation.setStmtColumnEncriptionSetting(this.getStmtColumnEncriptionSetting());
bcOperation.setDestinationTableMetadata(rs);
}

bcOperation.writeToServer(batchRecord);
bcOperation.close();

updateCounts = new long[batchParamValues.size()];
for (int i = 0; i < batchParamValues.size(); ++i) {
updateCounts[i] = 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
*/
package com.microsoft.sqlserver.jdbc.unit.statement;

import static org.junit.Assert.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.assertEquals;
Expand Down Expand Up @@ -109,6 +110,36 @@ public void testBatchUpdateCountTrueOnFirstPstmtSpPrepare() throws Exception {
testBatchUpdateCountWith(5, 4, true, "prepare", expectedUpdateCount);
}

@Test
public void testSqlServerBulkCopyCachingPstmtLevel() throws Exception {
Calendar gmtCal = Calendar.getInstance(TimeZone.getTimeZone("GMT"));
long ms = 1578743412000L;

try (Connection con = DriverManager.getConnection(
connectionString + ";useBulkCopyForBatchInsert=true;sendTemporalDataTypesAsStringForBulkCopy=false;");
Statement stmt = con.createStatement();
PreparedStatement pstmt = con.prepareStatement("INSERT INTO " + timestampTable1 + " VALUES(?)")) {

TestUtils.dropTableIfExists(timestampTable1, stmt);
String createSql = "CREATE TABLE " + timestampTable1 + " (c1 DATETIME2(3))";
stmt.execute(createSql);

Field cachedBulkCopyOperationField = pstmt.getClass().getDeclaredField("bcOperation");
cachedBulkCopyOperationField.setAccessible(true);
Object cachedBulkCopyOperation = cachedBulkCopyOperationField.get(pstmt);
assertEquals(null, cachedBulkCopyOperation, "SqlServerBulkCopy object should not be cached yet.");

Timestamp timestamp = new Timestamp(ms);

pstmt.setTimestamp(1, timestamp, gmtCal);
pstmt.addBatch();
pstmt.executeBatch();

cachedBulkCopyOperation = cachedBulkCopyOperationField.get(pstmt);
assertNotNull("SqlServerBulkCopy object should be cached.", cachedBulkCopyOperation);
}
}

@Test
public void testValidTimezoneForTimestampBatchInsertWithBulkCopy() throws Exception {
Calendar gmtCal = Calendar.getInstance(TimeZone.getTimeZone("GMT"));
Expand Down