diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerPreparedStatement.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerPreparedStatement.java index f3a441581..a25aa2cbd 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerPreparedStatement.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerPreparedStatement.java @@ -146,6 +146,14 @@ private void setPreparedStatementHandle(int handle) { */ private static final Pattern execEscapePattern = Pattern.compile("^\\s*(?i)(?:exec|execute)\\b"); + /** + * For caching data related to batch insert with bulkcopy + */ + private SQLServerBulkCopy bcOperation = null; + private String bcOperationTableName = null; + private ArrayList bcOperationColumnList = null; + private ArrayList bcOperationValueList = null; + /** Returns the prepared statement SQL */ @Override public String toString() { @@ -392,6 +400,10 @@ final void closeInternal() { // Clean up client-side state batchParamValues = null; + + if (null != bcOperation) { + bcOperation.close(); + } } /** @@ -2287,9 +2299,17 @@ public int[] executeBatch() throws SQLServerException, BatchUpdateException, SQL } } - String tableName = parseUserSQLForTableNameDW(false, false, false, false); - ArrayList columnList = parseUserSQLForColumnListDW(); - ArrayList valueList = parseUserSQLForValueListDW(false); + if (null == bcOperationTableName) { + bcOperationTableName = parseUserSQLForTableNameDW(false, false, false, false); + } + + if (null == bcOperationColumnList) { + bcOperationColumnList = parseUserSQLForColumnListDW(); + } + + if (null == bcOperationValueList) { + bcOperationValueList = parseUserSQLForValueListDW(false); + } checkAdditionalQuery(); @@ -2298,28 +2318,28 @@ public int[] executeBatch() throws SQLServerException, BatchUpdateException, SQL stmtColumnEncriptionSetting); SQLServerResultSet rs = stmt .executeQueryInternal("sp_executesql N'SET FMTONLY ON SELECT * FROM " - + Util.escapeSingleQuotes(tableName) + " '")) { + + Util.escapeSingleQuotes(bcOperationTableName) + " '")) { Map columnMappings = null; - if (null != columnList && !columnList.isEmpty()) { - if (columnList.size() != valueList.size()) { + if (null != bcOperationColumnList && !bcOperationColumnList.isEmpty()) { + if (bcOperationColumnList.size() != bcOperationValueList.size()) { MessageFormat form = new MessageFormat( SQLServerException.getErrString("R_colNotMatchTable")); - Object[] msgArgs = {columnList.size(), valueList.size()}; + Object[] msgArgs = {bcOperationColumnList.size(), bcOperationValueList.size()}; throw new IllegalArgumentException(form.format(msgArgs)); } - columnMappings = new HashMap<>(columnList.size()); + columnMappings = new HashMap<>(bcOperationColumnList.size()); } else { - if (rs.getColumnCount() != valueList.size()) { + if (rs.getColumnCount() != bcOperationValueList.size()) { MessageFormat form = new MessageFormat( SQLServerException.getErrString("R_colNotMatchTable")); - Object[] msgArgs = {rs.getColumnCount(), valueList.size()}; + Object[] msgArgs = {rs.getColumnCount(), bcOperationValueList.size()}; throw new IllegalArgumentException(form.format(msgArgs)); } } SQLServerBulkBatchInsertRecord batchRecord = new SQLServerBulkBatchInsertRecord( - batchParamValues, columnList, valueList, null); + batchParamValues, bcOperationColumnList, bcOperationValueList, null); for (int i = 1; i <= rs.getColumnCount(); i++) { Column c = rs.getColumn(i); @@ -2335,8 +2355,8 @@ public int[] executeBatch() throws SQLServerException, BatchUpdateException, SQL } else { jdbctype = ti.getSSType().getJDBCType().getIntValue(); } - if (null != columnList && !columnList.isEmpty()) { - int columnIndex = columnList.indexOf(c.getColumnName()); + if (null != bcOperationColumnList && !bcOperationColumnList.isEmpty()) { + int columnIndex = bcOperationColumnList.indexOf(c.getColumnName()); if (columnIndex > -1) { columnMappings.put(columnIndex + 1, i); batchRecord.addColumnMetadata(columnIndex + 1, c.getColumnName(), jdbctype, @@ -2348,20 +2368,23 @@ public int[] executeBatch() throws SQLServerException, BatchUpdateException, SQL } } - SQLServerBulkCopy bcOperation = new SQLServerBulkCopy(connection); - SQLServerBulkCopyOptions option = new SQLServerBulkCopyOptions(); - option.setBulkCopyTimeout(queryTimeout); - bcOperation.setBulkCopyOptions(option); - bcOperation.setDestinationTableName(tableName); - if (columnMappings != null) { - for (Entry pair : columnMappings.entrySet()) { - bcOperation.addColumnMapping(pair.getKey(), pair.getValue()); + if (null == bcOperation) { + bcOperation = new SQLServerBulkCopy(connection); + SQLServerBulkCopyOptions option = new SQLServerBulkCopyOptions(); + option.setBulkCopyTimeout(queryTimeout); + bcOperation.setBulkCopyOptions(option); + bcOperation.setDestinationTableName(bcOperationTableName); + if (columnMappings != null) { + for (Entry pair : columnMappings.entrySet()) { + bcOperation.addColumnMapping(pair.getKey(), pair.getValue()); + } } + bcOperation.setStmtColumnEncriptionSetting(this.getStmtColumnEncriptionSetting()); + bcOperation.setDestinationTableMetadata(rs); } - bcOperation.setStmtColumnEncriptionSetting(this.getStmtColumnEncriptionSetting()); - bcOperation.setDestinationTableMetadata(rs); + bcOperation.writeToServer(batchRecord); - bcOperation.close(); + updateCounts = new int[batchParamValues.size()]; for (int i = 0; i < batchParamValues.size(); ++i) { updateCounts[i] = 1; @@ -2471,9 +2494,17 @@ public long[] executeLargeBatch() throws SQLServerException, BatchUpdateExceptio } } - String tableName = parseUserSQLForTableNameDW(false, false, false, false); - ArrayList columnList = parseUserSQLForColumnListDW(); - ArrayList valueList = parseUserSQLForValueListDW(false); + if (null == bcOperationTableName) { + bcOperationTableName = parseUserSQLForTableNameDW(false, false, false, false); + } + + if (null == bcOperationColumnList) { + bcOperationColumnList = parseUserSQLForColumnListDW(); + } + + if (null == bcOperationValueList) { + bcOperationValueList = parseUserSQLForValueListDW(false); + } checkAdditionalQuery(); @@ -2482,25 +2513,25 @@ public long[] executeLargeBatch() throws SQLServerException, BatchUpdateExceptio stmtColumnEncriptionSetting); SQLServerResultSet rs = stmt .executeQueryInternal("sp_executesql N'SET FMTONLY ON SELECT * FROM " - + Util.escapeSingleQuotes(tableName) + " '")) { - if (null != columnList && !columnList.isEmpty()) { - if (columnList.size() != valueList.size()) { + + Util.escapeSingleQuotes(bcOperationTableName) + " '")) { + if (null != bcOperationColumnList && !bcOperationColumnList.isEmpty()) { + if (bcOperationColumnList.size() != bcOperationValueList.size()) { MessageFormat form = new MessageFormat( SQLServerException.getErrString("R_colNotMatchTable")); - Object[] msgArgs = {columnList.size(), valueList.size()}; + Object[] msgArgs = {bcOperationColumnList.size(), bcOperationValueList.size()}; throw new IllegalArgumentException(form.format(msgArgs)); } } else { - if (rs.getColumnCount() != valueList.size()) { + if (rs.getColumnCount() != bcOperationValueList.size()) { MessageFormat form = new MessageFormat( SQLServerException.getErrString("R_colNotMatchTable")); - Object[] msgArgs = {columnList != null ? columnList.size() : 0, valueList.size()}; + Object[] msgArgs = {bcOperationColumnList!= null ? bcOperationColumnList.size() : 0, bcOperationValueList.size()}; throw new IllegalArgumentException(form.format(msgArgs)); } } SQLServerBulkBatchInsertRecord batchRecord = new SQLServerBulkBatchInsertRecord( - batchParamValues, columnList, valueList, null); + batchParamValues, bcOperationColumnList, bcOperationValueList, null); for (int i = 1; i <= rs.getColumnCount(); i++) { Column c = rs.getColumn(i); @@ -2517,15 +2548,18 @@ public long[] executeLargeBatch() throws SQLServerException, BatchUpdateExceptio ti.getScale()); } - SQLServerBulkCopy bcOperation = new SQLServerBulkCopy(connection); - SQLServerBulkCopyOptions option = new SQLServerBulkCopyOptions(); - option.setBulkCopyTimeout(queryTimeout); - bcOperation.setBulkCopyOptions(option); - bcOperation.setDestinationTableName(tableName); - bcOperation.setStmtColumnEncriptionSetting(this.getStmtColumnEncriptionSetting()); - bcOperation.setDestinationTableMetadata(rs); + if (null == bcOperation) { + bcOperation = new SQLServerBulkCopy(connection); + SQLServerBulkCopyOptions option = new SQLServerBulkCopyOptions(); + option.setBulkCopyTimeout(queryTimeout); + bcOperation.setBulkCopyOptions(option); + bcOperation.setDestinationTableName(bcOperationTableName); + bcOperation.setStmtColumnEncriptionSetting(this.getStmtColumnEncriptionSetting()); + bcOperation.setDestinationTableMetadata(rs); + } + bcOperation.writeToServer(batchRecord); - bcOperation.close(); + updateCounts = new long[batchParamValues.size()]; for (int i = 0; i < batchParamValues.size(); ++i) { updateCounts[i] = 1; diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/unit/statement/BatchExecutionTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/unit/statement/BatchExecutionTest.java index 47b771dc3..7c33fc296 100644 --- a/src/test/java/com/microsoft/sqlserver/jdbc/unit/statement/BatchExecutionTest.java +++ b/src/test/java/com/microsoft/sqlserver/jdbc/unit/statement/BatchExecutionTest.java @@ -4,6 +4,7 @@ */ package com.microsoft.sqlserver.jdbc.unit.statement; +import static org.junit.Assert.assertNotNull; import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -109,6 +110,36 @@ public void testBatchUpdateCountTrueOnFirstPstmtSpPrepare() throws Exception { testBatchUpdateCountWith(5, 4, true, "prepare", expectedUpdateCount); } + @Test + public void testSqlServerBulkCopyCachingPstmtLevel() throws Exception { + Calendar gmtCal = Calendar.getInstance(TimeZone.getTimeZone("GMT")); + long ms = 1578743412000L; + + try (Connection con = DriverManager.getConnection( + connectionString + ";useBulkCopyForBatchInsert=true;sendTemporalDataTypesAsStringForBulkCopy=false;"); + Statement stmt = con.createStatement(); + PreparedStatement pstmt = con.prepareStatement("INSERT INTO " + timestampTable1 + " VALUES(?)")) { + + TestUtils.dropTableIfExists(timestampTable1, stmt); + String createSql = "CREATE TABLE " + timestampTable1 + " (c1 DATETIME2(3))"; + stmt.execute(createSql); + + Field cachedBulkCopyOperationField = pstmt.getClass().getDeclaredField("bcOperation"); + cachedBulkCopyOperationField.setAccessible(true); + Object cachedBulkCopyOperation = cachedBulkCopyOperationField.get(pstmt); + assertEquals(null, cachedBulkCopyOperation, "SqlServerBulkCopy object should not be cached yet."); + + Timestamp timestamp = new Timestamp(ms); + + pstmt.setTimestamp(1, timestamp, gmtCal); + pstmt.addBatch(); + pstmt.executeBatch(); + + cachedBulkCopyOperation = cachedBulkCopyOperationField.get(pstmt); + assertNotNull("SqlServerBulkCopy object should be cached.", cachedBulkCopyOperation); + } + } + @Test public void testValidTimezoneForTimestampBatchInsertWithBulkCopy() throws Exception { Calendar gmtCal = Calendar.getInstance(TimeZone.getTimeZone("GMT"));