Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates for running tests with managed identity #2416

Merged
merged 14 commits into from
May 21, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ protected static void createTable(String tableName, String cekName, String table
TestUtils.dropTableIfExists(tableName, stmt);
sql = String.format(createSql, tableName, sql);
stmt.execute(sql);
stmt.execute("DBCC FREEPROCCACHE");
TestUtils.freeProcCache(stmt);
} catch (SQLException e) {
fail(e.getMessage());
}
Expand Down Expand Up @@ -373,7 +373,7 @@ protected static void createPrecisionTable(String tableName, String table[][], S
}
sql = String.format(createSql, tableName, sql);
stmt.execute(sql);
stmt.execute("DBCC FREEPROCCACHE");
TestUtils.freeProcCache(stmt);
} catch (SQLException e) {
fail(e.getMessage());
}
Expand Down Expand Up @@ -401,7 +401,7 @@ protected static void createScaleTable(String tableName, String table[][], Strin

sql = String.format(createSql, tableName, sql);
stmt.execute(sql);
stmt.execute("DBCC FREEPROCCACHE");
TestUtils.freeProcCache(stmt);
} catch (SQLException e) {
fail(e.getMessage());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2201,7 +2201,7 @@ protected static void createDateTableCallableStatement(String cekName) throws SQ
SQLServerStatement stmt = (SQLServerStatement) con.createStatement()) {
TestUtils.dropTableIfExists(DATE_TABLE_AE, stmt);
stmt.execute(sql);
stmt.execute("DBCC FREEPROCCACHE");
TestUtils.freeProcCache(stmt);
} catch (SQLException e) {
fail(e.getMessage());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,13 @@ public void testConnectCountInLoginAndCorrectRetryCount() {
assertTrue(con == null, TestResource.getResource("R_shouldNotConnect"));
}
} catch (Exception e) {
assertTrue(e.getMessage().contains(TestResource.getResource("R_cannotOpenDatabase")), e.getMessage());
assertTrue(
e.getMessage().contains(TestResource.getResource("R_cannotOpenDatabase"))
|| (TestUtils.getProperty(connectionString, "msiClientId") != null && (e.getMessage()
.toLowerCase().contains(TestResource.getResource("R_loginFailedMI").toLowerCase())
|| e.getMessage().toLowerCase()
.contains(TestResource.getResource("R_MInotAvailable").toLowerCase()))),
e.getMessage());
long totalTime = System.currentTimeMillis() - timerStart;

// Maximum is unknown, but is needs to be less than longLoginTimeout or else this is an issue.
Expand Down Expand Up @@ -795,13 +801,22 @@ public void testIncorrectDatabase() throws SQLException {
assertTrue(timeDiff <= milsecs, form.format(msgArgs));
}
} catch (Exception e) {
assertTrue(e.getMessage().contains(TestResource.getResource("R_cannotOpenDatabase")), e.getMessage());
assertTrue(
e.getMessage().contains(TestResource.getResource("R_cannotOpenDatabase"))
|| (TestUtils.getProperty(connectionString, "msiClientId") != null
&& e.getMessage().toLowerCase()
.contains(TestResource.getResource("R_loginFailedMI").toLowerCase())),
e.getMessage());
timerEnd = System.currentTimeMillis();
}
}

@Test
public void testIncorrectUserName() throws SQLException {
String auth = TestUtils.getProperty(connectionString, "authentication");
org.junit.Assume.assumeTrue(auth != null
&& (auth.equalsIgnoreCase("SqlPassword") || auth.equalsIgnoreCase("ActiveDirectoryPassword")));

long timerStart = 0;
long timerEnd = 0;
final long milsecs = threshHoldForNoRetryInMilliseconds;
Expand All @@ -819,13 +834,22 @@ public void testIncorrectUserName() throws SQLException {
assertTrue(timeDiff <= milsecs, form.format(msgArgs));
}
} catch (Exception e) {
assertTrue(e.getMessage().contains(TestResource.getResource("R_loginFailed")));
assertTrue(
e.getMessage().contains(TestResource.getResource("R_loginFailed"))
|| (TestUtils.getProperty(connectionString, "msiClientId") != null
&& e.getMessage().toLowerCase()
.contains(TestResource.getResource("R_loginFailedMI").toLowerCase())),
e.getMessage());
timerEnd = System.currentTimeMillis();
}
}

@Test
public void testIncorrectPassword() throws SQLException {
String auth = TestUtils.getProperty(connectionString, "authentication");
org.junit.Assume.assumeTrue(auth != null
&& (auth.equalsIgnoreCase("SqlPassword") || auth.equalsIgnoreCase("ActiveDirectoryPassword")));

long timerStart = 0;
long timerEnd = 0;
final long milsecs = threshHoldForNoRetryInMilliseconds;
Expand All @@ -843,7 +867,12 @@ public void testIncorrectPassword() throws SQLException {
assertTrue(timeDiff <= milsecs, form.format(msgArgs));
}
} catch (Exception e) {
assertTrue(e.getMessage().contains(TestResource.getResource("R_loginFailed")));
assertTrue(
e.getMessage().contains(TestResource.getResource("R_loginFailed"))
|| (TestUtils.getProperty(connectionString, "msiClientId") != null
&& e.getMessage().toLowerCase()
.contains(TestResource.getResource("R_loginFailedMI").toLowerCase())),
e.getMessage());
timerEnd = System.currentTimeMillis();
}
}
Expand Down
3 changes: 2 additions & 1 deletion src/test/java/com/microsoft/sqlserver/jdbc/TestResource.java
Original file line number Diff line number Diff line change
Expand Up @@ -220,5 +220,6 @@ protected Object[][] getContents() {
{"R_unexpectedThreadCount", "Thread count is higher than expected."},
{"R_expectedClassDoesNotMatchActualClass",
"Expected column class {0} does not match actual column class {1} for column {2}."},
{"R_loginFailedMSI", "Login failed for user '<token-identified principal>'"}};
{"R_loginFailedMI", "Login failed for user '<token-identified principal>'"},
{"R_MInotAvailable", "Managed Identity authentication is not available"},};
}
10 changes: 9 additions & 1 deletion src/test/java/com/microsoft/sqlserver/jdbc/TestUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ public static void dropDatabaseIfExists(String databaseName, String connectionSt
*/
public static void dropSchemaIfExists(String schemaName, Statement stmt) throws SQLException {
stmt.execute("if EXISTS (SELECT * FROM sys.schemas where name = '" + escapeSingleQuotes(schemaName)
+ "') DROP SCHEMA " + AbstractSQLGenerator.escapeIdentifier(schemaName));
+ "') DROP SCHEMA" + AbstractSQLGenerator.escapeIdentifier(schemaName));
}

/**
Expand Down Expand Up @@ -1135,4 +1135,12 @@ public static String getConnectionID(
SQLServerConnection conn = (SQLServerConnection) physicalConnection.get(pc);
return (String) traceID.get(conn);
}

public static void freeProcCache(Statement stmt) {
try {
stmt.execute("DBCC FREEPROCCACHE");
} catch (Exception e) {
// ignore error - some tests fails due to permission issues from managed identity, this does not seem to affect tests
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,10 @@ public void testConnectionPoolClose() throws SQLException {

@Test
public void testConnectionPoolClientConnectionId() throws SQLException {
String auth = TestUtils.getProperty(connectionString, "authentication");
org.junit.Assume.assumeTrue(auth != null
&& (auth.equalsIgnoreCase("SqlPassword") || auth.equalsIgnoreCase("ActiveDirectoryPassword")));

SQLServerXADataSource ds = new SQLServerXADataSource();
ds.setURL(connectionString);
PooledConnection pc = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ public void testDefaultLoginTimeout() {
assertTrue(
(e.getMessage().toLowerCase()
.contains(TestResource.getResource("R_tcpipConnectionToHost").toLowerCase()))
|| (TestUtils.getProperty(connectionString, "msiClientId") != null && e.getMessage()
.toLowerCase().contains(TestResource.getResource("R_loginFailedMI").toLowerCase()))
|| ((isSqlAzure() || isSqlAzureDW()) ? e.getMessage().toLowerCase()
.contains(TestResource.getResource("R_connectTimedOut").toLowerCase()) : false),
e.getMessage());
Expand Down Expand Up @@ -95,6 +97,8 @@ public void testURLLoginTimeout() {
assertTrue(
(e.getMessage().toLowerCase()
.contains(TestResource.getResource("R_tcpipConnectionToHost").toLowerCase()))
|| (TestUtils.getProperty(connectionString, "msiClientId") != null && e.getMessage()
.toLowerCase().contains(TestResource.getResource("R_loginFailedMI").toLowerCase()))
|| ((isSqlAzure() || isSqlAzureDW()) ? e.getMessage().toLowerCase()
.contains(TestResource.getResource("R_connectTimedOut").toLowerCase()) : false),
e.getMessage());
Expand Down Expand Up @@ -122,6 +126,8 @@ public void testDMLoginTimeoutApplied() {
assertTrue(
(e.getMessage().toLowerCase()
.contains(TestResource.getResource("R_tcpipConnectionToHost").toLowerCase()))
|| (TestUtils.getProperty(connectionString, "msiClientId") != null && e.getMessage()
.toLowerCase().contains(TestResource.getResource("R_loginFailedMI").toLowerCase()))
|| ((isSqlAzure() || isSqlAzureDW()) ? e.getMessage().toLowerCase()
.contains(TestResource.getResource("R_connectTimedOut").toLowerCase()) : false),
e.getMessage());
Expand Down Expand Up @@ -150,6 +156,9 @@ public void testDMLoginTimeoutNotApplied() {
assertTrue(
(e.getMessage().toLowerCase()
.contains(TestResource.getResource("R_tcpipConnectionToHost").toLowerCase()))
|| (TestUtils.getProperty(connectionString, "msiClientId") != null
&& e.getMessage().toLowerCase()
.contains(TestResource.getResource("R_loginFailedMI").toLowerCase()))
|| ((isSqlAzure() || isSqlAzureDW()) ? e.getMessage().toLowerCase()
.contains(TestResource.getResource("R_connectTimedOut").toLowerCase()) : false),
e.getMessage());
Expand Down Expand Up @@ -181,6 +190,8 @@ public void testConnectRetryDisable() {

assertTrue(
e.getMessage().matches(TestUtils.formatErrorMsg("R_tcpipConnectionFailed"))
|| (TestUtils.getProperty(connectionString, "msiClientId") != null && e.getMessage()
.toLowerCase().contains(TestResource.getResource("R_loginFailedMI").toLowerCase()))
|| ((isSqlAzure() || isSqlAzureDW()) ? e.getMessage().toLowerCase()
.contains(TestResource.getResource("R_connectTimedOut").toLowerCase()) : false),
e.getMessage());
Expand All @@ -207,6 +218,8 @@ public void testConnectRetryBadServer() {
assertTrue(
(e.getMessage().toLowerCase()
.contains(TestResource.getResource("R_tcpipConnectionToHost").toLowerCase()))
|| (TestUtils.getProperty(connectionString, "msiClientId") != null && e.getMessage()
.toLowerCase().contains(TestResource.getResource("R_loginFailedMI").toLowerCase()))
|| ((isSqlAzure() || isSqlAzureDW()) ? e.getMessage().toLowerCase()
.contains(TestResource.getResource("R_connectTimedOut").toLowerCase()) : false),
e.getMessage());
Expand All @@ -220,6 +233,10 @@ public void testConnectRetryBadServer() {
// Test connect retry for database error
@Test
public void testConnectRetryServerError() {
String auth = TestUtils.getProperty(connectionString, "authentication");
org.junit.Assume.assumeTrue(auth != null
&& (auth.equalsIgnoreCase("SqlPassword") || auth.equalsIgnoreCase("ActiveDirectoryPassword")));

long totalTime = 0;
long timerStart = System.currentTimeMillis();
int interval = defaultTimeout; // long interval so we can tell if there was a retry
Expand All @@ -237,6 +254,8 @@ public void testConnectRetryServerError() {
assertTrue(
(e.getMessage().toLowerCase()
.contains(TestResource.getResource("R_cannotOpenDatabase").toLowerCase()))
|| (TestUtils.getProperty(connectionString, "msiClientId") != null && e.getMessage()
.toLowerCase().contains(TestResource.getResource("R_loginFailedMI").toLowerCase()))
|| ((isSqlAzure() || isSqlAzureDW()) ? e.getMessage().toLowerCase()
.contains(TestResource.getResource("R_connectTimedOut").toLowerCase()) : false),
e.getMessage());
Expand All @@ -252,6 +271,10 @@ public void testConnectRetryServerError() {
// Test connect retry for database error using Datasource
@Test
public void testConnectRetryServerErrorDS() {
String auth = TestUtils.getProperty(connectionString, "authentication");
org.junit.Assume.assumeTrue(auth != null
&& (auth.equalsIgnoreCase("SqlPassword") || auth.equalsIgnoreCase("ActiveDirectoryPassword")));

long totalTime = 0;
long timerStart = System.currentTimeMillis();
int interval = defaultTimeout; // long interval so we can tell if there was a retry
Expand All @@ -270,6 +293,8 @@ public void testConnectRetryServerErrorDS() {
assertTrue(
(e.getMessage().toLowerCase()
.contains(TestResource.getResource("R_cannotOpenDatabase").toLowerCase()))
|| (TestUtils.getProperty(connectionString, "msiClientId") != null && e.getMessage()
.toLowerCase().contains(TestResource.getResource("R_loginFailedMI").toLowerCase()))
|| ((isSqlAzure() || isSqlAzureDW()) ? e.getMessage().toLowerCase()
.contains(TestResource.getResource("R_connectTimedOut").toLowerCase()) : false),
e.getMessage());
Expand Down Expand Up @@ -303,6 +328,8 @@ public void testConnectRetryTimeout() {
assertTrue(
(e.getMessage().toLowerCase()
.contains(TestResource.getResource("R_cannotOpenDatabase").toLowerCase()))
|| (TestUtils.getProperty(connectionString, "msiClientId") != null && e.getMessage()
.toLowerCase().contains(TestResource.getResource("R_loginFailedMI").toLowerCase()))
|| ((isSqlAzure() || isSqlAzureDW()) ? e.getMessage().toLowerCase()
.contains(TestResource.getResource("R_connectTimedOut").toLowerCase()) : false),
e.getMessage());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,10 @@ public void testGetURL() throws SQLException {
*/
@Test
public void testDBUserLogin() throws SQLException {
String auth = TestUtils.getProperty(connectionString, "authentication");
org.junit.Assume.assumeTrue(auth != null
&& (auth.equalsIgnoreCase("SqlPassword") || auth.equalsIgnoreCase("ActiveDirectoryPassword")));

try (Connection conn = getConnection()) {
DatabaseMetaData databaseMetaData = conn.getMetaData();
String connectionString = getConnectionString();
Expand All @@ -219,7 +223,8 @@ public void testDBUserLogin() throws SQLException {

assertNotNull(userName, TestResource.getResource("R_userNameNull"));
assertTrue(userName.equalsIgnoreCase(userFromConnectionString),
TestResource.getResource("R_userNameNotMatch"));
TestResource.getResource("R_userNameNotMatch") + "userName: " + userName + "from connectio string: "
+ userFromConnectionString);
} catch (Exception e) {
fail(TestResource.getResource("R_unexpectedErrorMessage") + e.getMessage());
}
Expand Down Expand Up @@ -1049,8 +1054,6 @@ public static void terminate() throws SQLException {
TestUtils.dropTableWithSchemaIfExists(tableNameWithSchema, stmt);
TestUtils.dropProcedureWithSchemaIfExists(sprocWithSchema, stmt);
TestUtils.dropSchemaIfExists(schema, stmt);
TestUtils.dropUserIfExists(newUserName, stmt);
TestUtils.dropLoginIfExists(newUserName, stmt);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ private void createCMK(String cmkName, String keyStoreName, String keyPath, Stat
private void callDbccFreeProcCache() throws SQLException {
try (Connection connection = DriverManager.getConnection(adPasswordConnectionStr);
Statement stmt = connection.createStatement()) {
stmt.execute("DBCC FREEPROCCACHE");
TestUtils.freeProcCache(stmt);
}
}

Expand Down
Loading
Loading