AzureTokenProvider.java

  1. package no.nav.data.common.security.azure;

  2. import com.github.benmanes.caffeine.cache.Cache;
  3. import com.github.benmanes.caffeine.cache.Caffeine;
  4. import com.microsoft.aad.msal4j.*;
  5. import com.microsoft.graph.serviceclient.GraphServiceClient;
  6. import com.microsoft.kiota.authentication.AccessTokenProvider;
  7. import com.microsoft.kiota.authentication.AllowedHostsValidator;
  8. import com.microsoft.kiota.authentication.BaseBearerTokenAuthenticationProvider;
  9. import com.nimbusds.oauth2.sdk.pkce.CodeChallengeMethod;
  10. import io.prometheus.client.Summary;
  11. import lombok.extern.slf4j.Slf4j;
  12. import no.nav.data.common.exceptions.TechnicalException;
  13. import no.nav.data.common.security.AuthService;
  14. import no.nav.data.common.security.Encryptor;
  15. import no.nav.data.common.security.TokenProvider;
  16. import no.nav.data.common.security.azure.support.AuthResultExpiry;
  17. import no.nav.data.common.security.domain.Auth;
  18. import no.nav.data.common.security.dto.Credential;
  19. import no.nav.data.common.security.dto.OAuthState;
  20. import no.nav.data.common.utils.Constants;
  21. import no.nav.data.common.utils.MetricUtils;
  22. import org.apache.commons.codec.binary.Base64;
  23. import org.apache.commons.codec.digest.DigestUtils;
  24. import org.apache.commons.lang3.StringUtils;
  25. import org.springframework.stereotype.Service;
  26. import org.springframework.util.Assert;
  27. import org.springframework.util.ReflectionUtils;

  28. import java.lang.reflect.InvocationTargetException;
  29. import java.lang.reflect.Method;
  30. import java.net.URI;
  31. import java.net.URL;
  32. import java.time.Duration;
  33. import java.util.Map;
  34. import java.util.Set;

  35. import static java.util.Objects.requireNonNull;
  36. import static no.nav.data.common.security.SecurityConstants.SESS_ID_LEN;
  37. import static no.nav.data.common.security.SecurityConstants.TOKEN_TYPE;
  38. import static no.nav.data.common.security.azure.AzureConstants.MICROSOFT_GRAPH_SCOPES;
  39. import static no.nav.data.common.security.azure.AzureConstants.MICROSOFT_GRAPH_SCOPE_APP;

  40. @Slf4j
  41. @Service
  42. public class AzureTokenProvider implements TokenProvider {
  43.     private final Cache<String, IAuthenticationResult> accessTokenCache;

  44.     private final IConfidentialClientApplication msalClient;
  45.     private final AuthService authService;

  46.     private final AADAuthenticationProperties aadAuthProps;
  47.     private final Encryptor encryptor;

  48.     private final Summary tokenMetrics;
  49.     private CustomAccessTokenProvider customAccessTokenProvider;

  50.     public AzureTokenProvider(AADAuthenticationProperties aadAuthProps,
  51.             IConfidentialClientApplication msalClient,
  52.             AuthService authService, Encryptor encryptor) {
  53.         this.aadAuthProps = aadAuthProps;
  54.         this.msalClient = msalClient;
  55.         this.authService = authService;
  56.         this.encryptor = encryptor;
  57.         this.tokenMetrics = MetricUtils.summary()
  58.                 .labels("accessToken").labels("lookupGrantedAuthorities")
  59.                 .labelNames("action")
  60.                 .name(Constants.APP_ID.replace('-', '_') + "_token_summary")
  61.                 .help("Time taken for azure token lookups")
  62.                 .quantile(.5, .01).quantile(.9, .01).quantile(.99, .001)
  63.                 .maxAgeSeconds(Duration.ofHours(24).getSeconds())
  64.                 .ageBuckets(8)
  65.                 .register();

  66.         this.accessTokenCache = Caffeine.newBuilder().recordStats()
  67.                 .expireAfter(new AuthResultExpiry())
  68.                 .maximumSize(1000).build();
  69.         MetricUtils.register("accessTokenCache", accessTokenCache);

  70.         customAccessTokenProvider = new CustomAccessTokenProvider(this);
  71.     }

  72.     GraphServiceClient getGraphClient() {
  73.         return new GraphServiceClient(new BaseBearerTokenAuthenticationProvider(customAccessTokenProvider));
  74.     }

  75.     @Override
  76.     public String getConsumerToken(String resource) {
  77.         return Credential.getCredential()
  78.                 .filter(Credential::hasAuth)
  79.                 .map(cred -> TOKEN_TYPE + getAccessTokenForResource(cred.getAuth().decryptRefreshToken(), resource))
  80.                 .orElseGet(() -> TOKEN_TYPE + getApplicationTokenForResource(resource));
  81.     }

  82.     public Auth getAuth(String session) {
  83.         Assert.isTrue(session.length() > SESS_ID_LEN, "invalid session");
  84.         var sessionId = session.substring(0, SESS_ID_LEN);
  85.         var sessionKey = session.substring(SESS_ID_LEN);
  86.         var auth = authService.getAuth(sessionId, sessionKey);
  87.         try {
  88.             String accessToken = getAccessTokenForResource(auth.decryptRefreshToken(), resourceForAppId());
  89.             auth.addAccessToken(accessToken);
  90.             return auth;
  91.         } catch (RuntimeException e) {
  92.             throw new TechnicalException("Failed to get access token for userId=%s initiated=%s".formatted(auth.getUserId(), auth.getInitiated()));
  93.         }
  94.     }

  95.     @Override
  96.     public void destroySession() {
  97.         Credential.getCredential().map(Credential::getAuth).ifPresent(auth -> authService.endSession(auth.getId()));
  98.     }

  99.     @Override
  100.     public String createAuthRequestRedirectUrl(String postLoginRedirectUri, String postLoginErrorUri, String redirectUri) {
  101.         var auth = authService.createAuth();
  102.         var codeVerifier = auth.getCodeVerifier();
  103.         var s256 = DigestUtils.sha256(codeVerifier);
  104.         var codeChallenge = Base64.encodeBase64URLSafeString(s256);
  105.         URL url = msalClient.getAuthorizationRequestUrl(AuthorizationRequestUrlParameters
  106.                 .builder(redirectUri, MICROSOFT_GRAPH_SCOPES)
  107.                 .state(new OAuthState(auth.getId().toString(), postLoginRedirectUri, postLoginErrorUri).toJson(encryptor))
  108.                 .responseMode(ResponseMode.FORM_POST)
  109.                 .codeChallengeMethod(CodeChallengeMethod.S256.getValue())
  110.                 .codeChallenge(codeChallenge)
  111.                 .build());
  112.         return url.toString();
  113.     }

  114.     @Override
  115.     public String createSession(String sessionId, String code, String redirectUri) {
  116.         try {
  117.             log.debug("Looking up token for auth code");
  118.             var codeVerifier = authService.getCodeVerifier(sessionId);
  119.             var authResult = msalClient.acquireToken(AuthorizationCodeParameters
  120.                     .builder(code, new URI(redirectUri))
  121.                     .scopes(MICROSOFT_GRAPH_SCOPES)
  122.                     .codeVerifier(codeVerifier)
  123.                     .build()).get();
  124.             String userId = StringUtils.substringBefore(authResult.account().homeAccountId(), ".");
  125.             String refreshToken = getRefreshTokenFromAuthResult(authResult);
  126.             return authService.initAuth(userId, refreshToken, sessionId);
  127.         } catch (Exception e) {
  128.             log.error("Failed to get token for auth code", e);
  129.             throw new TechnicalException("Failed to get token for auth code", e);
  130.         }
  131.     }

  132.     private String getRefreshTokenFromAuthResult(IAuthenticationResult authResult) throws ClassNotFoundException, IllegalAccessException, InvocationTargetException {
  133.         // interface is missing refreshToken...
  134.         Method refreshTokenMethod = ReflectionUtils.findMethod(Class.forName("com.microsoft.aad.msal4j.AuthenticationResult"), "refreshToken");
  135.         Assert.notNull(refreshTokenMethod, "couldn't find refreshToken method");
  136.         refreshTokenMethod.setAccessible(true);
  137.         return (String) refreshTokenMethod.invoke(authResult);
  138.     }

  139.     private String resourceForAppId() {
  140.         return aadAuthProps.getClientId() + "/.default";
  141.     }

  142.     String getApplicationTokenForResource(String resource) {
  143.         log.trace("Getting application token for resource {}", resource);
  144.         return requireNonNull(accessTokenCache.get("credential" + resource, cacheKey -> acquireTokenByCredential(resource))).accessToken();
  145.     }

  146.     private String getAccessTokenForResource(String refreshToken, String resource) {
  147.         log.trace("Getting access token for resource {}", resource);
  148.         return requireNonNull(accessTokenCache.get("refresh" + refreshToken + resource, cacheKey -> acquireTokenByRefreshToken(refreshToken, resource))).accessToken();
  149.     }

  150.     private IAuthenticationResult acquireTokenByRefreshToken(String refreshToken, String resource) {
  151.         try (var ignored = tokenMetrics.labels("accessToken").startTimer()) {
  152.             log.debug("Looking up access token for resource {}", resource);
  153.             return msalClient.acquireToken(RefreshTokenParameters.builder(Set.of(resource), refreshToken).build()).get();
  154.         } catch (Exception e) {
  155.             throw new TechnicalException("Failed to get access token for refreshToken", e);
  156.         }
  157.     }

  158.     /**
  159.      * access token for app user
  160.      */
  161.     private IAuthenticationResult acquireTokenByCredential(String resource) {
  162.         try {
  163.             log.debug("Looking up application token for resource {}", resource);
  164.             return msalClient.acquireToken(ClientCredentialParameters.builder(Set.of(resource)).build()).get();
  165.         } catch (Exception e) {
  166.             throw new TechnicalException("Failed to get access token for credential", e);
  167.         }
  168.     }

  169.     class CustomAccessTokenProvider implements AccessTokenProvider {
  170.         AzureTokenProvider tokenProvider;

  171.         public CustomAccessTokenProvider(AzureTokenProvider tokenProvider) {
  172.             this.tokenProvider = tokenProvider;
  173.         }

  174.         @Override
  175.         public String getAuthorizationToken(URI uri, Map<String, Object> additionalAuthenticationContex) {
  176.             return tokenProvider.getApplicationTokenForResource(MICROSOFT_GRAPH_SCOPE_APP);
  177.         }

  178.         // Make sure to have the right set of hosts
  179.         private final AllowedHostsValidator validator = new AllowedHostsValidator("graph.microsoft.com");

  180.         @Override
  181.         public AllowedHostsValidator getAllowedHostsValidator() {
  182.             // Handle allowed hosts validation logic here
  183.             return validator;
  184.         }
  185.     }
  186. }