From cdeb46845c59d3195ddd2755d8b1accea25274be Mon Sep 17 00:00:00 2001 From: lilgreenbird Date: Fri, 10 May 2024 17:25:11 -0700 Subject: [PATCH 01/12] test updates --- .../jdbc/AlwaysEncrypted/AESetup.java | 6 ++-- .../CallableStatementTest.java | 2 +- .../jdbc/SQLServerConnectionTest.java | 36 ++++++++++++++++--- .../microsoft/sqlserver/jdbc/TestUtils.java | 33 +++++------------ .../jdbc/connection/PoolingTest.java | 4 +++ .../jdbc/connection/TimeoutTest.java | 27 ++++++++++++++ .../DatabaseMetaDataTest.java | 31 ++++++++-------- .../sqlserver/jdbc/fedauth/FedauthWithAE.java | 2 +- .../sqlserver/jdbc/fips/FipsTest.java | 19 +++++++--- .../jdbc/unit/SQLServerErrorTest.java | 7 +++- .../unit/statement/BatchExecutionTest.java | 8 +++++ .../unit/statement/PreparedStatementTest.java | 4 ++- .../sqlserver/testframework/AbstractTest.java | 28 +++++++++------ 13 files changed, 142 insertions(+), 65 deletions(-) diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/AlwaysEncrypted/AESetup.java b/src/test/java/com/microsoft/sqlserver/jdbc/AlwaysEncrypted/AESetup.java index 02336f79c..3fbf0d4d8 100644 --- a/src/test/java/com/microsoft/sqlserver/jdbc/AlwaysEncrypted/AESetup.java +++ b/src/test/java/com/microsoft/sqlserver/jdbc/AlwaysEncrypted/AESetup.java @@ -339,7 +339,7 @@ protected static void createTable(String tableName, String cekName, String table TestUtils.dropTableIfExists(tableName, stmt); sql = String.format(createSql, tableName, sql); stmt.execute(sql); - stmt.execute("DBCC FREEPROCCACHE"); + TestUtils.freeProcCache(stmt); } catch (SQLException e) { fail(e.getMessage()); } @@ -373,7 +373,7 @@ protected static void createPrecisionTable(String tableName, String table[][], S } sql = String.format(createSql, tableName, sql); stmt.execute(sql); - stmt.execute("DBCC FREEPROCCACHE"); + TestUtils.freeProcCache(stmt); } catch (SQLException e) { fail(e.getMessage()); } @@ -401,7 +401,7 @@ protected static void createScaleTable(String tableName, String table[][], Strin sql = String.format(createSql, tableName, sql); stmt.execute(sql); - stmt.execute("DBCC FREEPROCCACHE"); + TestUtils.freeProcCache(stmt); } catch (SQLException e) { fail(e.getMessage()); } diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/AlwaysEncrypted/CallableStatementTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/AlwaysEncrypted/CallableStatementTest.java index 95531d698..d259b35a5 100644 --- a/src/test/java/com/microsoft/sqlserver/jdbc/AlwaysEncrypted/CallableStatementTest.java +++ b/src/test/java/com/microsoft/sqlserver/jdbc/AlwaysEncrypted/CallableStatementTest.java @@ -2201,7 +2201,7 @@ protected static void createDateTableCallableStatement(String cekName) throws SQ SQLServerStatement stmt = (SQLServerStatement) con.createStatement()) { TestUtils.dropTableIfExists(DATE_TABLE_AE, stmt); stmt.execute(sql); - stmt.execute("DBCC FREEPROCCACHE"); + TestUtils.freeProcCache(stmt); } catch (SQLException e) { fail(e.getMessage()); } diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/SQLServerConnectionTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/SQLServerConnectionTest.java index 12eea46cf..bdac7c106 100644 --- a/src/test/java/com/microsoft/sqlserver/jdbc/SQLServerConnectionTest.java +++ b/src/test/java/com/microsoft/sqlserver/jdbc/SQLServerConnectionTest.java @@ -482,7 +482,12 @@ public void testConnectCountInLoginAndCorrectRetryCount() { assertTrue(con == null, TestResource.getResource("R_shouldNotConnect")); } } catch (Exception e) { - assertTrue(e.getMessage().contains(TestResource.getResource("R_cannotOpenDatabase")), e.getMessage()); + assertTrue( + e.getMessage().contains(TestResource.getResource("R_cannotOpenDatabase")) + || (TestUtils.getProperty(connectionString, "msiClientId") != null + && e.getMessage().toLowerCase() + .contains(TestResource.getResource("R_loginFailedMSI").toLowerCase())), + e.getMessage()); long totalTime = System.currentTimeMillis() - timerStart; // Maximum is unknown, but is needs to be less than longLoginTimeout or else this is an issue. @@ -795,13 +800,22 @@ public void testIncorrectDatabase() throws SQLException { assertTrue(timeDiff <= milsecs, form.format(msgArgs)); } } catch (Exception e) { - assertTrue(e.getMessage().contains(TestResource.getResource("R_cannotOpenDatabase")), e.getMessage()); + assertTrue( + e.getMessage().contains(TestResource.getResource("R_cannotOpenDatabase")) + || (TestUtils.getProperty(connectionString, "msiClientId") != null + && e.getMessage().toLowerCase() + .contains(TestResource.getResource("R_loginFailedMSI").toLowerCase())), + e.getMessage()); timerEnd = System.currentTimeMillis(); } } @Test public void testIncorrectUserName() throws SQLException { + String auth = TestUtils.getProperty(connectionString, "authentication"); + org.junit.Assume.assumeTrue(auth != null + && (auth.equalsIgnoreCase("SqlPassword") || auth.equalsIgnoreCase("ActiveDirectoryPassword"))); + long timerStart = 0; long timerEnd = 0; final long milsecs = threshHoldForNoRetryInMilliseconds; @@ -819,13 +833,22 @@ public void testIncorrectUserName() throws SQLException { assertTrue(timeDiff <= milsecs, form.format(msgArgs)); } } catch (Exception e) { - assertTrue(e.getMessage().contains(TestResource.getResource("R_loginFailed"))); + assertTrue( + e.getMessage().contains(TestResource.getResource("R_loginFailed")) + || (TestUtils.getProperty(connectionString, "msiClientId") != null + && e.getMessage().toLowerCase() + .contains(TestResource.getResource("R_loginFailedMSI").toLowerCase())), + e.getMessage()); timerEnd = System.currentTimeMillis(); } } @Test public void testIncorrectPassword() throws SQLException { + String auth = TestUtils.getProperty(connectionString, "authentication"); + org.junit.Assume.assumeTrue(auth != null + && (auth.equalsIgnoreCase("SqlPassword") || auth.equalsIgnoreCase("ActiveDirectoryPassword"))); + long timerStart = 0; long timerEnd = 0; final long milsecs = threshHoldForNoRetryInMilliseconds; @@ -843,7 +866,12 @@ public void testIncorrectPassword() throws SQLException { assertTrue(timeDiff <= milsecs, form.format(msgArgs)); } } catch (Exception e) { - assertTrue(e.getMessage().contains(TestResource.getResource("R_loginFailed"))); + assertTrue( + e.getMessage().contains(TestResource.getResource("R_loginFailed")) + || (TestUtils.getProperty(connectionString, "msiClientId") != null + && e.getMessage().toLowerCase() + .contains(TestResource.getResource("R_loginFailedMSI").toLowerCase())), + e.getMessage()); timerEnd = System.currentTimeMillis(); } } diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/TestUtils.java b/src/test/java/com/microsoft/sqlserver/jdbc/TestUtils.java index 8db118b4d..4cbab64ec 100644 --- a/src/test/java/com/microsoft/sqlserver/jdbc/TestUtils.java +++ b/src/test/java/com/microsoft/sqlserver/jdbc/TestUtils.java @@ -535,31 +535,7 @@ public static void dropDatabaseIfExists(String databaseName, String connectionSt */ public static void dropSchemaIfExists(String schemaName, Statement stmt) throws SQLException { stmt.execute("if EXISTS (SELECT * FROM sys.schemas where name = '" + escapeSingleQuotes(schemaName) - + "') DROP SCHEMA " + AbstractSQLGenerator.escapeIdentifier(schemaName)); - } - - /** - * mimic "DROP USER..." - * - * @param userName - * @param stmt - * @throws SQLException - */ - public static void dropUserIfExists(String userName, Statement stmt) throws SQLException { - stmt.execute("IF EXISTS (SELECT * FROM sys.sysusers where name = '" + escapeSingleQuotes(userName) - + "') DROP USER " + AbstractSQLGenerator.escapeIdentifier(userName)); - } - - /** - * mimic "DROP LOGIN..." - * - * @param userName - * @param stmt - * @throws SQLException - */ - public static void dropLoginIfExists(String userName, Statement stmt) throws SQLException { - stmt.execute("IF EXISTS (SELECT * FROM sys.sysusers where name = '" + escapeSingleQuotes(userName) - + "') DROP LOGIN " + AbstractSQLGenerator.escapeIdentifier(userName)); + + "') drop schema " + AbstractSQLGenerator.escapeIdentifier(schemaName)); } /** @@ -1135,4 +1111,11 @@ public static String getConnectionID( SQLServerConnection conn = (SQLServerConnection) physicalConnection.get(pc); return (String) traceID.get(conn); } + + public static void freeProcCache(Statement stmt) { + try { + stmt.execute("DBCC FREEPROCCACHE"); + // ignore error + } catch (Exception e) {} + } } diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/connection/PoolingTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/connection/PoolingTest.java index d99d846ef..57a11a725 100644 --- a/src/test/java/com/microsoft/sqlserver/jdbc/connection/PoolingTest.java +++ b/src/test/java/com/microsoft/sqlserver/jdbc/connection/PoolingTest.java @@ -152,6 +152,10 @@ public void testConnectionPoolClose() throws SQLException { @Test public void testConnectionPoolClientConnectionId() throws SQLException { + String auth = TestUtils.getProperty(connectionString, "authentication"); + org.junit.Assume.assumeTrue(auth != null + && (auth.equalsIgnoreCase("SqlPassword") || auth.equalsIgnoreCase("ActiveDirectoryPassword"))); + SQLServerXADataSource ds = new SQLServerXADataSource(); ds.setURL(connectionString); PooledConnection pc = null; diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/connection/TimeoutTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/connection/TimeoutTest.java index 268fda736..ffd9e925c 100644 --- a/src/test/java/com/microsoft/sqlserver/jdbc/connection/TimeoutTest.java +++ b/src/test/java/com/microsoft/sqlserver/jdbc/connection/TimeoutTest.java @@ -68,6 +68,8 @@ public void testDefaultLoginTimeout() { assertTrue( (e.getMessage().toLowerCase() .contains(TestResource.getResource("R_tcpipConnectionToHost").toLowerCase())) + || (TestUtils.getProperty(connectionString, "msiClientId") != null && e.getMessage() + .toLowerCase().contains(TestResource.getResource("R_loginFailedMSI").toLowerCase())) || ((isSqlAzure() || isSqlAzureDW()) ? e.getMessage().toLowerCase() .contains(TestResource.getResource("R_connectTimedOut").toLowerCase()) : false), e.getMessage()); @@ -95,6 +97,8 @@ public void testURLLoginTimeout() { assertTrue( (e.getMessage().toLowerCase() .contains(TestResource.getResource("R_tcpipConnectionToHost").toLowerCase())) + || (TestUtils.getProperty(connectionString, "msiClientId") != null && e.getMessage() + .toLowerCase().contains(TestResource.getResource("R_loginFailedMSI").toLowerCase())) || ((isSqlAzure() || isSqlAzureDW()) ? e.getMessage().toLowerCase() .contains(TestResource.getResource("R_connectTimedOut").toLowerCase()) : false), e.getMessage()); @@ -122,6 +126,8 @@ public void testDMLoginTimeoutApplied() { assertTrue( (e.getMessage().toLowerCase() .contains(TestResource.getResource("R_tcpipConnectionToHost").toLowerCase())) + || (TestUtils.getProperty(connectionString, "msiClientId") != null && e.getMessage() + .toLowerCase().contains(TestResource.getResource("R_loginFailedMSI").toLowerCase())) || ((isSqlAzure() || isSqlAzureDW()) ? e.getMessage().toLowerCase() .contains(TestResource.getResource("R_connectTimedOut").toLowerCase()) : false), e.getMessage()); @@ -150,6 +156,9 @@ public void testDMLoginTimeoutNotApplied() { assertTrue( (e.getMessage().toLowerCase() .contains(TestResource.getResource("R_tcpipConnectionToHost").toLowerCase())) + || (TestUtils.getProperty(connectionString, "msiClientId") != null + && e.getMessage().toLowerCase() + .contains(TestResource.getResource("R_loginFailedMSI").toLowerCase())) || ((isSqlAzure() || isSqlAzureDW()) ? e.getMessage().toLowerCase() .contains(TestResource.getResource("R_connectTimedOut").toLowerCase()) : false), e.getMessage()); @@ -181,6 +190,8 @@ public void testConnectRetryDisable() { assertTrue( e.getMessage().matches(TestUtils.formatErrorMsg("R_tcpipConnectionFailed")) + || (TestUtils.getProperty(connectionString, "msiClientId") != null && e.getMessage() + .toLowerCase().contains(TestResource.getResource("R_loginFailedMSI").toLowerCase())) || ((isSqlAzure() || isSqlAzureDW()) ? e.getMessage().toLowerCase() .contains(TestResource.getResource("R_connectTimedOut").toLowerCase()) : false), e.getMessage()); @@ -207,6 +218,8 @@ public void testConnectRetryBadServer() { assertTrue( (e.getMessage().toLowerCase() .contains(TestResource.getResource("R_tcpipConnectionToHost").toLowerCase())) + || (TestUtils.getProperty(connectionString, "msiClientId") != null && e.getMessage() + .toLowerCase().contains(TestResource.getResource("R_loginFailedMSI").toLowerCase())) || ((isSqlAzure() || isSqlAzureDW()) ? e.getMessage().toLowerCase() .contains(TestResource.getResource("R_connectTimedOut").toLowerCase()) : false), e.getMessage()); @@ -220,6 +233,10 @@ public void testConnectRetryBadServer() { // Test connect retry for database error @Test public void testConnectRetryServerError() { + String auth = TestUtils.getProperty(connectionString, "authentication"); + org.junit.Assume.assumeTrue(auth != null + && (auth.equalsIgnoreCase("SqlPassword") || auth.equalsIgnoreCase("ActiveDirectoryPassword"))); + long totalTime = 0; long timerStart = System.currentTimeMillis(); int interval = defaultTimeout; // long interval so we can tell if there was a retry @@ -237,6 +254,8 @@ public void testConnectRetryServerError() { assertTrue( (e.getMessage().toLowerCase() .contains(TestResource.getResource("R_cannotOpenDatabase").toLowerCase())) + || (TestUtils.getProperty(connectionString, "msiClientId") != null && e.getMessage() + .toLowerCase().contains(TestResource.getResource("R_loginFailedMSI").toLowerCase())) || ((isSqlAzure() || isSqlAzureDW()) ? e.getMessage().toLowerCase() .contains(TestResource.getResource("R_connectTimedOut").toLowerCase()) : false), e.getMessage()); @@ -252,6 +271,10 @@ public void testConnectRetryServerError() { // Test connect retry for database error using Datasource @Test public void testConnectRetryServerErrorDS() { + String auth = TestUtils.getProperty(connectionString, "authentication"); + org.junit.Assume.assumeTrue(auth != null + && (auth.equalsIgnoreCase("SqlPassword") || auth.equalsIgnoreCase("ActiveDirectoryPassword"))); + long totalTime = 0; long timerStart = System.currentTimeMillis(); int interval = defaultTimeout; // long interval so we can tell if there was a retry @@ -270,6 +293,8 @@ public void testConnectRetryServerErrorDS() { assertTrue( (e.getMessage().toLowerCase() .contains(TestResource.getResource("R_cannotOpenDatabase").toLowerCase())) + || (TestUtils.getProperty(connectionString, "msiClientId") != null && e.getMessage() + .toLowerCase().contains(TestResource.getResource("R_loginFailedMSI").toLowerCase())) || ((isSqlAzure() || isSqlAzureDW()) ? e.getMessage().toLowerCase() .contains(TestResource.getResource("R_connectTimedOut").toLowerCase()) : false), e.getMessage()); @@ -303,6 +328,8 @@ public void testConnectRetryTimeout() { assertTrue( (e.getMessage().toLowerCase() .contains(TestResource.getResource("R_cannotOpenDatabase").toLowerCase())) + || (TestUtils.getProperty(connectionString, "msiClientId") != null && e.getMessage() + .toLowerCase().contains(TestResource.getResource("R_loginFailedMSI").toLowerCase())) || ((isSqlAzure() || isSqlAzureDW()) ? e.getMessage().toLowerCase() .contains(TestResource.getResource("R_connectTimedOut").toLowerCase()) : false), e.getMessage()); diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/databasemetadata/DatabaseMetaDataTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/databasemetadata/DatabaseMetaDataTest.java index c140c19f8..11cefbc2a 100644 --- a/src/test/java/com/microsoft/sqlserver/jdbc/databasemetadata/DatabaseMetaDataTest.java +++ b/src/test/java/com/microsoft/sqlserver/jdbc/databasemetadata/DatabaseMetaDataTest.java @@ -62,7 +62,6 @@ public class DatabaseMetaDataTest extends AbstractTest { private static final String uuid = UUID.randomUUID().toString().replaceAll("-", ""); private static final String tableName = RandomUtil.getIdentifier("DBMetadataTable"); private static final String functionName = RandomUtil.getIdentifier("DBMetadataFunction"); - private static final String newUserName = "newUser" + uuid; private static final String schema = "schema_demo" + uuid; private static final String escapedSchema = "schema\\_demo" + uuid; private static final String tableNameWithSchema = schema + ".resource"; @@ -196,6 +195,10 @@ public void testGetURL() throws SQLException { */ @Test public void testDBUserLogin() throws SQLException { + String auth = TestUtils.getProperty(connectionString, "authentication"); + org.junit.Assume.assumeTrue(auth != null + && (auth.equalsIgnoreCase("SqlPassword") || auth.equalsIgnoreCase("ActiveDirectoryPassword"))); + try (Connection conn = getConnection()) { DatabaseMetaData databaseMetaData = conn.getMetaData(); String connectionString = getConnectionString(); @@ -219,7 +222,8 @@ public void testDBUserLogin() throws SQLException { assertNotNull(userName, TestResource.getResource("R_userNameNull")); assertTrue(userName.equalsIgnoreCase(userFromConnectionString), - TestResource.getResource("R_userNameNotMatch")); + TestResource.getResource("R_userNameNotMatch") + "userName: " + userName + "from connectio string: " + + userFromConnectionString); } catch (Exception e) { fail(TestResource.getResource("R_unexpectedErrorMessage") + e.getMessage()); } @@ -228,13 +232,14 @@ public void testDBUserLogin() throws SQLException { @Test @Tag(Constants.xAzureSQLDW) public void testImpersonateGetUserName() throws SQLException { - String escapedNewUser = AbstractSQLGenerator.escapeIdentifier(newUserName); + String newUser = "newUser" + UUID.randomUUID(); try (Connection conn = getConnection(); Statement stmt = conn.createStatement()) { + String escapedNewUser = AbstractSQLGenerator.escapeIdentifier(newUser); String password = "password" + UUID.randomUUID(); - TestUtils.dropUserIfExists(newUserName, stmt); - TestUtils.dropLoginIfExists(newUserName, stmt); + stmt.execute("IF EXISTS (select * from sys.sysusers where name = '" + escapedNewUser + "') DROP USER" + + escapedNewUser); // create new user and login try { @@ -248,17 +253,17 @@ public void testImpersonateGetUserName() throws SQLException { } DatabaseMetaData databaseMetaData = conn.getMetaData(); - try (CallableStatement asOtherUser = conn.prepareCall("EXECUTE AS USER = '" + newUserName + "'")) { + try (CallableStatement asOtherUser = conn.prepareCall("EXECUTE AS USER = '" + newUser + "'")) { asOtherUser.execute(); - assertTrue(newUserName.equalsIgnoreCase(databaseMetaData.getUserName()), + assertTrue(newUser.equalsIgnoreCase(databaseMetaData.getUserName()), TestResource.getResource("R_userNameNotMatch")); } catch (Exception e) { fail(TestResource.getResource("R_unexpectedErrorMessage") + e.getMessage()); - } - } finally { - try (Connection conn = getConnection(); Statement stmt = conn.createStatement()) { - TestUtils.dropUserIfExists(newUserName, stmt); - TestUtils.dropLoginIfExists(newUserName, stmt); + } finally { + stmt.execute("IF EXISTS (select * from sys.sysusers where name = '" + escapedNewUser + "') DROP USER" + + escapedNewUser); + stmt.execute("IF EXISTS (select * from sys.sysusers where name = '" + escapedNewUser + "') DROP LOGIN" + + escapedNewUser); } } } @@ -1049,8 +1054,6 @@ public static void terminate() throws SQLException { TestUtils.dropTableWithSchemaIfExists(tableNameWithSchema, stmt); TestUtils.dropProcedureWithSchemaIfExists(sprocWithSchema, stmt); TestUtils.dropSchemaIfExists(schema, stmt); - TestUtils.dropUserIfExists(newUserName, stmt); - TestUtils.dropLoginIfExists(newUserName, stmt); } } } diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/fedauth/FedauthWithAE.java b/src/test/java/com/microsoft/sqlserver/jdbc/fedauth/FedauthWithAE.java index bfd09d3b9..1e0112f02 100644 --- a/src/test/java/com/microsoft/sqlserver/jdbc/fedauth/FedauthWithAE.java +++ b/src/test/java/com/microsoft/sqlserver/jdbc/fedauth/FedauthWithAE.java @@ -307,7 +307,7 @@ private void createCMK(String cmkName, String keyStoreName, String keyPath, Stat private void callDbccFreeProcCache() throws SQLException { try (Connection connection = DriverManager.getConnection(adPasswordConnectionStr); Statement stmt = connection.createStatement()) { - stmt.execute("DBCC FREEPROCCACHE"); + TestUtils.freeProcCache(stmt); } } diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/fips/FipsTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/fips/FipsTest.java index 10853f2f6..64d1ebdf1 100644 --- a/src/test/java/com/microsoft/sqlserver/jdbc/fips/FipsTest.java +++ b/src/test/java/com/microsoft/sqlserver/jdbc/fips/FipsTest.java @@ -50,7 +50,7 @@ public void fipsTrustServerCertificateTest() throws Exception { Assertions.fail(TestResource.getResource("R_expectedExceptionNotThrown")); } catch (SQLException e) { Assertions.assertTrue(e.getMessage().contains(TestResource.getResource("R_invalidFipsConfig")), - TestResource.getResource("R_invalidTrustCert")); + TestResource.getResource("R_invalidTrustCert") + ": " + e.getMessage()); } } @@ -62,13 +62,19 @@ public void fipsTrustServerCertificateTest() throws Exception { */ @Test public void fipsEncryptTest() throws Exception { + // test doesn't apply to managed identity as encrypt is set to on by default + String auth = TestUtils.getProperty(connectionString, "authentication"); + org.junit.Assume.assumeTrue(auth != null && !(auth.equalsIgnoreCase("ActiveDirectoryManagedIdentity") + || auth.equalsIgnoreCase("ActiveDirectoryMSI"))); + Properties props = buildConnectionProperties(); props.setProperty(Constants.ENCRYPT, Boolean.FALSE.toString()); + System.out.println("fipsEncryptTest connectionString=" + connectionString); try (Connection con = PrepUtil.getConnection(connectionString, props)) { Assertions.fail(TestResource.getResource("R_expectedExceptionNotThrown")); } catch (SQLException e) { Assertions.assertTrue(e.getMessage().contains(TestResource.getResource("R_invalidFipsConfig")), - TestResource.getResource("R_invalidEncrypt")); + TestResource.getResource("R_invalidTrustCert") + ": " + e.getMessage()); } } @@ -118,6 +124,11 @@ public void fipsDataSourcePropertyTest() throws Exception { */ @Test public void fipsDatSourceEncrypt() { + // test doesn't apply to managed identity as encrypt is set to on by default + String auth = TestUtils.getProperty(connectionString, "authentication"); + org.junit.Assume.assumeTrue(auth != null && !(auth.equalsIgnoreCase("ActiveDirectoryManagedIdentity") + || auth.equalsIgnoreCase("ActiveDirectoryMSI"))); + SQLServerDataSource ds = new SQLServerDataSource(); setDataSourceProperties(ds); ds.setEncrypt(Constants.FALSE); @@ -126,7 +137,7 @@ public void fipsDatSourceEncrypt() { Assertions.fail(TestResource.getResource("R_expectedExceptionNotThrown")); } catch (SQLException e) { Assertions.assertTrue(e.getMessage().contains(TestResource.getResource("R_invalidFipsConfig")), - TestResource.getResource("R_invalidEncrypt")); + TestResource.getResource("R_invalidEncrypt") + ": " + e.getMessage()); } } @@ -146,7 +157,7 @@ public void fipsDataSourceTrustServerCertificateTest() throws Exception { Assertions.fail(TestResource.getResource("R_expectedExceptionNotThrown")); } catch (SQLException e) { Assertions.assertTrue(e.getMessage().contains(TestResource.getResource("R_invalidFipsConfig")), - TestResource.getResource("R_invalidTrustCert")); + TestResource.getResource("R_invalidTrustCert") + ": " + e.getMessage()); } } diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/unit/SQLServerErrorTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/unit/SQLServerErrorTest.java index 2aef2cf43..7c7dc377e 100644 --- a/src/test/java/com/microsoft/sqlserver/jdbc/unit/SQLServerErrorTest.java +++ b/src/test/java/com/microsoft/sqlserver/jdbc/unit/SQLServerErrorTest.java @@ -41,7 +41,12 @@ public static void setupTests() throws Exception { @Test @Tag(Constants.xAzureSQLDW) - public void testLoginFailedError() { + public void testLoginFailedError() { + // test to remove password only valid for password auth + String auth = TestUtils.getProperty(connectionString, "authentication"); + org.junit.Assume.assumeTrue(auth != null + && (auth.equalsIgnoreCase("SqlPassword") || auth.equalsIgnoreCase("ActiveDirectoryPassword"))); + SQLServerDataSource ds = new SQLServerDataSource(); ds.setURL(connectionString); ds.setLoginTimeout(loginTimeOutInSeconds); diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/unit/statement/BatchExecutionTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/unit/statement/BatchExecutionTest.java index 0dace62b2..5b003ef15 100644 --- a/src/test/java/com/microsoft/sqlserver/jdbc/unit/statement/BatchExecutionTest.java +++ b/src/test/java/com/microsoft/sqlserver/jdbc/unit/statement/BatchExecutionTest.java @@ -186,8 +186,12 @@ public void testValidTimezonesDstTimestampBatchInsertWithBulkCopy() throws Excep Timestamp timestamp = new Timestamp(ms); pstmt.setTimestamp(1, timestamp, gmtCal); + pstmt.addBatch(); + pstmt.executeBatch(); + } catch (Exception e) { + fail(e.getMessage()); } // Insert Timestamp using bulkcopy for batch insert @@ -200,6 +204,8 @@ public void testValidTimezonesDstTimestampBatchInsertWithBulkCopy() throws Excep pstmt.setTimestamp(1, timestamp, gmtCal); pstmt.addBatch(); pstmt.executeBatch(); + } catch (Exception e) { + fail(e.getMessage()); } // Compare Timestamp values inserted, should be the same @@ -225,6 +231,8 @@ public void testValidTimezonesDstTimestampBatchInsertWithBulkCopy() throws Excep assertEquals(ts0, ts1, failureMsg); assertEquals(t0, t1, failureMsg); assertEquals(d0, d1, failureMsg); + } catch (Exception e) { + fail(e.getMessage()); } } } diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/unit/statement/PreparedStatementTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/unit/statement/PreparedStatementTest.java index 79862dfcd..b21372c96 100644 --- a/src/test/java/com/microsoft/sqlserver/jdbc/unit/statement/PreparedStatementTest.java +++ b/src/test/java/com/microsoft/sqlserver/jdbc/unit/statement/PreparedStatementTest.java @@ -258,7 +258,9 @@ public void testBatchedUnprepare() throws SQLException { con.setStatementPoolingCacheSize(0); // Clean-up proc cache - this.executeSQL(con, "DBCC FREEPROCCACHE;"); + try (Statement stmt = con.createStatement()) { + TestUtils.freeProcCache(stmt); + } String lookupUniqueifier = UUID.randomUUID().toString(); diff --git a/src/test/java/com/microsoft/sqlserver/testframework/AbstractTest.java b/src/test/java/com/microsoft/sqlserver/testframework/AbstractTest.java index 3590aa7f2..5078954c9 100644 --- a/src/test/java/com/microsoft/sqlserver/testframework/AbstractTest.java +++ b/src/test/java/com/microsoft/sqlserver/testframework/AbstractTest.java @@ -5,6 +5,8 @@ package com.microsoft.sqlserver.testframework; +import static org.junit.jupiter.api.Assertions.fail; + import java.io.ByteArrayOutputStream; import java.io.File; import java.io.FileInputStream; @@ -299,19 +301,23 @@ protected static void setupConnectionString() { } protected static void setConnection() throws Exception { - setupConnectionString(); + try { + setupConnectionString(); - Assertions.assertNotNull(connectionString, TestResource.getResource("R_ConnectionStringNull")); - Class.forName(Constants.MSSQL_JDBC_PACKAGE + ".SQLServerDriver"); - if (!SQLServerDriver.isRegistered()) { - SQLServerDriver.register(); - } - if (null == connection || connection.isClosed()) { - connection = getConnection(); - } - isSqlAzureOrAzureDW(connection); + Assertions.assertNotNull(connectionString, TestResource.getResource("R_ConnectionStringNull")); + Class.forName(Constants.MSSQL_JDBC_PACKAGE + ".SQLServerDriver"); + if (!SQLServerDriver.isRegistered()) { + SQLServerDriver.register(); + } + if (null == connection || connection.isClosed()) { + connection = getConnection(); + } + isSqlAzureOrAzureDW(connection); - checkSqlOS(connection); + checkSqlOS(connection); + } catch (Exception e) { + fail("setConnection failed, connectionString=" + connectionString + "\nException: " + e.getMessage()); + } } /** From bff4507498a0bdc0a875a3958c66208a67b2e44b Mon Sep 17 00:00:00 2001 From: lilgreenbird Date: Fri, 10 May 2024 17:40:35 -0700 Subject: [PATCH 02/12] update --- .../microsoft/sqlserver/jdbc/TestUtils.java | 26 ++++++++++++++++++- .../DatabaseMetaDataTest.java | 22 ++++++++-------- 2 files changed, 36 insertions(+), 12 deletions(-) diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/TestUtils.java b/src/test/java/com/microsoft/sqlserver/jdbc/TestUtils.java index 4cbab64ec..8bc97ada0 100644 --- a/src/test/java/com/microsoft/sqlserver/jdbc/TestUtils.java +++ b/src/test/java/com/microsoft/sqlserver/jdbc/TestUtils.java @@ -538,6 +538,30 @@ public static void dropSchemaIfExists(String schemaName, Statement stmt) throws + "') drop schema " + AbstractSQLGenerator.escapeIdentifier(schemaName)); } + /** + * mimic "DROP USER..." + * + * @param userName + * @param stmt + * @throws SQLException + */ + public static void dropUserIfExists(String userName, Statement stmt) throws SQLException { + stmt.execute("IF EXISTS (SELECT * FROM sys.sysusers where name = '" + escapeSingleQuotes(userName) + + "') DROP USER " + AbstractSQLGenerator.escapeIdentifier(userName)); + } + + /** + * mimic "DROP LOGIN..." + * + * @param userName + * @param stmt + * @throws SQLException + */ + public static void dropLoginIfExists(String userName, Statement stmt) throws SQLException { + stmt.execute("IF EXISTS (SELECT * FROM sys.sysusers where name = '" + escapeSingleQuotes(userName) + + "') DROP LOGIN " + AbstractSQLGenerator.escapeIdentifier(userName)); + } + /** *
      * This method drops objects for below types:
@@ -1115,7 +1139,7 @@ public static String getConnectionID(
     public static void freeProcCache(Statement stmt) {
         try {
             stmt.execute("DBCC FREEPROCCACHE");
-        // ignore error
+            // ignore error - some tests fails due to permission issues from managed identity, this does not seem to affect tests
         } catch (Exception e) {}
     }
 }
diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/databasemetadata/DatabaseMetaDataTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/databasemetadata/DatabaseMetaDataTest.java
index 11cefbc2a..65b76ae12 100644
--- a/src/test/java/com/microsoft/sqlserver/jdbc/databasemetadata/DatabaseMetaDataTest.java
+++ b/src/test/java/com/microsoft/sqlserver/jdbc/databasemetadata/DatabaseMetaDataTest.java
@@ -62,6 +62,7 @@ public class DatabaseMetaDataTest extends AbstractTest {
     private static final String uuid = UUID.randomUUID().toString().replaceAll("-", "");
     private static final String tableName = RandomUtil.getIdentifier("DBMetadataTable");
     private static final String functionName = RandomUtil.getIdentifier("DBMetadataFunction");
+    private static final String newUserName = "newUser" + uuid;
     private static final String schema = "schema_demo" + uuid;
     private static final String escapedSchema = "schema\\_demo" + uuid;
     private static final String tableNameWithSchema = schema + ".resource";
@@ -232,14 +233,13 @@ public void testDBUserLogin() throws SQLException {
     @Test
     @Tag(Constants.xAzureSQLDW)
     public void testImpersonateGetUserName() throws SQLException {
-        String newUser = "newUser" + UUID.randomUUID();
+        String escapedNewUser = AbstractSQLGenerator.escapeIdentifier(newUserName);
 
         try (Connection conn = getConnection(); Statement stmt = conn.createStatement()) {
-            String escapedNewUser = AbstractSQLGenerator.escapeIdentifier(newUser);
             String password = "password" + UUID.randomUUID();
 
-            stmt.execute("IF EXISTS (select * from sys.sysusers where name = '" + escapedNewUser + "') DROP USER"
-                    + escapedNewUser);
+            TestUtils.dropUserIfExists(newUserName, stmt);
+            TestUtils.dropLoginIfExists(newUserName, stmt);
 
             // create new user and login
             try {
@@ -253,17 +253,17 @@ public void testImpersonateGetUserName() throws SQLException {
             }
 
             DatabaseMetaData databaseMetaData = conn.getMetaData();
-            try (CallableStatement asOtherUser = conn.prepareCall("EXECUTE AS USER = '" + newUser + "'")) {
+            try (CallableStatement asOtherUser = conn.prepareCall("EXECUTE AS USER = '" + newUserName + "'")) {
                 asOtherUser.execute();
-                assertTrue(newUser.equalsIgnoreCase(databaseMetaData.getUserName()),
+                assertTrue(newUserName.equalsIgnoreCase(databaseMetaData.getUserName()),
                         TestResource.getResource("R_userNameNotMatch"));
             } catch (Exception e) {
                 fail(TestResource.getResource("R_unexpectedErrorMessage") + e.getMessage());
-            } finally {
-                stmt.execute("IF EXISTS (select * from sys.sysusers where name = '" + escapedNewUser + "') DROP USER"
-                        + escapedNewUser);
-                stmt.execute("IF EXISTS (select * from sys.sysusers where name = '" + escapedNewUser + "') DROP LOGIN"
-                        + escapedNewUser);
+            }
+        } finally {
+            try (Connection conn = getConnection(); Statement stmt = conn.createStatement()) {
+                TestUtils.dropUserIfExists(newUserName, stmt);
+                TestUtils.dropLoginIfExists(newUserName, stmt);
             }
         }
     }

From ccd98d1e37d267c330b1f60cc9f2e2c9ee094807 Mon Sep 17 00:00:00 2001
From: lilgreenbird 
Date: Fri, 10 May 2024 17:43:42 -0700
Subject: [PATCH 03/12] update

---
 .../sqlserver/jdbc/unit/statement/BatchExecutionTest.java     | 4 +---
 1 file changed, 1 insertion(+), 3 deletions(-)

diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/unit/statement/BatchExecutionTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/unit/statement/BatchExecutionTest.java
index 5b003ef15..6f08b8618 100644
--- a/src/test/java/com/microsoft/sqlserver/jdbc/unit/statement/BatchExecutionTest.java
+++ b/src/test/java/com/microsoft/sqlserver/jdbc/unit/statement/BatchExecutionTest.java
@@ -184,11 +184,9 @@ public void testValidTimezonesDstTimestampBatchInsertWithBulkCopy() throws Excep
                 stmt.execute(createSql);
 
                 Timestamp timestamp = new Timestamp(ms);
-
+                
                 pstmt.setTimestamp(1, timestamp, gmtCal);
-
                 pstmt.addBatch();
-
                 pstmt.executeBatch();
             } catch (Exception e) {
                 fail(e.getMessage());

From ad9d1c4dcc75b5714cdcf5c999d2b0fb957f1824 Mon Sep 17 00:00:00 2001
From: lilgreenbird 
Date: Mon, 13 May 2024 15:30:57 -0700
Subject: [PATCH 04/12] update

---
 .../sqlserver/jdbc/unit/statement/BatchExecutionTest.java       | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/unit/statement/BatchExecutionTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/unit/statement/BatchExecutionTest.java
index 6f08b8618..47b771dc3 100644
--- a/src/test/java/com/microsoft/sqlserver/jdbc/unit/statement/BatchExecutionTest.java
+++ b/src/test/java/com/microsoft/sqlserver/jdbc/unit/statement/BatchExecutionTest.java
@@ -184,7 +184,7 @@ public void testValidTimezonesDstTimestampBatchInsertWithBulkCopy() throws Excep
                 stmt.execute(createSql);
 
                 Timestamp timestamp = new Timestamp(ms);
-                
+
                 pstmt.setTimestamp(1, timestamp, gmtCal);
                 pstmt.addBatch();
                 pstmt.executeBatch();

From 7c962e9661a41c82d403aed2ffc5e87ba09d0f92 Mon Sep 17 00:00:00 2001
From: lilgreenbird 
Date: Mon, 13 May 2024 18:57:29 -0700
Subject: [PATCH 05/12] update

---
 .../sqlserver/jdbc/SQLServerConnection.java   | 256 ++++++++++--------
 1 file changed, 142 insertions(+), 114 deletions(-)

diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java
index 48275d2aa..e25bc6483 100644
--- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java
+++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java
@@ -6030,64 +6030,66 @@ private SqlAuthenticationToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throw
         }
 
         while (true) {
-            if (authenticationString.equalsIgnoreCase(SqlAuthentication.ACTIVE_DIRECTORY_PASSWORD.toString())) {
-                fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthToken(fedAuthInfo, user,
-                        activeConnectionProperties.getProperty(SQLServerDriverStringProperty.PASSWORD.toString()),
-                        authenticationString);
+            try {
+                if (authenticationString.equalsIgnoreCase(SqlAuthentication.ACTIVE_DIRECTORY_PASSWORD.toString())) {
+                    fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthToken(fedAuthInfo, user,
+                            activeConnectionProperties.getProperty(SQLServerDriverStringProperty.PASSWORD.toString()),
+                            authenticationString);
 
-                // Break out of the retry loop in successful case.
-                break;
-            } else if (authenticationString
-                    .equalsIgnoreCase(SqlAuthentication.ACTIVE_DIRECTORY_MANAGED_IDENTITY.toString())) {
+                    // Break out of the retry loop in successful case.
+                    break;
+                } else if (authenticationString
+                        .equalsIgnoreCase(SqlAuthentication.ACTIVE_DIRECTORY_MANAGED_IDENTITY.toString())) {
 
-                String managedIdentityClientId = activeConnectionProperties
-                        .getProperty(SQLServerDriverStringProperty.USER.toString());
+                    String managedIdentityClientId = activeConnectionProperties
+                            .getProperty(SQLServerDriverStringProperty.USER.toString());
+
+                    if (null != managedIdentityClientId && !managedIdentityClientId.isEmpty()) {
+                        fedAuthToken = SQLServerSecurityUtility.getManagedIdentityCredAuthToken(fedAuthInfo.spn,
+                                managedIdentityClientId);
+                        break;
+                    }
 
-                if (null != managedIdentityClientId && !managedIdentityClientId.isEmpty()) {
                     fedAuthToken = SQLServerSecurityUtility.getManagedIdentityCredAuthToken(fedAuthInfo.spn,
-                            managedIdentityClientId);
+                            activeConnectionProperties
+                                    .getProperty(SQLServerDriverStringProperty.MSI_CLIENT_ID.toString()));
+
+                    // Break out of the retry loop in successful case.
                     break;
-                }
+                } else if (authenticationString
+                        .equalsIgnoreCase(SqlAuthentication.ACTIVE_DIRECTORY_SERVICE_PRINCIPAL.toString())) {
 
-                fedAuthToken = SQLServerSecurityUtility.getManagedIdentityCredAuthToken(fedAuthInfo.spn,
-                        activeConnectionProperties.getProperty(SQLServerDriverStringProperty.MSI_CLIENT_ID.toString()));
+                    // aadPrincipalID and aadPrincipalSecret is deprecated replaced by username and password
+                    if (aadPrincipalID != null && !aadPrincipalID.isEmpty() && aadPrincipalSecret != null
+                            && !aadPrincipalSecret.isEmpty()) {
+                        fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthTokenPrincipal(fedAuthInfo, aadPrincipalID,
+                                aadPrincipalSecret, authenticationString);
+                    } else {
+                        fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthTokenPrincipal(fedAuthInfo,
+                                activeConnectionProperties.getProperty(SQLServerDriverStringProperty.USER.toString()),
+                                activeConnectionProperties.getProperty(
+                                        SQLServerDriverStringProperty.PASSWORD.toString()),
+                                authenticationString);
+                    }
 
-                // Break out of the retry loop in successful case.
-                break;
-            } else if (authenticationString
-                    .equalsIgnoreCase(SqlAuthentication.ACTIVE_DIRECTORY_SERVICE_PRINCIPAL.toString())) {
-
-                // aadPrincipalID and aadPrincipalSecret is deprecated replaced by username and password
-                if (aadPrincipalID != null && !aadPrincipalID.isEmpty() && aadPrincipalSecret != null
-                        && !aadPrincipalSecret.isEmpty()) {
-                    fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthTokenPrincipal(fedAuthInfo, aadPrincipalID,
-                            aadPrincipalSecret, authenticationString);
-                } else {
-                    fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthTokenPrincipal(fedAuthInfo,
+                    // Break out of the retry loop in successful case.
+                    break;
+                } else if (authenticationString.equalsIgnoreCase(
+                        SqlAuthentication.ACTIVE_DIRECTORY_SERVICE_PRINCIPAL_CERTIFICATE.toString())) {
+
+                    // clientCertificate property is used to specify path to certificate file
+                    fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthTokenPrincipalCertificate(fedAuthInfo,
                             activeConnectionProperties.getProperty(SQLServerDriverStringProperty.USER.toString()),
+                            servicePrincipalCertificate,
                             activeConnectionProperties.getProperty(SQLServerDriverStringProperty.PASSWORD.toString()),
-                            authenticationString);
-                }
-
-                // Break out of the retry loop in successful case.
-                break;
-            } else if (authenticationString
-                    .equalsIgnoreCase(SqlAuthentication.ACTIVE_DIRECTORY_SERVICE_PRINCIPAL_CERTIFICATE.toString())) {
+                            servicePrincipalCertificateKey, servicePrincipalCertificatePassword, authenticationString);
 
-                // clientCertificate property is used to specify path to certificate file
-                fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthTokenPrincipalCertificate(fedAuthInfo,
-                        activeConnectionProperties.getProperty(SQLServerDriverStringProperty.USER.toString()),
-                        servicePrincipalCertificate,
-                        activeConnectionProperties.getProperty(SQLServerDriverStringProperty.PASSWORD.toString()),
-                        servicePrincipalCertificateKey, servicePrincipalCertificatePassword, authenticationString);
-
-                // Break out of the retry loop in successful case.
-                break;
-            } else if (authenticationString
-                    .equalsIgnoreCase(SqlAuthentication.ACTIVE_DIRECTORY_INTEGRATED.toString())) {
-                // If operating system is windows and mssql-jdbc_auth is loaded then choose the DLL authentication.
-                if (isWindows && AuthenticationJNI.isDllLoaded()) {
-                    try {
+                    // Break out of the retry loop in successful case.
+                    break;
+                } else if (authenticationString
+                        .equalsIgnoreCase(SqlAuthentication.ACTIVE_DIRECTORY_INTEGRATED.toString())) {
+                    // If operating system is windows and mssql-jdbc_auth is loaded then choose the DLL authentication.
+                    if (isWindows && AuthenticationJNI.isDllLoaded()) {
                         FedAuthDllInfo dllInfo = AuthenticationJNI.getAccessTokenForWindowsIntegrated(
                                 fedAuthInfo.stsurl, fedAuthInfo.spn, clientConnectionId.toString(),
                                 ActiveDirectoryAuthentication.JDBC_FEDAUTH_CLIENT_ID, 0);
@@ -6103,93 +6105,119 @@ private SqlAuthenticationToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throw
 
                         // Break out of the retry loop in successful case.
                         break;
-                    } catch (DLLException adalException) {
-
-                        // the mssql-jdbc_auth DLL return -1 for errorCategory, if unable to load the adalsql DLL
-                        int errorCategory = adalException.getCategory();
-                        if (-1 == errorCategory) {
+                    }
+                    // else choose MSAL4J for integrated authentication. This option is supported for both windows and unix,
+                    // so we don't need to check the
+                    // OS version here.
+                    else {
+                        // Check if MSAL4J library is available
+                        if (!msalContextExists()) {
                             MessageFormat form = new MessageFormat(
-                                    SQLServerException.getErrString("R_UnableLoadADALSqlDll"));
-                            Object[] msgArgs = {Integer.toHexString(adalException.getState())};
-                            throw new SQLServerException(form.format(msgArgs), null);
+                                    SQLServerException.getErrString("R_DLLandMSALMissing"));
+                            Object[] msgArgs = {SQLServerDriver.AUTH_DLL_NAME, authenticationString};
+                            throw new SQLServerException(form.format(msgArgs), null, 0, null);
                         }
+                        fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthTokenIntegrated(fedAuthInfo,
+                                authenticationString);
+                    }
+                    // Break out of the retry loop in successful case.
+                    break;
+                } else if (authenticationString
+                        .equalsIgnoreCase(SqlAuthentication.ACTIVE_DIRECTORY_INTERACTIVE.toString())) {
+                    // interactive flow
+                    fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthTokenInteractive(fedAuthInfo, user,
+                            authenticationString);
+
+                    // Break out of the retry loop in successful case.
+                    break;
+                } else if (authenticationString
+                        .equalsIgnoreCase(SqlAuthentication.ACTIVE_DIRECTORY_DEFAULT.toString())) {
+                    String managedIdentityClientId = activeConnectionProperties
+                            .getProperty(SQLServerDriverStringProperty.USER.toString());
+
+                    if (null != managedIdentityClientId && !managedIdentityClientId.isEmpty()) {
+                        fedAuthToken = SQLServerSecurityUtility.getDefaultAzureCredAuthToken(fedAuthInfo.spn,
+                                managedIdentityClientId);
+                        break;
+                    }
 
-                        int millisecondsRemaining = timerRemaining(timerExpire);
-                        if (ActiveDirectoryAuthentication.GET_ACCESS_TOKEN_TRANSIENT_ERROR != errorCategory
-                                || timerHasExpired(timerExpire) || (fedauthSleepInterval >= millisecondsRemaining)) {
+                    fedAuthToken = SQLServerSecurityUtility.getDefaultAzureCredAuthToken(fedAuthInfo.spn,
+                            activeConnectionProperties
+                                    .getProperty(SQLServerDriverStringProperty.MSI_CLIENT_ID.toString()));
 
-                            String errorStatus = Integer.toHexString(adalException.getStatus());
+                    break;
+                }
+            } catch (Exception e) {
+                int millisecondsRemaining = timerRemaining(timerExpire);
 
-                            if (connectionlogger.isLoggable(Level.FINER)) {
-                                connectionlogger.fine(
-                                        toString() + " SQLServerConnection.getFedAuthToken.AdalException category:"
-                                                + errorCategory + " error: " + errorStatus);
-                            }
+                if (e instanceof DLLException) {
+                    DLLException dllException = (DLLException) e;
+                    // the mssql-jdbc_auth DLL return -1 for errorCategory, if unable to load the adalsql DLL
+                    int errorCategory = dllException.getCategory();
+                    if (-1 == errorCategory) {
+                        MessageFormat form = new MessageFormat(
+                                SQLServerException.getErrString("R_UnableLoadADALSqlDll"));
+                        Object[] msgArgs = {Integer.toHexString(dllException.getState())};
+                        throw new SQLServerException(form.format(msgArgs), null);
+                    }
 
-                            MessageFormat form = new MessageFormat(
-                                    SQLServerException.getErrString("R_ADALAuthenticationMiddleErrorMessage"));
-                            String errorCode = Integer.toHexString(adalException.getStatus()).toUpperCase();
-                            Object[] msgArgs1 = {errorCode, adalException.getState()};
-                            SQLServerException middleException = new SQLServerException(form.format(msgArgs1),
-                                    adalException);
-
-                            form = new MessageFormat(SQLServerException.getErrString("R_MSALExecution"));
-                            Object[] msgArgs = {user, authenticationString};
-                            throw new SQLServerException(form.format(msgArgs), null, 0, middleException);
-                        }
+                    if (ActiveDirectoryAuthentication.GET_ACCESS_TOKEN_TRANSIENT_ERROR != errorCategory
+                            || timerHasExpired(timerExpire) || (fedauthSleepInterval >= millisecondsRemaining)) {
+
+                        String errorStatus = Integer.toHexString(dllException.getStatus());
 
                         if (connectionlogger.isLoggable(Level.FINER)) {
-                            connectionlogger.fine(toString() + " SQLServerConnection.getFedAuthToken sleeping: "
-                                    + fedauthSleepInterval + " milliseconds.");
-                            connectionlogger.fine(toString() + " SQLServerConnection.getFedAuthToken remaining: "
-                                    + millisecondsRemaining + " milliseconds.");
+                            connectionlogger
+                                    .fine(toString() + " SQLServerConnection.getFedAuthToken.AdalException category:"
+                                            + errorCategory + " error: " + errorStatus);
                         }
 
-                        sleepForInterval(fedauthSleepInterval);
-                        fedauthSleepInterval = (fedauthSleepInterval < 500) ? fedauthSleepInterval * 2 : 1000;
+                        MessageFormat form = new MessageFormat(
+                                SQLServerException.getErrString("R_ADALAuthenticationMiddleErrorMessage"));
+                        String errorCode = Integer.toHexString(dllException.getStatus()).toUpperCase();
+                        Object[] msgArgs = {errorCode, dllException.getState()};
+                        SQLServerException middleException = new SQLServerException(form.format(msgArgs), dllException);
+                        throw new SQLServerException(form.format(msgArgs), null, 0, middleException);
+                    }
 
+                    if (connectionlogger.isLoggable(Level.FINER)) {
+                        connectionlogger.fine(
+                                toString() + "SQLServerConnection.getFedAuthToken sleeping: " + fedauthSleepInterval
+                                        + " milliseconds. + remaining: " + millisecondsRemaining + " milliseconds.");
                     }
-                }
-                // else choose MSAL4J for integrated authentication. This option is supported for both windows and unix,
-                // so we don't need to check the
-                // OS version here.
-                else {
-                    // Check if MSAL4J library is available
-                    if (!msalContextExists()) {
-                        MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_DLLandMSALMissing"));
-                        Object[] msgArgs = {SQLServerDriver.AUTH_DLL_NAME, authenticationString};
+
+                    sleepForInterval(fedauthSleepInterval);
+                    fedauthSleepInterval = (fedauthSleepInterval < 500) ? fedauthSleepInterval * 2 : 1000;
+                } else if (e instanceof SQLServerException) {
+                    SQLServerException sqlServerException = (SQLServerException) e;
+                    SQLServerError sqlServerError = sqlServerException.getSQLServerError();
+                    if (!TransientError.isTransientError(sqlServerError) || timerHasExpired(timerExpire)
+                            || (fedauthSleepInterval >= millisecondsRemaining)) {
+                        MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_MSALExecution"));
+                        Object[] msgArgs = {user, authenticationString};
                         throw new SQLServerException(form.format(msgArgs), null, 0, null);
                     }
-                    fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthTokenIntegrated(fedAuthInfo, authenticationString);
-                }
-                // Break out of the retry loop in successful case.
-                break;
-            } else if (authenticationString
-                    .equalsIgnoreCase(SqlAuthentication.ACTIVE_DIRECTORY_INTERACTIVE.toString())) {
-                // interactive flow
-                fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthTokenInteractive(fedAuthInfo, user,
-                        authenticationString);
 
-                // Break out of the retry loop in successful case.
-                break;
-            } else if (authenticationString.equalsIgnoreCase(SqlAuthentication.ACTIVE_DIRECTORY_DEFAULT.toString())) {
-                String managedIdentityClientId = activeConnectionProperties
-                        .getProperty(SQLServerDriverStringProperty.USER.toString());
+                    if (connectionlogger.isLoggable(Level.FINER)) {
+                        connectionlogger.fine(
+                                toString() + "SQLServerConnection.getFedAuthToken sleeping: " + fedauthSleepInterval
+                                        + " milliseconds. + remaining: " + millisecondsRemaining + " milliseconds.");
+                    }
+                    System.out.println("SQLServerConnection.getFedAuthToken sleeping: " + +fedauthSleepInterval
+                            + " milliseconds. + remaining: " + +millisecondsRemaining + " milliseconds.");
+                    sleepForInterval(fedauthSleepInterval);
+                    fedauthSleepInterval = (fedauthSleepInterval < 500) ? fedauthSleepInterval * 2 : 1000;
 
-                if (null != managedIdentityClientId && !managedIdentityClientId.isEmpty()) {
-                    fedAuthToken = SQLServerSecurityUtility.getDefaultAzureCredAuthToken(fedAuthInfo.spn,
-                            managedIdentityClientId);
-                    break;
+                } else {
+                    MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_MSALExecution"));
+                    Object[] msgArgs = {user, authenticationString};
+                    throw new SQLServerException(form.format(msgArgs), null, 0, null);
                 }
-
-                fedAuthToken = SQLServerSecurityUtility.getDefaultAzureCredAuthToken(fedAuthInfo.spn,
-                        activeConnectionProperties.getProperty(SQLServerDriverStringProperty.MSI_CLIENT_ID.toString()));
-
-                break;
             }
         }
 
         return fedAuthToken;
+
     }
 
     private boolean msalContextExists() {

From 16b3f125da3f14e496ffcbb1bf0007da313663c6 Mon Sep 17 00:00:00 2001
From: lilgreenbird 
Date: Mon, 13 May 2024 21:47:26 -0700
Subject: [PATCH 06/12] update

---
 .../sqlserver/jdbc/SQLServerConnection.java   | 259 ++++++++----------
 .../sqlserver/jdbc/SQLServerMSAL4JUtils.java  | 130 ++++++++-
 .../jdbc/SQLServerSecurityUtility.java        | 127 ---------
 3 files changed, 245 insertions(+), 271 deletions(-)

diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java
index e25bc6483..73a666596 100644
--- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java
+++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java
@@ -6030,66 +6030,65 @@ private SqlAuthenticationToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throw
         }
 
         while (true) {
-            try {
-                if (authenticationString.equalsIgnoreCase(SqlAuthentication.ACTIVE_DIRECTORY_PASSWORD.toString())) {
-                    fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthToken(fedAuthInfo, user,
-                            activeConnectionProperties.getProperty(SQLServerDriverStringProperty.PASSWORD.toString()),
-                            authenticationString);
-
-                    // Break out of the retry loop in successful case.
-                    break;
-                } else if (authenticationString
-                        .equalsIgnoreCase(SqlAuthentication.ACTIVE_DIRECTORY_MANAGED_IDENTITY.toString())) {
+            if (authenticationString.equalsIgnoreCase(SqlAuthentication.ACTIVE_DIRECTORY_PASSWORD.toString())) {
+                fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthToken(fedAuthInfo, user,
+                        activeConnectionProperties.getProperty(SQLServerDriverStringProperty.PASSWORD.toString()),
+                        authenticationString);
 
-                    String managedIdentityClientId = activeConnectionProperties
-                            .getProperty(SQLServerDriverStringProperty.USER.toString());
-
-                    if (null != managedIdentityClientId && !managedIdentityClientId.isEmpty()) {
-                        fedAuthToken = SQLServerSecurityUtility.getManagedIdentityCredAuthToken(fedAuthInfo.spn,
-                                managedIdentityClientId);
-                        break;
-                    }
+                // Break out of the retry loop in successful case.
+                break;
+            } else if (authenticationString
+                    .equalsIgnoreCase(SqlAuthentication.ACTIVE_DIRECTORY_MANAGED_IDENTITY.toString())) {
 
-                    fedAuthToken = SQLServerSecurityUtility.getManagedIdentityCredAuthToken(fedAuthInfo.spn,
-                            activeConnectionProperties
-                                    .getProperty(SQLServerDriverStringProperty.MSI_CLIENT_ID.toString()));
+                String managedIdentityClientId = activeConnectionProperties
+                        .getProperty(SQLServerDriverStringProperty.USER.toString());
 
-                    // Break out of the retry loop in successful case.
+                if (null != managedIdentityClientId && !managedIdentityClientId.isEmpty()) {
+                    fedAuthToken = SQLServerMSAL4JUtils.getManagedIdentityCredAuthToken(fedAuthInfo.spn,
+                            managedIdentityClientId, connectRetryCount, connectRetryInterval);
                     break;
-                } else if (authenticationString
-                        .equalsIgnoreCase(SqlAuthentication.ACTIVE_DIRECTORY_SERVICE_PRINCIPAL.toString())) {
-
-                    // aadPrincipalID and aadPrincipalSecret is deprecated replaced by username and password
-                    if (aadPrincipalID != null && !aadPrincipalID.isEmpty() && aadPrincipalSecret != null
-                            && !aadPrincipalSecret.isEmpty()) {
-                        fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthTokenPrincipal(fedAuthInfo, aadPrincipalID,
-                                aadPrincipalSecret, authenticationString);
-                    } else {
-                        fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthTokenPrincipal(fedAuthInfo,
-                                activeConnectionProperties.getProperty(SQLServerDriverStringProperty.USER.toString()),
-                                activeConnectionProperties.getProperty(
-                                        SQLServerDriverStringProperty.PASSWORD.toString()),
-                                authenticationString);
-                    }
+                }
 
-                    // Break out of the retry loop in successful case.
-                    break;
-                } else if (authenticationString.equalsIgnoreCase(
-                        SqlAuthentication.ACTIVE_DIRECTORY_SERVICE_PRINCIPAL_CERTIFICATE.toString())) {
+                fedAuthToken = SQLServerMSAL4JUtils.getManagedIdentityCredAuthToken(fedAuthInfo.spn,
+                        activeConnectionProperties.getProperty(SQLServerDriverStringProperty.MSI_CLIENT_ID.toString()),
+                        connectRetryCount, connectRetryInterval);
 
-                    // clientCertificate property is used to specify path to certificate file
-                    fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthTokenPrincipalCertificate(fedAuthInfo,
+                // Break out of the retry loop in successful case.
+                break;
+            } else if (authenticationString
+                    .equalsIgnoreCase(SqlAuthentication.ACTIVE_DIRECTORY_SERVICE_PRINCIPAL.toString())) {
+
+                // aadPrincipalID and aadPrincipalSecret is deprecated replaced by username and password
+                if (aadPrincipalID != null && !aadPrincipalID.isEmpty() && aadPrincipalSecret != null
+                        && !aadPrincipalSecret.isEmpty()) {
+                    fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthTokenPrincipal(fedAuthInfo, aadPrincipalID,
+                            aadPrincipalSecret, authenticationString);
+                } else {
+                    fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthTokenPrincipal(fedAuthInfo,
                             activeConnectionProperties.getProperty(SQLServerDriverStringProperty.USER.toString()),
-                            servicePrincipalCertificate,
                             activeConnectionProperties.getProperty(SQLServerDriverStringProperty.PASSWORD.toString()),
-                            servicePrincipalCertificateKey, servicePrincipalCertificatePassword, authenticationString);
+                            authenticationString);
+                }
 
-                    // Break out of the retry loop in successful case.
-                    break;
-                } else if (authenticationString
-                        .equalsIgnoreCase(SqlAuthentication.ACTIVE_DIRECTORY_INTEGRATED.toString())) {
-                    // If operating system is windows and mssql-jdbc_auth is loaded then choose the DLL authentication.
-                    if (isWindows && AuthenticationJNI.isDllLoaded()) {
+                // Break out of the retry loop in successful case.
+                break;
+            } else if (authenticationString
+                    .equalsIgnoreCase(SqlAuthentication.ACTIVE_DIRECTORY_SERVICE_PRINCIPAL_CERTIFICATE.toString())) {
+
+                // clientCertificate property is used to specify path to certificate file
+                fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthTokenPrincipalCertificate(fedAuthInfo,
+                        activeConnectionProperties.getProperty(SQLServerDriverStringProperty.USER.toString()),
+                        servicePrincipalCertificate,
+                        activeConnectionProperties.getProperty(SQLServerDriverStringProperty.PASSWORD.toString()),
+                        servicePrincipalCertificateKey, servicePrincipalCertificatePassword, authenticationString);
+
+                // Break out of the retry loop in successful case.
+                break;
+            } else if (authenticationString
+                    .equalsIgnoreCase(SqlAuthentication.ACTIVE_DIRECTORY_INTEGRATED.toString())) {
+                // If operating system is windows and mssql-jdbc_auth is loaded then choose the DLL authentication.
+                if (isWindows && AuthenticationJNI.isDllLoaded()) {
+                    try {
                         FedAuthDllInfo dllInfo = AuthenticationJNI.getAccessTokenForWindowsIntegrated(
                                 fedAuthInfo.stsurl, fedAuthInfo.spn, clientConnectionId.toString(),
                                 ActiveDirectoryAuthentication.JDBC_FEDAUTH_CLIENT_ID, 0);
@@ -6105,119 +6104,93 @@ private SqlAuthenticationToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throw
 
                         // Break out of the retry loop in successful case.
                         break;
-                    }
-                    // else choose MSAL4J for integrated authentication. This option is supported for both windows and unix,
-                    // so we don't need to check the
-                    // OS version here.
-                    else {
-                        // Check if MSAL4J library is available
-                        if (!msalContextExists()) {
+                    } catch (DLLException adalException) {
+
+                        // the mssql-jdbc_auth DLL return -1 for errorCategory, if unable to load the adalsql DLL
+                        int errorCategory = adalException.getCategory();
+                        if (-1 == errorCategory) {
                             MessageFormat form = new MessageFormat(
-                                    SQLServerException.getErrString("R_DLLandMSALMissing"));
-                            Object[] msgArgs = {SQLServerDriver.AUTH_DLL_NAME, authenticationString};
-                            throw new SQLServerException(form.format(msgArgs), null, 0, null);
+                                    SQLServerException.getErrString("R_UnableLoadADALSqlDll"));
+                            Object[] msgArgs = {Integer.toHexString(adalException.getState())};
+                            throw new SQLServerException(form.format(msgArgs), null);
                         }
-                        fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthTokenIntegrated(fedAuthInfo,
-                                authenticationString);
-                    }
-                    // Break out of the retry loop in successful case.
-                    break;
-                } else if (authenticationString
-                        .equalsIgnoreCase(SqlAuthentication.ACTIVE_DIRECTORY_INTERACTIVE.toString())) {
-                    // interactive flow
-                    fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthTokenInteractive(fedAuthInfo, user,
-                            authenticationString);
-
-                    // Break out of the retry loop in successful case.
-                    break;
-                } else if (authenticationString
-                        .equalsIgnoreCase(SqlAuthentication.ACTIVE_DIRECTORY_DEFAULT.toString())) {
-                    String managedIdentityClientId = activeConnectionProperties
-                            .getProperty(SQLServerDriverStringProperty.USER.toString());
-
-                    if (null != managedIdentityClientId && !managedIdentityClientId.isEmpty()) {
-                        fedAuthToken = SQLServerSecurityUtility.getDefaultAzureCredAuthToken(fedAuthInfo.spn,
-                                managedIdentityClientId);
-                        break;
-                    }
 
-                    fedAuthToken = SQLServerSecurityUtility.getDefaultAzureCredAuthToken(fedAuthInfo.spn,
-                            activeConnectionProperties
-                                    .getProperty(SQLServerDriverStringProperty.MSI_CLIENT_ID.toString()));
+                        int millisecondsRemaining = timerRemaining(timerExpire);
+                        if (ActiveDirectoryAuthentication.GET_ACCESS_TOKEN_TRANSIENT_ERROR != errorCategory
+                                || timerHasExpired(timerExpire) || (fedauthSleepInterval >= millisecondsRemaining)) {
 
-                    break;
-                }
-            } catch (Exception e) {
-                int millisecondsRemaining = timerRemaining(timerExpire);
+                            String errorStatus = Integer.toHexString(adalException.getStatus());
 
-                if (e instanceof DLLException) {
-                    DLLException dllException = (DLLException) e;
-                    // the mssql-jdbc_auth DLL return -1 for errorCategory, if unable to load the adalsql DLL
-                    int errorCategory = dllException.getCategory();
-                    if (-1 == errorCategory) {
-                        MessageFormat form = new MessageFormat(
-                                SQLServerException.getErrString("R_UnableLoadADALSqlDll"));
-                        Object[] msgArgs = {Integer.toHexString(dllException.getState())};
-                        throw new SQLServerException(form.format(msgArgs), null);
-                    }
-
-                    if (ActiveDirectoryAuthentication.GET_ACCESS_TOKEN_TRANSIENT_ERROR != errorCategory
-                            || timerHasExpired(timerExpire) || (fedauthSleepInterval >= millisecondsRemaining)) {
+                            if (connectionlogger.isLoggable(Level.FINER)) {
+                                connectionlogger.fine(
+                                        toString() + " SQLServerConnection.getFedAuthToken.AdalException category:"
+                                                + errorCategory + " error: " + errorStatus);
+                            }
 
-                        String errorStatus = Integer.toHexString(dllException.getStatus());
+                            MessageFormat form = new MessageFormat(
+                                    SQLServerException.getErrString("R_ADALAuthenticationMiddleErrorMessage"));
+                            String errorCode = Integer.toHexString(adalException.getStatus()).toUpperCase();
+                            Object[] msgArgs1 = {errorCode, adalException.getState()};
+                            SQLServerException middleException = new SQLServerException(form.format(msgArgs1),
+                                    adalException);
+
+                            form = new MessageFormat(SQLServerException.getErrString("R_MSALExecution"));
+                            Object[] msgArgs = {user, authenticationString};
+                            throw new SQLServerException(form.format(msgArgs), null, 0, middleException);
+                        }
 
                         if (connectionlogger.isLoggable(Level.FINER)) {
-                            connectionlogger
-                                    .fine(toString() + " SQLServerConnection.getFedAuthToken.AdalException category:"
-                                            + errorCategory + " error: " + errorStatus);
+                            connectionlogger.fine(toString() + " SQLServerConnection.getFedAuthToken sleeping: "
+                                    + fedauthSleepInterval + " milliseconds.");
+                            connectionlogger.fine(toString() + " SQLServerConnection.getFedAuthToken remaining: "
+                                    + millisecondsRemaining + " milliseconds.");
                         }
 
-                        MessageFormat form = new MessageFormat(
-                                SQLServerException.getErrString("R_ADALAuthenticationMiddleErrorMessage"));
-                        String errorCode = Integer.toHexString(dllException.getStatus()).toUpperCase();
-                        Object[] msgArgs = {errorCode, dllException.getState()};
-                        SQLServerException middleException = new SQLServerException(form.format(msgArgs), dllException);
-                        throw new SQLServerException(form.format(msgArgs), null, 0, middleException);
-                    }
+                        sleepForInterval(fedauthSleepInterval);
+                        fedauthSleepInterval = (fedauthSleepInterval < 500) ? fedauthSleepInterval * 2 : 1000;
 
-                    if (connectionlogger.isLoggable(Level.FINER)) {
-                        connectionlogger.fine(
-                                toString() + "SQLServerConnection.getFedAuthToken sleeping: " + fedauthSleepInterval
-                                        + " milliseconds. + remaining: " + millisecondsRemaining + " milliseconds.");
                     }
-
-                    sleepForInterval(fedauthSleepInterval);
-                    fedauthSleepInterval = (fedauthSleepInterval < 500) ? fedauthSleepInterval * 2 : 1000;
-                } else if (e instanceof SQLServerException) {
-                    SQLServerException sqlServerException = (SQLServerException) e;
-                    SQLServerError sqlServerError = sqlServerException.getSQLServerError();
-                    if (!TransientError.isTransientError(sqlServerError) || timerHasExpired(timerExpire)
-                            || (fedauthSleepInterval >= millisecondsRemaining)) {
-                        MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_MSALExecution"));
-                        Object[] msgArgs = {user, authenticationString};
+                }
+                // else choose MSAL4J for integrated authentication. This option is supported for both windows and unix,
+                // so we don't need to check the
+                // OS version here.
+                else {
+                    // Check if MSAL4J library is available
+                    if (!msalContextExists()) {
+                        MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_DLLandMSALMissing"));
+                        Object[] msgArgs = {SQLServerDriver.AUTH_DLL_NAME, authenticationString};
                         throw new SQLServerException(form.format(msgArgs), null, 0, null);
                     }
+                    fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthTokenIntegrated(fedAuthInfo, authenticationString);
+                }
+                // Break out of the retry loop in successful case.
+                break;
+            } else if (authenticationString
+                    .equalsIgnoreCase(SqlAuthentication.ACTIVE_DIRECTORY_INTERACTIVE.toString())) {
+                // interactive flow
+                fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthTokenInteractive(fedAuthInfo, user,
+                        authenticationString);
 
-                    if (connectionlogger.isLoggable(Level.FINER)) {
-                        connectionlogger.fine(
-                                toString() + "SQLServerConnection.getFedAuthToken sleeping: " + fedauthSleepInterval
-                                        + " milliseconds. + remaining: " + millisecondsRemaining + " milliseconds.");
-                    }
-                    System.out.println("SQLServerConnection.getFedAuthToken sleeping: " + +fedauthSleepInterval
-                            + " milliseconds. + remaining: " + +millisecondsRemaining + " milliseconds.");
-                    sleepForInterval(fedauthSleepInterval);
-                    fedauthSleepInterval = (fedauthSleepInterval < 500) ? fedauthSleepInterval * 2 : 1000;
+                // Break out of the retry loop in successful case.
+                break;
+            } else if (authenticationString.equalsIgnoreCase(SqlAuthentication.ACTIVE_DIRECTORY_DEFAULT.toString())) {
+                String managedIdentityClientId = activeConnectionProperties
+                        .getProperty(SQLServerDriverStringProperty.USER.toString());
 
-                } else {
-                    MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_MSALExecution"));
-                    Object[] msgArgs = {user, authenticationString};
-                    throw new SQLServerException(form.format(msgArgs), null, 0, null);
+                if (null != managedIdentityClientId && !managedIdentityClientId.isEmpty()) {
+                    fedAuthToken = SQLServerMSAL4JUtils.getDefaultAzureCredAuthToken(fedAuthInfo.spn,
+                            managedIdentityClientId);
+                    break;
                 }
+
+                fedAuthToken = SQLServerMSAL4JUtils.getDefaultAzureCredAuthToken(fedAuthInfo.spn,
+                        activeConnectionProperties.getProperty(SQLServerDriverStringProperty.MSI_CLIENT_ID.toString()));
+
+                break;
             }
         }
 
         return fedAuthToken;
-
     }
 
     private boolean msalContextExists() {
diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java
index 1ae69891b..e807f9835 100644
--- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java
+++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java
@@ -16,10 +16,12 @@
 
 import java.security.MessageDigest;
 import java.text.MessageFormat;
-
+import java.time.Duration;
+import java.util.Arrays;
 import java.util.Collections;
 import java.util.Enumeration;
 import java.util.HashSet;
+import java.util.Optional;
 import java.util.Set;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ConcurrentHashMap;
@@ -40,6 +42,12 @@
 import java.security.cert.X509Certificate;
 
 import com.microsoft.aad.msal4j.IAccount;
+import com.azure.core.credential.AccessToken;
+import com.azure.core.credential.TokenRequestContext;
+import com.azure.identity.DefaultAzureCredential;
+import com.azure.identity.DefaultAzureCredentialBuilder;
+import com.azure.identity.ManagedIdentityCredential;
+import com.azure.identity.ManagedIdentityCredentialBuilder;
 import com.microsoft.aad.msal4j.ClientCredentialFactory;
 import com.microsoft.aad.msal4j.ClientCredentialParameters;
 import com.microsoft.aad.msal4j.ConfidentialClientApplication;
@@ -65,6 +73,12 @@ class SQLServerMSAL4JUtils {
 
     private static final TokenCacheMap TOKEN_CACHE_MAP = new TokenCacheMap();
 
+    // Environment variable for intellij keepass database path
+    private static final String INTELLIJ_KEEPASS_PASS = "INTELLIJ_KEEPASS_PATH";
+
+    // Environment variable for additionally allowed tenants. The tenantIds are comma delimited
+    private static final String ADDITIONALLY_ALLOWED_TENANTS = "ADDITIONALLY_ALLOWED_TENANTS";
+
     private final static String LOGCONTEXT = "MSAL version "
             + com.microsoft.aad.msal4j.PublicClientApplication.class.getPackage().getImplementationVersion() + ": ";
 
@@ -433,6 +447,120 @@ static SqlAuthenticationToken getSqlFedAuthTokenInteractive(SqlFedAuthInfo fedAu
         }
     }
 
+    /**
+     * Get Managed Identity Authentication token through a ManagedIdentityCredential
+     * 
+     * @param resource
+     *        Token resource.
+     * @param managedIdentityClientId
+     *        Client ID of the user-assigned Managed Identity.
+     * @return fedauth token
+     * @throws SQLServerException
+     */
+    static SqlAuthenticationToken getManagedIdentityCredAuthToken(String resource, String managedIdentityClientId,
+            int retryCount, int retryInterval) throws SQLServerException {
+        ManagedIdentityCredential mic = null;
+
+        if (logger.isLoggable(java.util.logging.Level.FINEST)) {
+            logger.finest("Getting Managed Identity authentication token for: " + managedIdentityClientId);
+        }
+
+        if (null != managedIdentityClientId && !managedIdentityClientId.isEmpty()) {
+            mic = new ManagedIdentityCredentialBuilder().clientId(managedIdentityClientId).maxRetry(retryCount)
+                    .retryTimeout(duration -> Duration.ofSeconds(retryInterval)).build();
+        } else {
+            mic = new ManagedIdentityCredentialBuilder().maxRetry(retryCount)
+                    .retryTimeout(duration -> Duration.ofSeconds(retryInterval)).build();
+        }
+
+        TokenRequestContext tokenRequestContext = new TokenRequestContext();
+        String scope = resource.endsWith(SQLServerMSAL4JUtils.SLASH_DEFAULT) ? resource : resource
+                + SQLServerMSAL4JUtils.SLASH_DEFAULT;
+        tokenRequestContext.setScopes(Arrays.asList(scope));
+
+        SqlAuthenticationToken sqlFedAuthToken = null;
+
+        Optional accessTokenOptional = mic.getToken(tokenRequestContext).blockOptional();
+
+        if (!accessTokenOptional.isPresent()) {
+            throw new SQLServerException(SQLServerException.getErrString("R_ManagedIdentityTokenAcquisitionFail"),
+                    null);
+        } else {
+            AccessToken accessToken = accessTokenOptional.get();
+            sqlFedAuthToken = new SqlAuthenticationToken(accessToken.getToken(),
+                    accessToken.getExpiresAt().toEpochSecond());
+        }
+
+        if (logger.isLoggable(java.util.logging.Level.FINEST)) {
+            logger.finest("Got fedAuth token, expiry: " + sqlFedAuthToken.getExpiresOn().toString());
+        }
+
+        return sqlFedAuthToken;
+    }
+
+    /**
+     * Get Managed Identity Authentication token through the DefaultAzureCredential
+     *
+     * @param resource
+     *        Token resource.
+     * @param managedIdentityClientId
+     *        Client ID of the user-assigned Managed Identity.
+     * @return fedauth token
+     * @throws SQLServerException
+     */
+    static SqlAuthenticationToken getDefaultAzureCredAuthToken(String resource,
+            String managedIdentityClientId) throws SQLServerException {
+        String intellijKeepassPath = System.getenv(INTELLIJ_KEEPASS_PASS);
+        String[] additionallyAllowedTenants = getAdditonallyAllowedTenants();
+
+        DefaultAzureCredentialBuilder dacBuilder = new DefaultAzureCredentialBuilder();
+        DefaultAzureCredential dac = null;
+
+        if (null != managedIdentityClientId && !managedIdentityClientId.isEmpty()) {
+            dacBuilder.managedIdentityClientId(managedIdentityClientId);
+        }
+
+        if (null != intellijKeepassPath && !intellijKeepassPath.isEmpty()) {
+            dacBuilder.intelliJKeePassDatabasePath(intellijKeepassPath);
+        }
+
+        if (null != additionallyAllowedTenants && additionallyAllowedTenants.length != 0) {
+            dacBuilder.additionallyAllowedTenants(additionallyAllowedTenants);
+        }
+
+        dac = dacBuilder.build();
+
+        TokenRequestContext tokenRequestContext = new TokenRequestContext();
+        String scope = resource.endsWith(SQLServerMSAL4JUtils.SLASH_DEFAULT) ? resource : resource
+                + SQLServerMSAL4JUtils.SLASH_DEFAULT;
+        tokenRequestContext.setScopes(Arrays.asList(scope));
+
+        SqlAuthenticationToken sqlFedAuthToken = null;
+
+        Optional accessTokenOptional = dac.getToken(tokenRequestContext).blockOptional();
+
+        if (!accessTokenOptional.isPresent()) {
+            throw new SQLServerException(SQLServerException.getErrString("R_ManagedIdentityTokenAcquisitionFail"),
+                    null);
+        } else {
+            AccessToken accessToken = accessTokenOptional.get();
+            sqlFedAuthToken = new SqlAuthenticationToken(accessToken.getToken(),
+                    accessToken.getExpiresAt().toEpochSecond());
+        }
+
+        return sqlFedAuthToken;
+    }
+
+    private static String[] getAdditonallyAllowedTenants() {
+        String additonallyAllowedTenants = System.getenv(ADDITIONALLY_ALLOWED_TENANTS);
+
+        if (null != additonallyAllowedTenants && !additonallyAllowedTenants.isEmpty()) {
+            return System.getenv(ADDITIONALLY_ALLOWED_TENANTS).split(",");
+        }
+
+        return null;
+    }
+
     // Helper function to return account containing user name from set of accounts, or null if no match
     private static IAccount getAccountByUsername(Set accounts, String username) {
         if (!accounts.isEmpty()) {
diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerSecurityUtility.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerSecurityUtility.java
index 3bb2eb32d..e1e71613e 100644
--- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerSecurityUtility.java
+++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerSecurityUtility.java
@@ -8,21 +8,12 @@
 import java.security.InvalidKeyException;
 import java.security.NoSuchAlgorithmException;
 import java.text.MessageFormat;
-import java.util.Arrays;
-import java.util.Optional;
 import java.util.Iterator;
 import java.util.List;
 
 import javax.crypto.Mac;
 import javax.crypto.spec.SecretKeySpec;
 
-import com.azure.core.credential.AccessToken;
-import com.azure.core.credential.TokenRequestContext;
-import com.azure.identity.ManagedIdentityCredential;
-import com.azure.identity.ManagedIdentityCredentialBuilder;
-import com.azure.identity.DefaultAzureCredential;
-import com.azure.identity.DefaultAzureCredentialBuilder;
-
 
 /**
  * Various SQLServer security utilities.
@@ -40,12 +31,6 @@ class SQLServerSecurityUtility {
 
     static final String WINDOWS_KEY_STORE_NAME = "MSSQL_CERTIFICATE_STORE";
 
-    // Environment variable for intellij keepass database path
-    private static final String INTELLIJ_KEEPASS_PASS = "INTELLIJ_KEEPASS_PATH";
-
-    // Environment variable for additionally allowed tenants. The tenantIds are comma delimited
-    private static final String ADDITIONALLY_ALLOWED_TENANTS = "ADDITIONALLY_ALLOWED_TENANTS";
-
     private SQLServerSecurityUtility() {
         throw new UnsupportedOperationException(SQLServerException.getErrString("R_notSupported"));
     }
@@ -318,116 +303,4 @@ static void verifyColumnMasterKeyMetadata(SQLServerConnection connection, SQLSer
             throw new SQLServerException(SQLServerException.getErrString("R_VerifySignatureFailed"), null);
         }
     }
-
-    /**
-     * Get Managed Identity Authentication token through a ManagedIdentityCredential
-     * 
-     * @param resource
-     *        Token resource.
-     * @param managedIdentityClientId
-     *        Client ID of the user-assigned Managed Identity.
-     * @return fedauth token
-     * @throws SQLServerException
-     */
-    static SqlAuthenticationToken getManagedIdentityCredAuthToken(String resource,
-            String managedIdentityClientId) throws SQLServerException {
-        ManagedIdentityCredential mic = null;
-
-        if (logger.isLoggable(java.util.logging.Level.FINEST)) {
-            logger.finest("Getting Managed Identity authentication token for: " + managedIdentityClientId);
-        }
-
-        if (null != managedIdentityClientId && !managedIdentityClientId.isEmpty()) {
-            mic = new ManagedIdentityCredentialBuilder().clientId(managedIdentityClientId).build();
-        } else {
-            mic = new ManagedIdentityCredentialBuilder().build();
-        }
-
-        TokenRequestContext tokenRequestContext = new TokenRequestContext();
-        String scope = resource.endsWith(SQLServerMSAL4JUtils.SLASH_DEFAULT) ? resource : resource
-                + SQLServerMSAL4JUtils.SLASH_DEFAULT;
-        tokenRequestContext.setScopes(Arrays.asList(scope));
-
-        SqlAuthenticationToken sqlFedAuthToken = null;
-
-        Optional accessTokenOptional = mic.getToken(tokenRequestContext).blockOptional();
-
-        if (!accessTokenOptional.isPresent()) {
-            throw new SQLServerException(SQLServerException.getErrString("R_ManagedIdentityTokenAcquisitionFail"),
-                    null);
-        } else {
-            AccessToken accessToken = accessTokenOptional.get();
-            sqlFedAuthToken = new SqlAuthenticationToken(accessToken.getToken(),
-                    accessToken.getExpiresAt().toEpochSecond());
-        }
-
-        if (logger.isLoggable(java.util.logging.Level.FINEST)) {
-            logger.finest("Got fedAuth token, expiry: " + sqlFedAuthToken.getExpiresOn().toString());
-        }
-
-        return sqlFedAuthToken;
-    }
-
-    /**
-     * Get Managed Identity Authentication token through the DefaultAzureCredential
-     *
-     * @param resource
-     *        Token resource.
-     * @param managedIdentityClientId
-     *        Client ID of the user-assigned Managed Identity.
-     * @return fedauth token
-     * @throws SQLServerException
-     */
-    static SqlAuthenticationToken getDefaultAzureCredAuthToken(String resource,
-            String managedIdentityClientId) throws SQLServerException {
-        String intellijKeepassPath = System.getenv(INTELLIJ_KEEPASS_PASS);
-        String[] additionallyAllowedTenants = getAdditonallyAllowedTenants();
-
-        DefaultAzureCredentialBuilder dacBuilder = new DefaultAzureCredentialBuilder();
-        DefaultAzureCredential dac = null;
-
-        if (null != managedIdentityClientId && !managedIdentityClientId.isEmpty()) {
-            dacBuilder.managedIdentityClientId(managedIdentityClientId);
-        }
-
-        if (null != intellijKeepassPath && !intellijKeepassPath.isEmpty()) {
-            dacBuilder.intelliJKeePassDatabasePath(intellijKeepassPath);
-        }
-
-        if (null != additionallyAllowedTenants && additionallyAllowedTenants.length != 0) {
-            dacBuilder.additionallyAllowedTenants(additionallyAllowedTenants);
-        }
-
-        dac = dacBuilder.build();
-
-        TokenRequestContext tokenRequestContext = new TokenRequestContext();
-        String scope = resource.endsWith(SQLServerMSAL4JUtils.SLASH_DEFAULT) ? resource : resource
-                + SQLServerMSAL4JUtils.SLASH_DEFAULT;
-        tokenRequestContext.setScopes(Arrays.asList(scope));
-
-        SqlAuthenticationToken sqlFedAuthToken = null;
-
-        Optional accessTokenOptional = dac.getToken(tokenRequestContext).blockOptional();
-
-        if (!accessTokenOptional.isPresent()) {
-            throw new SQLServerException(SQLServerException.getErrString("R_ManagedIdentityTokenAcquisitionFail"),
-                    null);
-        } else {
-            AccessToken accessToken = accessTokenOptional.get();
-            sqlFedAuthToken = new SqlAuthenticationToken(accessToken.getToken(),
-                    accessToken.getExpiresAt().toEpochSecond());
-        }
-
-        return sqlFedAuthToken;
-    }
-
-    private static String[] getAdditonallyAllowedTenants() {
-        String additonallyAllowedTenants = System.getenv(ADDITIONALLY_ALLOWED_TENANTS);
-
-        if (null != additonallyAllowedTenants && !additonallyAllowedTenants.isEmpty()) {
-            return System.getenv(ADDITIONALLY_ALLOWED_TENANTS).split(",");
-        }
-
-        return null;
-    }
 }

From b4bd2ba7107d2a279fd85f0da38b14ed3e463adb Mon Sep 17 00:00:00 2001
From: lilgreenbird 
Date: Mon, 13 May 2024 23:23:23 -0700
Subject: [PATCH 07/12] update

---
 .../com/microsoft/sqlserver/jdbc/SQLServerConnection.java     | 3 +++
 .../com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java    | 4 ++++
 2 files changed, 7 insertions(+)

diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java
index 73a666596..2bcf3737a 100644
--- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java
+++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java
@@ -6043,6 +6043,9 @@ private SqlAuthenticationToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throw
                 String managedIdentityClientId = activeConnectionProperties
                         .getProperty(SQLServerDriverStringProperty.USER.toString());
 
+                // get and validate connection retry values
+                validateConnectionRetry();
+                
                 if (null != managedIdentityClientId && !managedIdentityClientId.isEmpty()) {
                     fedAuthToken = SQLServerMSAL4JUtils.getManagedIdentityCredAuthToken(fedAuthInfo.spn,
                             managedIdentityClientId, connectRetryCount, connectRetryInterval);
diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java
index e807f9835..49b999eb9 100644
--- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java
+++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java
@@ -465,6 +465,9 @@ static SqlAuthenticationToken getManagedIdentityCredAuthToken(String resource, S
             logger.finest("Getting Managed Identity authentication token for: " + managedIdentityClientId);
         }
 
+        System.out.println("Getting Managed Identity authentication token for: " + managedIdentityClientId
+                + " retryCount=" + retryCount + " retryInterval=" + retryInterval);
+
         if (null != managedIdentityClientId && !managedIdentityClientId.isEmpty()) {
             mic = new ManagedIdentityCredentialBuilder().clientId(managedIdentityClientId).maxRetry(retryCount)
                     .retryTimeout(duration -> Duration.ofSeconds(retryInterval)).build();
@@ -495,6 +498,7 @@ static SqlAuthenticationToken getManagedIdentityCredAuthToken(String resource, S
             logger.finest("Got fedAuth token, expiry: " + sqlFedAuthToken.getExpiresOn().toString());
         }
 
+        System.out.println("Got fedAuth token, expiry: " + sqlFedAuthToken.getExpiresOn().toString());
         return sqlFedAuthToken;
     }
 

From 36a07bc2d5cdf55250ca38cd3c146d56b5fa1b42 Mon Sep 17 00:00:00 2001
From: lilgreenbird 
Date: Tue, 14 May 2024 00:44:29 -0700
Subject: [PATCH 08/12] update

---
 .../sqlserver/jdbc/SQLServerMSAL4JUtils.java   |  4 ----
 .../jdbc/SQLServerConnectionTest.java          | 13 +++++++------
 .../microsoft/sqlserver/jdbc/TestResource.java |  3 ++-
 .../sqlserver/jdbc/connection/TimeoutTest.java | 18 +++++++++---------
 4 files changed, 18 insertions(+), 20 deletions(-)

diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java
index 49b999eb9..e807f9835 100644
--- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java
+++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java
@@ -465,9 +465,6 @@ static SqlAuthenticationToken getManagedIdentityCredAuthToken(String resource, S
             logger.finest("Getting Managed Identity authentication token for: " + managedIdentityClientId);
         }
 
-        System.out.println("Getting Managed Identity authentication token for: " + managedIdentityClientId
-                + " retryCount=" + retryCount + " retryInterval=" + retryInterval);
-
         if (null != managedIdentityClientId && !managedIdentityClientId.isEmpty()) {
             mic = new ManagedIdentityCredentialBuilder().clientId(managedIdentityClientId).maxRetry(retryCount)
                     .retryTimeout(duration -> Duration.ofSeconds(retryInterval)).build();
@@ -498,7 +495,6 @@ static SqlAuthenticationToken getManagedIdentityCredAuthToken(String resource, S
             logger.finest("Got fedAuth token, expiry: " + sqlFedAuthToken.getExpiresOn().toString());
         }
 
-        System.out.println("Got fedAuth token, expiry: " + sqlFedAuthToken.getExpiresOn().toString());
         return sqlFedAuthToken;
     }
 
diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/SQLServerConnectionTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/SQLServerConnectionTest.java
index bdac7c106..e024e98cb 100644
--- a/src/test/java/com/microsoft/sqlserver/jdbc/SQLServerConnectionTest.java
+++ b/src/test/java/com/microsoft/sqlserver/jdbc/SQLServerConnectionTest.java
@@ -484,9 +484,10 @@ public void testConnectCountInLoginAndCorrectRetryCount() {
         } catch (Exception e) {
             assertTrue(
                     e.getMessage().contains(TestResource.getResource("R_cannotOpenDatabase"))
-                            || (TestUtils.getProperty(connectionString, "msiClientId") != null
-                                    && e.getMessage().toLowerCase()
-                                            .contains(TestResource.getResource("R_loginFailedMSI").toLowerCase())),
+                            || (TestUtils.getProperty(connectionString, "msiClientId") != null && (e.getMessage()
+                                    .toLowerCase().contains(TestResource.getResource("R_loginFailedMI").toLowerCase())
+                                    || e.getMessage().toLowerCase()
+                                            .contains(TestResource.getResource("R_MInotAvailable").toLowerCase()))),
                     e.getMessage());
             long totalTime = System.currentTimeMillis() - timerStart;
 
@@ -804,7 +805,7 @@ public void testIncorrectDatabase() throws SQLException {
                     e.getMessage().contains(TestResource.getResource("R_cannotOpenDatabase"))
                             || (TestUtils.getProperty(connectionString, "msiClientId") != null
                                     && e.getMessage().toLowerCase()
-                                            .contains(TestResource.getResource("R_loginFailedMSI").toLowerCase())),
+                                            .contains(TestResource.getResource("R_loginFailedMI").toLowerCase())),
                     e.getMessage());
             timerEnd = System.currentTimeMillis();
         }
@@ -837,7 +838,7 @@ public void testIncorrectUserName() throws SQLException {
                     e.getMessage().contains(TestResource.getResource("R_loginFailed"))
                             || (TestUtils.getProperty(connectionString, "msiClientId") != null
                                     && e.getMessage().toLowerCase()
-                                            .contains(TestResource.getResource("R_loginFailedMSI").toLowerCase())),
+                                            .contains(TestResource.getResource("R_loginFailedMI").toLowerCase())),
                     e.getMessage());
             timerEnd = System.currentTimeMillis();
         }
@@ -870,7 +871,7 @@ public void testIncorrectPassword() throws SQLException {
                     e.getMessage().contains(TestResource.getResource("R_loginFailed"))
                             || (TestUtils.getProperty(connectionString, "msiClientId") != null
                                     && e.getMessage().toLowerCase()
-                                            .contains(TestResource.getResource("R_loginFailedMSI").toLowerCase())),
+                                            .contains(TestResource.getResource("R_loginFailedMI").toLowerCase())),
                     e.getMessage());
             timerEnd = System.currentTimeMillis();
         }
diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/TestResource.java b/src/test/java/com/microsoft/sqlserver/jdbc/TestResource.java
index d542434fb..cdaea574b 100644
--- a/src/test/java/com/microsoft/sqlserver/jdbc/TestResource.java
+++ b/src/test/java/com/microsoft/sqlserver/jdbc/TestResource.java
@@ -219,5 +219,6 @@ protected Object[][] getContents() {
             {"R_unexpectedThreadCount", "Thread count is higher than expected."},
             {"R_expectedClassDoesNotMatchActualClass",
                     "Expected column class {0} does not match actual column class {1} for column {2}."},
-            {"R_loginFailedMSI", "Login failed for user ''"}};
+            {"R_loginFailedMI", "Login failed for user ''"},
+            {"R_MInotAvailable", "Managed Identity authentication is not available"},};
 }
diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/connection/TimeoutTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/connection/TimeoutTest.java
index ffd9e925c..d7290b262 100644
--- a/src/test/java/com/microsoft/sqlserver/jdbc/connection/TimeoutTest.java
+++ b/src/test/java/com/microsoft/sqlserver/jdbc/connection/TimeoutTest.java
@@ -69,7 +69,7 @@ public void testDefaultLoginTimeout() {
                     (e.getMessage().toLowerCase()
                             .contains(TestResource.getResource("R_tcpipConnectionToHost").toLowerCase()))
                             || (TestUtils.getProperty(connectionString, "msiClientId") != null && e.getMessage()
-                                    .toLowerCase().contains(TestResource.getResource("R_loginFailedMSI").toLowerCase()))
+                                    .toLowerCase().contains(TestResource.getResource("R_loginFailedMI").toLowerCase()))
                             || ((isSqlAzure() || isSqlAzureDW()) ? e.getMessage().toLowerCase()
                                     .contains(TestResource.getResource("R_connectTimedOut").toLowerCase()) : false),
                     e.getMessage());
@@ -98,7 +98,7 @@ public void testURLLoginTimeout() {
                     (e.getMessage().toLowerCase()
                             .contains(TestResource.getResource("R_tcpipConnectionToHost").toLowerCase()))
                             || (TestUtils.getProperty(connectionString, "msiClientId") != null && e.getMessage()
-                                    .toLowerCase().contains(TestResource.getResource("R_loginFailedMSI").toLowerCase()))
+                                    .toLowerCase().contains(TestResource.getResource("R_loginFailedMI").toLowerCase()))
                             || ((isSqlAzure() || isSqlAzureDW()) ? e.getMessage().toLowerCase()
                                     .contains(TestResource.getResource("R_connectTimedOut").toLowerCase()) : false),
                     e.getMessage());
@@ -127,7 +127,7 @@ public void testDMLoginTimeoutApplied() {
                     (e.getMessage().toLowerCase()
                             .contains(TestResource.getResource("R_tcpipConnectionToHost").toLowerCase()))
                             || (TestUtils.getProperty(connectionString, "msiClientId") != null && e.getMessage()
-                                    .toLowerCase().contains(TestResource.getResource("R_loginFailedMSI").toLowerCase()))
+                                    .toLowerCase().contains(TestResource.getResource("R_loginFailedMI").toLowerCase()))
                             || ((isSqlAzure() || isSqlAzureDW()) ? e.getMessage().toLowerCase()
                                     .contains(TestResource.getResource("R_connectTimedOut").toLowerCase()) : false),
                     e.getMessage());
@@ -158,7 +158,7 @@ public void testDMLoginTimeoutNotApplied() {
                                 .contains(TestResource.getResource("R_tcpipConnectionToHost").toLowerCase()))
                                 || (TestUtils.getProperty(connectionString, "msiClientId") != null
                                         && e.getMessage().toLowerCase()
-                                                .contains(TestResource.getResource("R_loginFailedMSI").toLowerCase()))
+                                                .contains(TestResource.getResource("R_loginFailedMI").toLowerCase()))
                                 || ((isSqlAzure() || isSqlAzureDW()) ? e.getMessage().toLowerCase()
                                         .contains(TestResource.getResource("R_connectTimedOut").toLowerCase()) : false),
                         e.getMessage());
@@ -191,7 +191,7 @@ public void testConnectRetryDisable() {
             assertTrue(
                     e.getMessage().matches(TestUtils.formatErrorMsg("R_tcpipConnectionFailed"))
                             || (TestUtils.getProperty(connectionString, "msiClientId") != null && e.getMessage()
-                                    .toLowerCase().contains(TestResource.getResource("R_loginFailedMSI").toLowerCase()))
+                                    .toLowerCase().contains(TestResource.getResource("R_loginFailedMI").toLowerCase()))
                             || ((isSqlAzure() || isSqlAzureDW()) ? e.getMessage().toLowerCase()
                                     .contains(TestResource.getResource("R_connectTimedOut").toLowerCase()) : false),
                     e.getMessage());
@@ -219,7 +219,7 @@ public void testConnectRetryBadServer() {
                     (e.getMessage().toLowerCase()
                             .contains(TestResource.getResource("R_tcpipConnectionToHost").toLowerCase()))
                             || (TestUtils.getProperty(connectionString, "msiClientId") != null && e.getMessage()
-                                    .toLowerCase().contains(TestResource.getResource("R_loginFailedMSI").toLowerCase()))
+                                    .toLowerCase().contains(TestResource.getResource("R_loginFailedMI").toLowerCase()))
                             || ((isSqlAzure() || isSqlAzureDW()) ? e.getMessage().toLowerCase()
                                     .contains(TestResource.getResource("R_connectTimedOut").toLowerCase()) : false),
                     e.getMessage());
@@ -255,7 +255,7 @@ public void testConnectRetryServerError() {
                     (e.getMessage().toLowerCase()
                             .contains(TestResource.getResource("R_cannotOpenDatabase").toLowerCase()))
                             || (TestUtils.getProperty(connectionString, "msiClientId") != null && e.getMessage()
-                                    .toLowerCase().contains(TestResource.getResource("R_loginFailedMSI").toLowerCase()))
+                                    .toLowerCase().contains(TestResource.getResource("R_loginFailedMI").toLowerCase()))
                             || ((isSqlAzure() || isSqlAzureDW()) ? e.getMessage().toLowerCase()
                                     .contains(TestResource.getResource("R_connectTimedOut").toLowerCase()) : false),
                     e.getMessage());
@@ -294,7 +294,7 @@ public void testConnectRetryServerErrorDS() {
                     (e.getMessage().toLowerCase()
                             .contains(TestResource.getResource("R_cannotOpenDatabase").toLowerCase()))
                             || (TestUtils.getProperty(connectionString, "msiClientId") != null && e.getMessage()
-                                    .toLowerCase().contains(TestResource.getResource("R_loginFailedMSI").toLowerCase()))
+                                    .toLowerCase().contains(TestResource.getResource("R_loginFailedMI").toLowerCase()))
                             || ((isSqlAzure() || isSqlAzureDW()) ? e.getMessage().toLowerCase()
                                     .contains(TestResource.getResource("R_connectTimedOut").toLowerCase()) : false),
                     e.getMessage());
@@ -329,7 +329,7 @@ public void testConnectRetryTimeout() {
                     (e.getMessage().toLowerCase()
                             .contains(TestResource.getResource("R_cannotOpenDatabase").toLowerCase()))
                             || (TestUtils.getProperty(connectionString, "msiClientId") != null && e.getMessage()
-                                    .toLowerCase().contains(TestResource.getResource("R_loginFailedMSI").toLowerCase()))
+                                    .toLowerCase().contains(TestResource.getResource("R_loginFailedMI").toLowerCase()))
                             || ((isSqlAzure() || isSqlAzureDW()) ? e.getMessage().toLowerCase()
                                     .contains(TestResource.getResource("R_connectTimedOut").toLowerCase()) : false),
                     e.getMessage());

From 39e010448db078d3bdc538e3ede28423f3526097 Mon Sep 17 00:00:00 2001
From: lilgreenbird 
Date: Tue, 14 May 2024 12:37:57 -0700
Subject: [PATCH 09/12] update

---
 .../sqlserver/jdbc/SQLServerMSAL4JUtils.java  | 130 +-----------------
 .../jdbc/SQLServerSecurityUtility.java        | 127 +++++++++++++++++
 2 files changed, 128 insertions(+), 129 deletions(-)

diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java
index e807f9835..1ae69891b 100644
--- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java
+++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java
@@ -16,12 +16,10 @@
 
 import java.security.MessageDigest;
 import java.text.MessageFormat;
-import java.time.Duration;
-import java.util.Arrays;
+
 import java.util.Collections;
 import java.util.Enumeration;
 import java.util.HashSet;
-import java.util.Optional;
 import java.util.Set;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ConcurrentHashMap;
@@ -42,12 +40,6 @@
 import java.security.cert.X509Certificate;
 
 import com.microsoft.aad.msal4j.IAccount;
-import com.azure.core.credential.AccessToken;
-import com.azure.core.credential.TokenRequestContext;
-import com.azure.identity.DefaultAzureCredential;
-import com.azure.identity.DefaultAzureCredentialBuilder;
-import com.azure.identity.ManagedIdentityCredential;
-import com.azure.identity.ManagedIdentityCredentialBuilder;
 import com.microsoft.aad.msal4j.ClientCredentialFactory;
 import com.microsoft.aad.msal4j.ClientCredentialParameters;
 import com.microsoft.aad.msal4j.ConfidentialClientApplication;
@@ -73,12 +65,6 @@ class SQLServerMSAL4JUtils {
 
     private static final TokenCacheMap TOKEN_CACHE_MAP = new TokenCacheMap();
 
-    // Environment variable for intellij keepass database path
-    private static final String INTELLIJ_KEEPASS_PASS = "INTELLIJ_KEEPASS_PATH";
-
-    // Environment variable for additionally allowed tenants. The tenantIds are comma delimited
-    private static final String ADDITIONALLY_ALLOWED_TENANTS = "ADDITIONALLY_ALLOWED_TENANTS";
-
     private final static String LOGCONTEXT = "MSAL version "
             + com.microsoft.aad.msal4j.PublicClientApplication.class.getPackage().getImplementationVersion() + ": ";
 
@@ -447,120 +433,6 @@ static SqlAuthenticationToken getSqlFedAuthTokenInteractive(SqlFedAuthInfo fedAu
         }
     }
 
-    /**
-     * Get Managed Identity Authentication token through a ManagedIdentityCredential
-     * 
-     * @param resource
-     *        Token resource.
-     * @param managedIdentityClientId
-     *        Client ID of the user-assigned Managed Identity.
-     * @return fedauth token
-     * @throws SQLServerException
-     */
-    static SqlAuthenticationToken getManagedIdentityCredAuthToken(String resource, String managedIdentityClientId,
-            int retryCount, int retryInterval) throws SQLServerException {
-        ManagedIdentityCredential mic = null;
-
-        if (logger.isLoggable(java.util.logging.Level.FINEST)) {
-            logger.finest("Getting Managed Identity authentication token for: " + managedIdentityClientId);
-        }
-
-        if (null != managedIdentityClientId && !managedIdentityClientId.isEmpty()) {
-            mic = new ManagedIdentityCredentialBuilder().clientId(managedIdentityClientId).maxRetry(retryCount)
-                    .retryTimeout(duration -> Duration.ofSeconds(retryInterval)).build();
-        } else {
-            mic = new ManagedIdentityCredentialBuilder().maxRetry(retryCount)
-                    .retryTimeout(duration -> Duration.ofSeconds(retryInterval)).build();
-        }
-
-        TokenRequestContext tokenRequestContext = new TokenRequestContext();
-        String scope = resource.endsWith(SQLServerMSAL4JUtils.SLASH_DEFAULT) ? resource : resource
-                + SQLServerMSAL4JUtils.SLASH_DEFAULT;
-        tokenRequestContext.setScopes(Arrays.asList(scope));
-
-        SqlAuthenticationToken sqlFedAuthToken = null;
-
-        Optional accessTokenOptional = mic.getToken(tokenRequestContext).blockOptional();
-
-        if (!accessTokenOptional.isPresent()) {
-            throw new SQLServerException(SQLServerException.getErrString("R_ManagedIdentityTokenAcquisitionFail"),
-                    null);
-        } else {
-            AccessToken accessToken = accessTokenOptional.get();
-            sqlFedAuthToken = new SqlAuthenticationToken(accessToken.getToken(),
-                    accessToken.getExpiresAt().toEpochSecond());
-        }
-
-        if (logger.isLoggable(java.util.logging.Level.FINEST)) {
-            logger.finest("Got fedAuth token, expiry: " + sqlFedAuthToken.getExpiresOn().toString());
-        }
-
-        return sqlFedAuthToken;
-    }
-
-    /**
-     * Get Managed Identity Authentication token through the DefaultAzureCredential
-     *
-     * @param resource
-     *        Token resource.
-     * @param managedIdentityClientId
-     *        Client ID of the user-assigned Managed Identity.
-     * @return fedauth token
-     * @throws SQLServerException
-     */
-    static SqlAuthenticationToken getDefaultAzureCredAuthToken(String resource,
-            String managedIdentityClientId) throws SQLServerException {
-        String intellijKeepassPath = System.getenv(INTELLIJ_KEEPASS_PASS);
-        String[] additionallyAllowedTenants = getAdditonallyAllowedTenants();
-
-        DefaultAzureCredentialBuilder dacBuilder = new DefaultAzureCredentialBuilder();
-        DefaultAzureCredential dac = null;
-
-        if (null != managedIdentityClientId && !managedIdentityClientId.isEmpty()) {
-            dacBuilder.managedIdentityClientId(managedIdentityClientId);
-        }
-
-        if (null != intellijKeepassPath && !intellijKeepassPath.isEmpty()) {
-            dacBuilder.intelliJKeePassDatabasePath(intellijKeepassPath);
-        }
-
-        if (null != additionallyAllowedTenants && additionallyAllowedTenants.length != 0) {
-            dacBuilder.additionallyAllowedTenants(additionallyAllowedTenants);
-        }
-
-        dac = dacBuilder.build();
-
-        TokenRequestContext tokenRequestContext = new TokenRequestContext();
-        String scope = resource.endsWith(SQLServerMSAL4JUtils.SLASH_DEFAULT) ? resource : resource
-                + SQLServerMSAL4JUtils.SLASH_DEFAULT;
-        tokenRequestContext.setScopes(Arrays.asList(scope));
-
-        SqlAuthenticationToken sqlFedAuthToken = null;
-
-        Optional accessTokenOptional = dac.getToken(tokenRequestContext).blockOptional();
-
-        if (!accessTokenOptional.isPresent()) {
-            throw new SQLServerException(SQLServerException.getErrString("R_ManagedIdentityTokenAcquisitionFail"),
-                    null);
-        } else {
-            AccessToken accessToken = accessTokenOptional.get();
-            sqlFedAuthToken = new SqlAuthenticationToken(accessToken.getToken(),
-                    accessToken.getExpiresAt().toEpochSecond());
-        }
-
-        return sqlFedAuthToken;
-    }
-
-    private static String[] getAdditonallyAllowedTenants() {
-        String additonallyAllowedTenants = System.getenv(ADDITIONALLY_ALLOWED_TENANTS);
-
-        if (null != additonallyAllowedTenants && !additonallyAllowedTenants.isEmpty()) {
-            return System.getenv(ADDITIONALLY_ALLOWED_TENANTS).split(",");
-        }
-
-        return null;
-    }
-
     // Helper function to return account containing user name from set of accounts, or null if no match
     private static IAccount getAccountByUsername(Set accounts, String username) {
         if (!accounts.isEmpty()) {
diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerSecurityUtility.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerSecurityUtility.java
index e1e71613e..3bb2eb32d 100644
--- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerSecurityUtility.java
+++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerSecurityUtility.java
@@ -8,12 +8,21 @@
 import java.security.InvalidKeyException;
 import java.security.NoSuchAlgorithmException;
 import java.text.MessageFormat;
+import java.util.Arrays;
+import java.util.Optional;
 import java.util.Iterator;
 import java.util.List;
 
 import javax.crypto.Mac;
 import javax.crypto.spec.SecretKeySpec;
 
+import com.azure.core.credential.AccessToken;
+import com.azure.core.credential.TokenRequestContext;
+import com.azure.identity.ManagedIdentityCredential;
+import com.azure.identity.ManagedIdentityCredentialBuilder;
+import com.azure.identity.DefaultAzureCredential;
+import com.azure.identity.DefaultAzureCredentialBuilder;
+
 
 /**
  * Various SQLServer security utilities.
@@ -31,6 +40,12 @@ class SQLServerSecurityUtility {
 
     static final String WINDOWS_KEY_STORE_NAME = "MSSQL_CERTIFICATE_STORE";
 
+    // Environment variable for intellij keepass database path
+    private static final String INTELLIJ_KEEPASS_PASS = "INTELLIJ_KEEPASS_PATH";
+
+    // Environment variable for additionally allowed tenants. The tenantIds are comma delimited
+    private static final String ADDITIONALLY_ALLOWED_TENANTS = "ADDITIONALLY_ALLOWED_TENANTS";
+
     private SQLServerSecurityUtility() {
         throw new UnsupportedOperationException(SQLServerException.getErrString("R_notSupported"));
     }
@@ -303,4 +318,116 @@ static void verifyColumnMasterKeyMetadata(SQLServerConnection connection, SQLSer
             throw new SQLServerException(SQLServerException.getErrString("R_VerifySignatureFailed"), null);
         }
     }
+
+    /**
+     * Get Managed Identity Authentication token through a ManagedIdentityCredential
+     * 
+     * @param resource
+     *        Token resource.
+     * @param managedIdentityClientId
+     *        Client ID of the user-assigned Managed Identity.
+     * @return fedauth token
+     * @throws SQLServerException
+     */
+    static SqlAuthenticationToken getManagedIdentityCredAuthToken(String resource,
+            String managedIdentityClientId) throws SQLServerException {
+        ManagedIdentityCredential mic = null;
+
+        if (logger.isLoggable(java.util.logging.Level.FINEST)) {
+            logger.finest("Getting Managed Identity authentication token for: " + managedIdentityClientId);
+        }
+
+        if (null != managedIdentityClientId && !managedIdentityClientId.isEmpty()) {
+            mic = new ManagedIdentityCredentialBuilder().clientId(managedIdentityClientId).build();
+        } else {
+            mic = new ManagedIdentityCredentialBuilder().build();
+        }
+
+        TokenRequestContext tokenRequestContext = new TokenRequestContext();
+        String scope = resource.endsWith(SQLServerMSAL4JUtils.SLASH_DEFAULT) ? resource : resource
+                + SQLServerMSAL4JUtils.SLASH_DEFAULT;
+        tokenRequestContext.setScopes(Arrays.asList(scope));
+
+        SqlAuthenticationToken sqlFedAuthToken = null;
+
+        Optional accessTokenOptional = mic.getToken(tokenRequestContext).blockOptional();
+
+        if (!accessTokenOptional.isPresent()) {
+            throw new SQLServerException(SQLServerException.getErrString("R_ManagedIdentityTokenAcquisitionFail"),
+                    null);
+        } else {
+            AccessToken accessToken = accessTokenOptional.get();
+            sqlFedAuthToken = new SqlAuthenticationToken(accessToken.getToken(),
+                    accessToken.getExpiresAt().toEpochSecond());
+        }
+
+        if (logger.isLoggable(java.util.logging.Level.FINEST)) {
+            logger.finest("Got fedAuth token, expiry: " + sqlFedAuthToken.getExpiresOn().toString());
+        }
+
+        return sqlFedAuthToken;
+    }
+
+    /**
+     * Get Managed Identity Authentication token through the DefaultAzureCredential
+     *
+     * @param resource
+     *        Token resource.
+     * @param managedIdentityClientId
+     *        Client ID of the user-assigned Managed Identity.
+     * @return fedauth token
+     * @throws SQLServerException
+     */
+    static SqlAuthenticationToken getDefaultAzureCredAuthToken(String resource,
+            String managedIdentityClientId) throws SQLServerException {
+        String intellijKeepassPath = System.getenv(INTELLIJ_KEEPASS_PASS);
+        String[] additionallyAllowedTenants = getAdditonallyAllowedTenants();
+
+        DefaultAzureCredentialBuilder dacBuilder = new DefaultAzureCredentialBuilder();
+        DefaultAzureCredential dac = null;
+
+        if (null != managedIdentityClientId && !managedIdentityClientId.isEmpty()) {
+            dacBuilder.managedIdentityClientId(managedIdentityClientId);
+        }
+
+        if (null != intellijKeepassPath && !intellijKeepassPath.isEmpty()) {
+            dacBuilder.intelliJKeePassDatabasePath(intellijKeepassPath);
+        }
+
+        if (null != additionallyAllowedTenants && additionallyAllowedTenants.length != 0) {
+            dacBuilder.additionallyAllowedTenants(additionallyAllowedTenants);
+        }
+
+        dac = dacBuilder.build();
+
+        TokenRequestContext tokenRequestContext = new TokenRequestContext();
+        String scope = resource.endsWith(SQLServerMSAL4JUtils.SLASH_DEFAULT) ? resource : resource
+                + SQLServerMSAL4JUtils.SLASH_DEFAULT;
+        tokenRequestContext.setScopes(Arrays.asList(scope));
+
+        SqlAuthenticationToken sqlFedAuthToken = null;
+
+        Optional accessTokenOptional = dac.getToken(tokenRequestContext).blockOptional();
+
+        if (!accessTokenOptional.isPresent()) {
+            throw new SQLServerException(SQLServerException.getErrString("R_ManagedIdentityTokenAcquisitionFail"),
+                    null);
+        } else {
+            AccessToken accessToken = accessTokenOptional.get();
+            sqlFedAuthToken = new SqlAuthenticationToken(accessToken.getToken(),
+                    accessToken.getExpiresAt().toEpochSecond());
+        }
+
+        return sqlFedAuthToken;
+    }
+
+    private static String[] getAdditonallyAllowedTenants() {
+        String additonallyAllowedTenants = System.getenv(ADDITIONALLY_ALLOWED_TENANTS);
+
+        if (null != additonallyAllowedTenants && !additonallyAllowedTenants.isEmpty()) {
+            return System.getenv(ADDITIONALLY_ALLOWED_TENANTS).split(",");
+        }
+
+        return null;
+    }
 }

From e7b89f4ad0855ca8fa53dac996b9ee4263322e77 Mon Sep 17 00:00:00 2001
From: lilgreenbird 
Date: Tue, 14 May 2024 12:39:45 -0700
Subject: [PATCH 10/12] update

---
 .../sqlserver/jdbc/SQLServerConnection.java      | 16 ++++++----------
 1 file changed, 6 insertions(+), 10 deletions(-)

diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java
index 2bcf3737a..48275d2aa 100644
--- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java
+++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java
@@ -6043,18 +6043,14 @@ private SqlAuthenticationToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throw
                 String managedIdentityClientId = activeConnectionProperties
                         .getProperty(SQLServerDriverStringProperty.USER.toString());
 
-                // get and validate connection retry values
-                validateConnectionRetry();
-                
                 if (null != managedIdentityClientId && !managedIdentityClientId.isEmpty()) {
-                    fedAuthToken = SQLServerMSAL4JUtils.getManagedIdentityCredAuthToken(fedAuthInfo.spn,
-                            managedIdentityClientId, connectRetryCount, connectRetryInterval);
+                    fedAuthToken = SQLServerSecurityUtility.getManagedIdentityCredAuthToken(fedAuthInfo.spn,
+                            managedIdentityClientId);
                     break;
                 }
 
-                fedAuthToken = SQLServerMSAL4JUtils.getManagedIdentityCredAuthToken(fedAuthInfo.spn,
-                        activeConnectionProperties.getProperty(SQLServerDriverStringProperty.MSI_CLIENT_ID.toString()),
-                        connectRetryCount, connectRetryInterval);
+                fedAuthToken = SQLServerSecurityUtility.getManagedIdentityCredAuthToken(fedAuthInfo.spn,
+                        activeConnectionProperties.getProperty(SQLServerDriverStringProperty.MSI_CLIENT_ID.toString()));
 
                 // Break out of the retry loop in successful case.
                 break;
@@ -6181,12 +6177,12 @@ private SqlAuthenticationToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throw
                         .getProperty(SQLServerDriverStringProperty.USER.toString());
 
                 if (null != managedIdentityClientId && !managedIdentityClientId.isEmpty()) {
-                    fedAuthToken = SQLServerMSAL4JUtils.getDefaultAzureCredAuthToken(fedAuthInfo.spn,
+                    fedAuthToken = SQLServerSecurityUtility.getDefaultAzureCredAuthToken(fedAuthInfo.spn,
                             managedIdentityClientId);
                     break;
                 }
 
-                fedAuthToken = SQLServerMSAL4JUtils.getDefaultAzureCredAuthToken(fedAuthInfo.spn,
+                fedAuthToken = SQLServerSecurityUtility.getDefaultAzureCredAuthToken(fedAuthInfo.spn,
                         activeConnectionProperties.getProperty(SQLServerDriverStringProperty.MSI_CLIENT_ID.toString()));
 
                 break;

From 5c88f0fd89b027f077f7b3047d79ae6f46a6f6ea Mon Sep 17 00:00:00 2001
From: lilgreenbird 
Date: Tue, 14 May 2024 16:48:01 -0700
Subject: [PATCH 11/12] update

---
 src/test/java/com/microsoft/sqlserver/jdbc/TestUtils.java | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/TestUtils.java b/src/test/java/com/microsoft/sqlserver/jdbc/TestUtils.java
index 8bc97ada0..b17e86c0d 100644
--- a/src/test/java/com/microsoft/sqlserver/jdbc/TestUtils.java
+++ b/src/test/java/com/microsoft/sqlserver/jdbc/TestUtils.java
@@ -535,7 +535,7 @@ public static void dropDatabaseIfExists(String databaseName, String connectionSt
      */
     public static void dropSchemaIfExists(String schemaName, Statement stmt) throws SQLException {
         stmt.execute("if EXISTS (SELECT * FROM sys.schemas where name = '" + escapeSingleQuotes(schemaName)
-                + "') drop schema " + AbstractSQLGenerator.escapeIdentifier(schemaName));
+                + "') DROP SCHEMA" + AbstractSQLGenerator.escapeIdentifier(schemaName));
     }
 
     /**
@@ -1139,7 +1139,8 @@ public static String getConnectionID(
     public static void freeProcCache(Statement stmt) {
         try {
             stmt.execute("DBCC FREEPROCCACHE");
+        } catch (Exception e) {
             // ignore error - some tests fails due to permission issues from managed identity, this does not seem to affect tests
-        } catch (Exception e) {}
+        }
     }
 }

From 9ccf88d341b7e5b58566f9dfa8e4f6ff2de4c870 Mon Sep 17 00:00:00 2001
From: lilgreenbird 
Date: Tue, 21 May 2024 14:45:55 -0700
Subject: [PATCH 12/12] update

---
 src/test/java/com/microsoft/sqlserver/jdbc/fips/FipsTest.java | 1 -
 1 file changed, 1 deletion(-)

diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/fips/FipsTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/fips/FipsTest.java
index 64d1ebdf1..2e80407d3 100644
--- a/src/test/java/com/microsoft/sqlserver/jdbc/fips/FipsTest.java
+++ b/src/test/java/com/microsoft/sqlserver/jdbc/fips/FipsTest.java
@@ -69,7 +69,6 @@ public void fipsEncryptTest() throws Exception {
 
         Properties props = buildConnectionProperties();
         props.setProperty(Constants.ENCRYPT, Boolean.FALSE.toString());
-        System.out.println("fipsEncryptTest connectionString=" + connectionString);
         try (Connection con = PrepUtil.getConnection(connectionString, props)) {
             Assertions.fail(TestResource.getResource("R_expectedExceptionNotThrown"));
         } catch (SQLException e) {