diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java index 1ae69891b..9bb05a5a9 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java @@ -26,6 +26,7 @@ import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; import java.util.logging.Level; @@ -89,11 +90,21 @@ static SqlAuthenticationToken getSqlFedAuthToken(SqlFedAuthInfo fedAuthInfo, Str try { String hashedSecret = getHashedSecret(new String[] {fedAuthInfo.stsurl, user, password}); - PersistentTokenCacheAccessAspect persistentTokenCacheAccessAspect = TOKEN_CACHE_MAP.getEntry(hashedSecret); + PersistentTokenCacheAccessAspect persistentTokenCacheAccessAspect = TOKEN_CACHE_MAP.getEntry(user, + hashedSecret); + // check if account password was changed if (null == persistentTokenCacheAccessAspect) { persistentTokenCacheAccessAspect = new PersistentTokenCacheAccessAspect(); TOKEN_CACHE_MAP.addEntry(hashedSecret, persistentTokenCacheAccessAspect); + + if (logger.isLoggable(Level.FINEST)) { + logger.finest(LOGCONTEXT + ": cache token for user: " + user); + } + } else { + if (logger.isLoggable(Level.FINEST)) { + logger.finest(LOGCONTEXT + ": retrieved cached token for user: " + user); + } } final PublicClientApplication pca = PublicClientApplication @@ -145,11 +156,21 @@ static SqlAuthenticationToken getSqlFedAuthTokenPrincipal(SqlFedAuthInfo fedAuth try { String hashedSecret = getHashedSecret( new String[] {fedAuthInfo.stsurl, aadPrincipalID, aadPrincipalSecret}); - PersistentTokenCacheAccessAspect persistentTokenCacheAccessAspect = TOKEN_CACHE_MAP.getEntry(hashedSecret); + PersistentTokenCacheAccessAspect persistentTokenCacheAccessAspect = TOKEN_CACHE_MAP.getEntry(aadPrincipalID, + hashedSecret); + // check if principal secret was changed if (null == persistentTokenCacheAccessAspect) { persistentTokenCacheAccessAspect = new PersistentTokenCacheAccessAspect(); TOKEN_CACHE_MAP.addEntry(hashedSecret, persistentTokenCacheAccessAspect); + + if (logger.isLoggable(Level.FINEST)) { + logger.finest(LOGCONTEXT + ": cache token for principal id: " + aadPrincipalID); + } + } else { + if (logger.isLoggable(Level.FINEST)) { + logger.finest(LOGCONTEXT + ": retrieved cached token for principal id: " + aadPrincipalID); + } } IClientCredential credential = ClientCredentialFactory.createFromSecret(aadPrincipalSecret); @@ -202,11 +223,21 @@ static SqlAuthenticationToken getSqlFedAuthTokenPrincipalCertificate(SqlFedAuthI try { String hashedSecret = getHashedSecret(new String[] {fedAuthInfo.stsurl, aadPrincipalID, certFile, certPassword, certKey, certKeyPassword}); - PersistentTokenCacheAccessAspect persistentTokenCacheAccessAspect = TOKEN_CACHE_MAP.getEntry(hashedSecret); + PersistentTokenCacheAccessAspect persistentTokenCacheAccessAspect = TOKEN_CACHE_MAP.getEntry(aadPrincipalID, + hashedSecret); + // check if cert was changed if (null == persistentTokenCacheAccessAspect) { persistentTokenCacheAccessAspect = new PersistentTokenCacheAccessAspect(); TOKEN_CACHE_MAP.addEntry(hashedSecret, persistentTokenCacheAccessAspect); + + if (logger.isLoggable(Level.FINEST)) { + logger.finest(LOGCONTEXT + ": cache token for principal id: " + aadPrincipalID); + } + } else { + if (logger.isLoggable(Level.FINEST)) { + logger.finest(LOGCONTEXT + ": retrieved cached token for principal id: " + aadPrincipalID); + } } ConfidentialClientApplication clientApplication = null; @@ -493,18 +524,25 @@ private static String getHashedSecret(String[] secrets) throws SQLServerExceptio private static class TokenCacheMap { private ConcurrentHashMap tokenCacheMap = new ConcurrentHashMap<>(); - PersistentTokenCacheAccessAspect getEntry(String key) { + PersistentTokenCacheAccessAspect getEntry(String value, String key) { PersistentTokenCacheAccessAspect persistentTokenCacheAccessAspect = tokenCacheMap.get(key); if (null != persistentTokenCacheAccessAspect) { - if (System.currentTimeMillis() > persistentTokenCacheAccessAspect.getExpiryTime()) { + long currentTime = System.currentTimeMillis(); + + if (currentTime > persistentTokenCacheAccessAspect.getExpiryTime()) { tokenCacheMap.remove(key); persistentTokenCacheAccessAspect = new PersistentTokenCacheAccessAspect(); persistentTokenCacheAccessAspect - .setExpiryTime(System.currentTimeMillis() + PersistentTokenCacheAccessAspect.TIME_TO_LIVE); + .setExpiryTime(currentTime + PersistentTokenCacheAccessAspect.TIME_TO_LIVE); tokenCacheMap.put(key, persistentTokenCacheAccessAspect); + + if (logger.isLoggable(Level.FINEST)) { + logger.finest(LOGCONTEXT + ": entry expired for: " + value + " new entry will expire in: " + + TimeUnit.MILLISECONDS.toSeconds(PersistentTokenCacheAccessAspect.TIME_TO_LIVE) + "s"); + } } } @@ -514,6 +552,10 @@ PersistentTokenCacheAccessAspect getEntry(String key) { void addEntry(String key, PersistentTokenCacheAccessAspect value) { value.setExpiryTime(System.currentTimeMillis() + PersistentTokenCacheAccessAspect.TIME_TO_LIVE); tokenCacheMap.put(key, value); + if (logger.isLoggable(Level.FINEST)) { + logger.finest(LOGCONTEXT + ": add entry for: " + value + ", will expire in: " + + TimeUnit.MILLISECONDS.toSeconds(PersistentTokenCacheAccessAspect.TIME_TO_LIVE) + "s"); + } } } }