Skip to content

Commit

Permalink
Caching SQLServerBulkCopy object for batch insert (#2435)
Browse files Browse the repository at this point in the history
* Caching SQLServerBulkCopy object for batch insert

* Changed to 'null == ...'

* Changed comment

* Added test
  • Loading branch information
tkyc authored Jun 17, 2024
1 parent 2a33f9a commit aae94a6
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,14 @@ private void setPreparedStatementHandle(int handle) {
*/
private static final Pattern execEscapePattern = Pattern.compile("^\\s*(?i)(?:exec|execute)\\b");

/**
* For caching data related to batch insert with bulkcopy
*/
private SQLServerBulkCopy bcOperation = null;
private String bcOperationTableName = null;
private ArrayList<String> bcOperationColumnList = null;
private ArrayList<String> bcOperationValueList = null;

/** Returns the prepared statement SQL */
@Override
public String toString() {
Expand Down Expand Up @@ -392,6 +400,10 @@ final void closeInternal() {

// Clean up client-side state
batchParamValues = null;

if (null != bcOperation) {
bcOperation.close();
}
}

/**
Expand Down Expand Up @@ -2287,9 +2299,17 @@ public int[] executeBatch() throws SQLServerException, BatchUpdateException, SQL
}
}

String tableName = parseUserSQLForTableNameDW(false, false, false, false);
ArrayList<String> columnList = parseUserSQLForColumnListDW();
ArrayList<String> valueList = parseUserSQLForValueListDW(false);
if (null == bcOperationTableName) {
bcOperationTableName = parseUserSQLForTableNameDW(false, false, false, false);
}

if (null == bcOperationColumnList) {
bcOperationColumnList = parseUserSQLForColumnListDW();
}

if (null == bcOperationValueList) {
bcOperationValueList = parseUserSQLForValueListDW(false);
}

checkAdditionalQuery();

Expand All @@ -2298,28 +2318,28 @@ public int[] executeBatch() throws SQLServerException, BatchUpdateException, SQL
stmtColumnEncriptionSetting);
SQLServerResultSet rs = stmt
.executeQueryInternal("sp_executesql N'SET FMTONLY ON SELECT * FROM "
+ Util.escapeSingleQuotes(tableName) + " '")) {
+ Util.escapeSingleQuotes(bcOperationTableName) + " '")) {
Map<Integer, Integer> columnMappings = null;
if (null != columnList && !columnList.isEmpty()) {
if (columnList.size() != valueList.size()) {
if (null != bcOperationColumnList && !bcOperationColumnList.isEmpty()) {
if (bcOperationColumnList.size() != bcOperationValueList.size()) {

MessageFormat form = new MessageFormat(
SQLServerException.getErrString("R_colNotMatchTable"));
Object[] msgArgs = {columnList.size(), valueList.size()};
Object[] msgArgs = {bcOperationColumnList.size(), bcOperationValueList.size()};
throw new IllegalArgumentException(form.format(msgArgs));
}
columnMappings = new HashMap<>(columnList.size());
columnMappings = new HashMap<>(bcOperationColumnList.size());
} else {
if (rs.getColumnCount() != valueList.size()) {
if (rs.getColumnCount() != bcOperationValueList.size()) {
MessageFormat form = new MessageFormat(
SQLServerException.getErrString("R_colNotMatchTable"));
Object[] msgArgs = {rs.getColumnCount(), valueList.size()};
Object[] msgArgs = {rs.getColumnCount(), bcOperationValueList.size()};
throw new IllegalArgumentException(form.format(msgArgs));
}
}

SQLServerBulkBatchInsertRecord batchRecord = new SQLServerBulkBatchInsertRecord(
batchParamValues, columnList, valueList, null);
batchParamValues, bcOperationColumnList, bcOperationValueList, null);

for (int i = 1; i <= rs.getColumnCount(); i++) {
Column c = rs.getColumn(i);
Expand All @@ -2335,8 +2355,8 @@ public int[] executeBatch() throws SQLServerException, BatchUpdateException, SQL
} else {
jdbctype = ti.getSSType().getJDBCType().getIntValue();
}
if (null != columnList && !columnList.isEmpty()) {
int columnIndex = columnList.indexOf(c.getColumnName());
if (null != bcOperationColumnList && !bcOperationColumnList.isEmpty()) {
int columnIndex = bcOperationColumnList.indexOf(c.getColumnName());
if (columnIndex > -1) {
columnMappings.put(columnIndex + 1, i);
batchRecord.addColumnMetadata(columnIndex + 1, c.getColumnName(), jdbctype,
Expand All @@ -2348,20 +2368,23 @@ public int[] executeBatch() throws SQLServerException, BatchUpdateException, SQL
}
}

SQLServerBulkCopy bcOperation = new SQLServerBulkCopy(connection);
SQLServerBulkCopyOptions option = new SQLServerBulkCopyOptions();
option.setBulkCopyTimeout(queryTimeout);
bcOperation.setBulkCopyOptions(option);
bcOperation.setDestinationTableName(tableName);
if (columnMappings != null) {
for (Entry<Integer, Integer> pair : columnMappings.entrySet()) {
bcOperation.addColumnMapping(pair.getKey(), pair.getValue());
if (null == bcOperation) {
bcOperation = new SQLServerBulkCopy(connection);
SQLServerBulkCopyOptions option = new SQLServerBulkCopyOptions();
option.setBulkCopyTimeout(queryTimeout);
bcOperation.setBulkCopyOptions(option);
bcOperation.setDestinationTableName(bcOperationTableName);
if (columnMappings != null) {
for (Entry<Integer, Integer> pair : columnMappings.entrySet()) {
bcOperation.addColumnMapping(pair.getKey(), pair.getValue());
}
}
bcOperation.setStmtColumnEncriptionSetting(this.getStmtColumnEncriptionSetting());
bcOperation.setDestinationTableMetadata(rs);
}
bcOperation.setStmtColumnEncriptionSetting(this.getStmtColumnEncriptionSetting());
bcOperation.setDestinationTableMetadata(rs);

bcOperation.writeToServer(batchRecord);
bcOperation.close();

updateCounts = new int[batchParamValues.size()];
for (int i = 0; i < batchParamValues.size(); ++i) {
updateCounts[i] = 1;
Expand Down Expand Up @@ -2471,9 +2494,17 @@ public long[] executeLargeBatch() throws SQLServerException, BatchUpdateExceptio
}
}

String tableName = parseUserSQLForTableNameDW(false, false, false, false);
ArrayList<String> columnList = parseUserSQLForColumnListDW();
ArrayList<String> valueList = parseUserSQLForValueListDW(false);
if (null == bcOperationTableName) {
bcOperationTableName = parseUserSQLForTableNameDW(false, false, false, false);
}

if (null == bcOperationColumnList) {
bcOperationColumnList = parseUserSQLForColumnListDW();
}

if (null == bcOperationValueList) {
bcOperationValueList = parseUserSQLForValueListDW(false);
}

checkAdditionalQuery();

Expand All @@ -2482,25 +2513,25 @@ public long[] executeLargeBatch() throws SQLServerException, BatchUpdateExceptio
stmtColumnEncriptionSetting);
SQLServerResultSet rs = stmt
.executeQueryInternal("sp_executesql N'SET FMTONLY ON SELECT * FROM "
+ Util.escapeSingleQuotes(tableName) + " '")) {
if (null != columnList && !columnList.isEmpty()) {
if (columnList.size() != valueList.size()) {
+ Util.escapeSingleQuotes(bcOperationTableName) + " '")) {
if (null != bcOperationColumnList && !bcOperationColumnList.isEmpty()) {
if (bcOperationColumnList.size() != bcOperationValueList.size()) {
MessageFormat form = new MessageFormat(
SQLServerException.getErrString("R_colNotMatchTable"));
Object[] msgArgs = {columnList.size(), valueList.size()};
Object[] msgArgs = {bcOperationColumnList.size(), bcOperationValueList.size()};
throw new IllegalArgumentException(form.format(msgArgs));
}
} else {
if (rs.getColumnCount() != valueList.size()) {
if (rs.getColumnCount() != bcOperationValueList.size()) {
MessageFormat form = new MessageFormat(
SQLServerException.getErrString("R_colNotMatchTable"));
Object[] msgArgs = {columnList != null ? columnList.size() : 0, valueList.size()};
Object[] msgArgs = {bcOperationColumnList!= null ? bcOperationColumnList.size() : 0, bcOperationValueList.size()};
throw new IllegalArgumentException(form.format(msgArgs));
}
}

SQLServerBulkBatchInsertRecord batchRecord = new SQLServerBulkBatchInsertRecord(
batchParamValues, columnList, valueList, null);
batchParamValues, bcOperationColumnList, bcOperationValueList, null);

for (int i = 1; i <= rs.getColumnCount(); i++) {
Column c = rs.getColumn(i);
Expand All @@ -2517,15 +2548,18 @@ public long[] executeLargeBatch() throws SQLServerException, BatchUpdateExceptio
ti.getScale());
}

SQLServerBulkCopy bcOperation = new SQLServerBulkCopy(connection);
SQLServerBulkCopyOptions option = new SQLServerBulkCopyOptions();
option.setBulkCopyTimeout(queryTimeout);
bcOperation.setBulkCopyOptions(option);
bcOperation.setDestinationTableName(tableName);
bcOperation.setStmtColumnEncriptionSetting(this.getStmtColumnEncriptionSetting());
bcOperation.setDestinationTableMetadata(rs);
if (null == bcOperation) {
bcOperation = new SQLServerBulkCopy(connection);
SQLServerBulkCopyOptions option = new SQLServerBulkCopyOptions();
option.setBulkCopyTimeout(queryTimeout);
bcOperation.setBulkCopyOptions(option);
bcOperation.setDestinationTableName(bcOperationTableName);
bcOperation.setStmtColumnEncriptionSetting(this.getStmtColumnEncriptionSetting());
bcOperation.setDestinationTableMetadata(rs);
}

bcOperation.writeToServer(batchRecord);
bcOperation.close();

updateCounts = new long[batchParamValues.size()];
for (int i = 0; i < batchParamValues.size(); ++i) {
updateCounts[i] = 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
*/
package com.microsoft.sqlserver.jdbc.unit.statement;

import static org.junit.Assert.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.assertEquals;
Expand Down Expand Up @@ -109,6 +110,36 @@ public void testBatchUpdateCountTrueOnFirstPstmtSpPrepare() throws Exception {
testBatchUpdateCountWith(5, 4, true, "prepare", expectedUpdateCount);
}

@Test
public void testSqlServerBulkCopyCachingPstmtLevel() throws Exception {
Calendar gmtCal = Calendar.getInstance(TimeZone.getTimeZone("GMT"));
long ms = 1578743412000L;

try (Connection con = DriverManager.getConnection(
connectionString + ";useBulkCopyForBatchInsert=true;sendTemporalDataTypesAsStringForBulkCopy=false;");
Statement stmt = con.createStatement();
PreparedStatement pstmt = con.prepareStatement("INSERT INTO " + timestampTable1 + " VALUES(?)")) {

TestUtils.dropTableIfExists(timestampTable1, stmt);
String createSql = "CREATE TABLE " + timestampTable1 + " (c1 DATETIME2(3))";
stmt.execute(createSql);

Field cachedBulkCopyOperationField = pstmt.getClass().getDeclaredField("bcOperation");
cachedBulkCopyOperationField.setAccessible(true);
Object cachedBulkCopyOperation = cachedBulkCopyOperationField.get(pstmt);
assertEquals(null, cachedBulkCopyOperation, "SqlServerBulkCopy object should not be cached yet.");

Timestamp timestamp = new Timestamp(ms);

pstmt.setTimestamp(1, timestamp, gmtCal);
pstmt.addBatch();
pstmt.executeBatch();

cachedBulkCopyOperation = cachedBulkCopyOperationField.get(pstmt);
assertNotNull("SqlServerBulkCopy object should be cached.", cachedBulkCopyOperation);
}
}

@Test
public void testValidTimezoneForTimestampBatchInsertWithBulkCopy() throws Exception {
Calendar gmtCal = Calendar.getInstance(TimeZone.getTimeZone("GMT"));
Expand Down

0 comments on commit aae94a6

Please sign in to comment.