AADStatelessAuthenticationFilter.java

  1. package no.nav.data.common.security.azure;

  2. import com.nimbusds.jose.JOSEException;
  3. import com.nimbusds.jose.JWSAlgorithm;
  4. import com.nimbusds.jose.JWSObject;
  5. import com.nimbusds.jose.jwk.source.JWKSource;
  6. import com.nimbusds.jose.jwk.source.RemoteJWKSet;
  7. import com.nimbusds.jose.proc.BadJOSEException;
  8. import com.nimbusds.jose.proc.JWSKeySelector;
  9. import com.nimbusds.jose.proc.JWSVerificationKeySelector;
  10. import com.nimbusds.jose.proc.SecurityContext;
  11. import com.nimbusds.jose.util.ResourceRetriever;
  12. import com.nimbusds.jwt.JWTClaimsSet;
  13. import com.nimbusds.jwt.proc.BadJWTException;
  14. import com.nimbusds.jwt.proc.ConfigurableJWTProcessor;
  15. import com.nimbusds.jwt.proc.DefaultJWTClaimsVerifier;
  16. import com.nimbusds.jwt.proc.DefaultJWTProcessor;
  17. import com.nimbusds.jwt.proc.JWTClaimsSetVerifier;
  18. import com.nimbusds.openid.connect.sdk.op.OIDCProviderMetadata;
  19. import io.prometheus.client.Counter;
  20. import jakarta.servlet.FilterChain;
  21. import jakarta.servlet.ServletException;
  22. import jakarta.servlet.http.Cookie;
  23. import jakarta.servlet.http.HttpServletRequest;
  24. import jakarta.servlet.http.HttpServletResponse;
  25. import lombok.extern.slf4j.Slf4j;
  26. import no.nav.data.common.security.AppIdMapping;
  27. import no.nav.data.common.security.AuthController;
  28. import no.nav.data.common.security.RoleSupport;
  29. import no.nav.data.common.security.domain.Auth;
  30. import no.nav.data.common.security.dto.Credential;
  31. import no.nav.data.common.utils.MetricUtils;
  32. import org.apache.commons.lang3.StringUtils;
  33. import org.springframework.http.HttpHeaders;
  34. import org.springframework.security.core.context.SecurityContextHolder;
  35. import org.springframework.security.web.authentication.preauth.PreAuthenticatedAuthenticationToken;
  36. import org.springframework.web.filter.OncePerRequestFilter;

  37. import java.io.IOException;
  38. import java.net.MalformedURLException;
  39. import java.text.ParseException;
  40. import java.util.HashSet;
  41. import java.util.List;
  42. import java.util.Optional;
  43. import java.util.Set;
  44. import java.util.stream.Stream;

  45. import static no.nav.data.common.security.SecurityConstants.COOKIE_NAME;
  46. import static no.nav.data.common.security.SecurityConstants.TOKEN_TYPE;
  47. import static org.springframework.util.StringUtils.hasText;

  48. @Slf4j
  49. public class AADStatelessAuthenticationFilter extends OncePerRequestFilter {

  50.     private static final Counter counter = initCounter();

  51.     private final AzureTokenProvider azureTokenProvider;
  52.     private final RoleSupport roleSupport;
  53.     private final List<String> allowedAppIds;
  54.     private final OIDCProviderMetadata oidcProviderMetadata;
  55.     private final JWKSource<SecurityContext> keySource;

  56.     public AADStatelessAuthenticationFilter(AzureTokenProvider azureTokenProvider, RoleSupport roleSupport, AppIdMapping appIdMapping,
  57.             AADAuthenticationProperties aadAuthProps, ResourceRetriever resourceRetriever, OIDCProviderMetadata oidcProviderMetadata) {
  58.         this.azureTokenProvider = azureTokenProvider;
  59.         this.roleSupport = roleSupport;
  60.         this.allowedAppIds = List.copyOf(appIdMapping.getIds());
  61.         this.oidcProviderMetadata = oidcProviderMetadata;

  62.         // azure spring
  63.         this.validAudiences.add(aadAuthProps.getClientId());
  64.         try {
  65.             keySource = new RemoteJWKSet<>(oidcProviderMetadata.getJWKSetURI().toURL(), resourceRetriever);
  66.         } catch (MalformedURLException e) {
  67.             log.error("Failed to parse active directory key discovery uri.", e);
  68.             throw new IllegalStateException("Failed to parse active directory key discovery uri.", e);
  69.         }
  70.     }

  71.     @Override
  72.     protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {
  73.         boolean cleanupRequired = false;

  74.         if (StringUtils.startsWith(request.getServletPath(), "/login")) {
  75.             counter.labels("login").inc();
  76.         } else {
  77.             cleanupRequired = authenticate(request, response);
  78.         }

  79.         try {
  80.             filterChain.doFilter(request, response);
  81.         } finally {
  82.             if (cleanupRequired) {
  83.                 SecurityContextHolder.clearContext();
  84.             }
  85.         }
  86.     }

  87.     private boolean authenticate(HttpServletRequest request, HttpServletResponse response) throws ServletException {
  88.         Credential credential = getCredential(request, response);
  89.         if (credential != null) {
  90.             try {
  91.                 var principal = buildUserPrincipal(credential.getAccessToken());
  92.                 var grantedAuthorities = roleSupport.lookupGrantedAuthorities(principal.getStringListClaim("groups"));
  93.                 var authentication = new PreAuthenticatedAuthenticationToken(principal, credential, grantedAuthorities);
  94.                 authentication.setDetails(new AzureUserInfo(principal, grantedAuthorities));
  95.                 authentication.setAuthenticated(true);
  96.                 log.trace("Request token verification success for subject {} with roles {}.", AzureUserInfo.getUserId(principal), grantedAuthorities);
  97.                 SecurityContextHolder.getContext().setAuthentication(authentication);
  98.                 return true;
  99.             } catch (BadJWTException ex) {
  100.                 String errorMessage = "Invalid JWT. Either expired or not yet valid. " + ex.getMessage();
  101.                 log.warn(errorMessage);
  102.                 throw new ServletException(errorMessage, ex);
  103.             } catch (ParseException | BadJOSEException | JOSEException ex) {
  104.                 log.error("Failed to initialize UserPrincipal.", ex);
  105.                 throw new ServletException(ex);
  106.             }
  107.         } else {
  108.             if (!StringUtils.startsWith(request.getServletPath(), "/internal")) {
  109.                 counter.labels("no_auth").inc();
  110.             }
  111.         }
  112.         return false;
  113.     }

  114.     private Credential getCredential(HttpServletRequest request, HttpServletResponse response) {
  115.         if (request.getCookies() != null) {
  116.             Optional<Cookie> cookie = Stream.of(request.getCookies())
  117.                     .filter(c -> c.getName().equals(COOKIE_NAME))
  118.                     .findFirst();
  119.             if (cookie.isPresent()) {
  120.                 try {
  121.                     String session = cookie.get().getValue();
  122.                     Auth auth = azureTokenProvider.getAuth(session);
  123.                     counter.labels("cookie").inc();
  124.                     return new Credential(auth);
  125.                 } catch (Exception e) {
  126.                     log.warn("Invalid auth cookie", e);
  127.                     response.addCookie(AuthController.createCookie(null, 0, request));
  128.                     return null;
  129.                 }
  130.             }
  131.         }
  132.         String authHeader = request.getHeader(HttpHeaders.AUTHORIZATION);
  133.         if (hasText(authHeader) && authHeader.startsWith(TOKEN_TYPE)) {
  134.             String authHeader1 = request.getHeader(HttpHeaders.AUTHORIZATION);
  135.             String token = authHeader1.replace(TOKEN_TYPE, "");
  136.             counter.labels("direct_token").inc();
  137.             return new Credential(token);
  138.         }
  139.         return null;
  140.     }

  141.     private JWTClaimsSet buildUserPrincipal(String token) throws ParseException, JOSEException, BadJOSEException {
  142.         var principal = buildAndValidateClaims(token);
  143.         String appIdClaim = AzureUserInfo.getAppId(principal);
  144.         if (appIdClaim == null || !allowedAppIds.contains(appIdClaim)) {
  145.             throw new BadJWTException("Invalid token appId. Provided value " + appIdClaim + " does not match allowed appId");
  146.         }
  147.         return principal;
  148.     }

  149.     private static Counter initCounter() {
  150.         return MetricUtils.counter()
  151.                 .labels("no_auth").labels("cookie").labels("direct_token").labels("login")
  152.                 .name("team_adal_auth_counter")
  153.                 .help("Counter for authentication events")
  154.                 .labelNames("action")
  155.                 .register();
  156.     }

  157.     // From spring azure start
  158.     private final Set<String> validAudiences = new HashSet<>();

  159.     public JWTClaimsSet buildAndValidateClaims(String idToken) throws ParseException, BadJOSEException, JOSEException {
  160.         final JWSObject jwsObject = JWSObject.parse(idToken);
  161.         final ConfigurableJWTProcessor<SecurityContext> validator =
  162.                 getAadJwtTokenValidator(jwsObject.getHeader().getAlgorithm());
  163.         final JWTClaimsSet jwtClaimsSet = validator.process(idToken, null);
  164.         final JWTClaimsSetVerifier<SecurityContext> verifier = validator.getJWTClaimsSetVerifier();
  165.         verifier.verify(jwtClaimsSet, null);

  166.         return jwtClaimsSet;
  167.     }

  168.     private ConfigurableJWTProcessor<SecurityContext> getAadJwtTokenValidator(JWSAlgorithm jwsAlgorithm) {
  169.         final ConfigurableJWTProcessor<SecurityContext> jwtProcessor = new DefaultJWTProcessor<>();

  170.         final JWSKeySelector<SecurityContext> keySelector =
  171.                 new JWSVerificationKeySelector<>(jwsAlgorithm, keySource);
  172.         jwtProcessor.setJWSKeySelector(keySelector);

  173.         jwtProcessor.setJWTClaimsSetVerifier(new DefaultJWTClaimsVerifier<>() {
  174.             @Override
  175.             public void verify(JWTClaimsSet claimsSet, SecurityContext ctx) throws BadJWTException {
  176.                 super.verify(claimsSet, ctx);
  177.                 final String issuer = claimsSet.getIssuer();
  178.                 if (issuer == null || !issuer.equals(oidcProviderMetadata.getIssuer().getValue())) {
  179.                     throw new BadJWTException("Invalid token issuer " + issuer);
  180.                 }
  181.                 final Optional<String> matchedAudience = claimsSet.getAudience().stream().filter(validAudiences::contains).findFirst();
  182.                 if (matchedAudience.isPresent()) {
  183.                     log.trace("Matched audience [{}]", matchedAudience.get());
  184.                 } else {
  185.                     throw new BadJWTException("Invalid token audience. Provided value " + claimsSet.getAudience() + "does not match neither client-id nor AppIdUri.");
  186.                 }
  187.             }
  188.         });
  189.         return jwtProcessor;
  190.     }

  191. }