AzureTokenProvider.java
package no.nav.data.common.security.azure;
import com.github.benmanes.caffeine.cache.Cache;
import com.github.benmanes.caffeine.cache.Caffeine;
import com.microsoft.aad.msal4j.*;
import com.microsoft.graph.serviceclient.GraphServiceClient;
import com.microsoft.kiota.authentication.AccessTokenProvider;
import com.microsoft.kiota.authentication.AllowedHostsValidator;
import com.microsoft.kiota.authentication.BaseBearerTokenAuthenticationProvider;
import com.nimbusds.oauth2.sdk.pkce.CodeChallengeMethod;
import io.prometheus.client.Summary;
import lombok.extern.slf4j.Slf4j;
import no.nav.data.common.exceptions.TechnicalException;
import no.nav.data.common.security.AuthService;
import no.nav.data.common.security.Encryptor;
import no.nav.data.common.security.TokenProvider;
import no.nav.data.common.security.azure.support.AuthResultExpiry;
import no.nav.data.common.security.domain.Auth;
import no.nav.data.common.security.dto.Credential;
import no.nav.data.common.security.dto.OAuthState;
import no.nav.data.common.utils.Constants;
import no.nav.data.common.utils.MetricUtils;
import org.apache.commons.codec.binary.Base64;
import org.apache.commons.codec.digest.DigestUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.stereotype.Service;
import org.springframework.util.Assert;
import org.springframework.util.ReflectionUtils;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.URI;
import java.net.URL;
import java.time.Duration;
import java.util.Map;
import java.util.Set;
import static java.util.Objects.requireNonNull;
import static no.nav.data.common.security.SecurityConstants.SESS_ID_LEN;
import static no.nav.data.common.security.SecurityConstants.TOKEN_TYPE;
import static no.nav.data.common.security.azure.AzureConstants.MICROSOFT_GRAPH_SCOPES;
import static no.nav.data.common.security.azure.AzureConstants.MICROSOFT_GRAPH_SCOPE_APP;
@Slf4j
@Service
public class AzureTokenProvider implements TokenProvider {
private final Cache<String, IAuthenticationResult> accessTokenCache;
private final IConfidentialClientApplication msalClient;
private final AuthService authService;
private final AADAuthenticationProperties aadAuthProps;
private final Encryptor encryptor;
private final Summary tokenMetrics;
private CustomAccessTokenProvider customAccessTokenProvider;
public AzureTokenProvider(AADAuthenticationProperties aadAuthProps,
IConfidentialClientApplication msalClient,
AuthService authService, Encryptor encryptor) {
this.aadAuthProps = aadAuthProps;
this.msalClient = msalClient;
this.authService = authService;
this.encryptor = encryptor;
this.tokenMetrics = MetricUtils.summary()
.labels("accessToken").labels("lookupGrantedAuthorities")
.labelNames("action")
.name(Constants.APP_ID.replace('-', '_') + "_token_summary")
.help("Time taken for azure token lookups")
.quantile(.5, .01).quantile(.9, .01).quantile(.99, .001)
.maxAgeSeconds(Duration.ofHours(24).getSeconds())
.ageBuckets(8)
.register();
this.accessTokenCache = Caffeine.newBuilder().recordStats()
.expireAfter(new AuthResultExpiry())
.maximumSize(2000).build();
MetricUtils.register("accessTokenCache", accessTokenCache);
customAccessTokenProvider = new CustomAccessTokenProvider(this);
}
GraphServiceClient getGraphClient() {
return new GraphServiceClient(new BaseBearerTokenAuthenticationProvider(customAccessTokenProvider));
}
@Override
public String getConsumerToken(String resource) {
return Credential.getCredential()
.filter(Credential::hasAuth)
.map(cred -> TOKEN_TYPE + getAccessTokenForResource(cred.getAuth().decryptRefreshToken(), resource))
.orElseGet(() -> TOKEN_TYPE + getApplicationTokenForResource(resource));
}
public Auth getAuth(String session) {
Assert.isTrue(session.length() > SESS_ID_LEN, "invalid session");
var sessionId = session.substring(0, SESS_ID_LEN);
var sessionKey = session.substring(SESS_ID_LEN);
var auth = authService.getAuth(sessionId, sessionKey);
try {
String accessToken = getAccessTokenForResource(auth.decryptRefreshToken(), resourceForAppId());
auth.addAccessToken(accessToken);
return auth;
} catch (RuntimeException e) {
throw new TechnicalException("Failed to get access token for userId=%s initiated=%s".formatted(auth.getUserId(), auth.getInitiated()));
}
}
@Override
public void destroySession() {
Credential.getCredential().map(Credential::getAuth).ifPresent(auth -> authService.endSession(auth.getId()));
}
@Override
public String createAuthRequestRedirectUrl(String postLoginRedirectUri, String postLoginErrorUri, String redirectUri) {
var auth = authService.createAuth();
var codeVerifier = auth.getCodeVerifier();
var s256 = DigestUtils.sha256(codeVerifier);
var codeChallenge = Base64.encodeBase64URLSafeString(s256);
URL url = msalClient.getAuthorizationRequestUrl(AuthorizationRequestUrlParameters
.builder(redirectUri, MICROSOFT_GRAPH_SCOPES)
.state(new OAuthState(auth.getId().toString(), postLoginRedirectUri, postLoginErrorUri).toJson(encryptor))
.responseMode(ResponseMode.FORM_POST)
.codeChallengeMethod(CodeChallengeMethod.S256.getValue())
.codeChallenge(codeChallenge)
.build());
return url.toString();
}
@Override
public String createSession(String sessionId, String code, String redirectUri) {
try {
log.debug("Looking up token for auth code");
var codeVerifier = authService.getCodeVerifier(sessionId);
var authResult = msalClient.acquireToken(AuthorizationCodeParameters
.builder(code, new URI(redirectUri))
.scopes(MICROSOFT_GRAPH_SCOPES)
.codeVerifier(codeVerifier)
.build()).get();
String userId = StringUtils.substringBefore(authResult.account().homeAccountId(), ".");
String refreshToken = getRefreshTokenFromAuthResult(authResult);
return authService.initAuth(userId, refreshToken, sessionId);
} catch (Exception e) {
log.error("Failed to get token for auth code", e);
throw new TechnicalException("Failed to get token for auth code", e);
}
}
private String getRefreshTokenFromAuthResult(IAuthenticationResult authResult) throws ClassNotFoundException, IllegalAccessException, InvocationTargetException {
// interface is missing refreshToken...
Method refreshTokenMethod = ReflectionUtils.findMethod(Class.forName("com.microsoft.aad.msal4j.AuthenticationResult"), "refreshToken");
Assert.notNull(refreshTokenMethod, "couldn't find refreshToken method");
refreshTokenMethod.setAccessible(true);
return (String) refreshTokenMethod.invoke(authResult);
}
private String resourceForAppId() {
return aadAuthProps.getClientId() + "/.default";
}
String getApplicationTokenForResource(String resource) {
log.trace("Getting application token for resource {}", resource);
return requireNonNull(accessTokenCache.get("credential" + resource, cacheKey -> acquireTokenByCredential(resource))).accessToken();
}
private String getAccessTokenForResource(String refreshToken, String resource) {
log.trace("Getting access token for resource {}", resource);
return requireNonNull(accessTokenCache.get("refresh" + refreshToken + resource, cacheKey -> acquireTokenByRefreshToken(refreshToken, resource))).accessToken();
}
private IAuthenticationResult acquireTokenByRefreshToken(String refreshToken, String resource) {
try (var ignored = tokenMetrics.labels("accessToken").startTimer()) {
log.debug("Looking up access token for resource {}", resource);
return msalClient.acquireToken(RefreshTokenParameters.builder(Set.of(resource), refreshToken).build()).get();
} catch (Exception e) {
throw new TechnicalException("Failed to get access token for refreshToken", e);
}
}
/**
* access token for app user
*/
private IAuthenticationResult acquireTokenByCredential(String resource) {
try {
log.debug("Looking up application token for resource {}", resource);
return msalClient.acquireToken(ClientCredentialParameters.builder(Set.of(resource)).build()).get();
} catch (Exception e) {
throw new TechnicalException("Failed to get access token for credential", e);
}
}
class CustomAccessTokenProvider implements AccessTokenProvider {
AzureTokenProvider tokenProvider;
public CustomAccessTokenProvider(AzureTokenProvider tokenProvider) {
this.tokenProvider = tokenProvider;
}
@Override
public String getAuthorizationToken(URI uri, Map<String, Object> additionalAuthenticationContex) {
return tokenProvider.getApplicationTokenForResource(MICROSOFT_GRAPH_SCOPE_APP);
}
// Make sure to have the right set of hosts
private final AllowedHostsValidator validator = new AllowedHostsValidator("graph.microsoft.com");
@Override
public AllowedHostsValidator getAllowedHostsValidator() {
// Handle allowed hosts validation logic here
return validator;
}
}
}