AADStatelessAuthenticationFilter.java

package no.nav.data.common.security.azure;

import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.JWSObject;
import com.nimbusds.jose.jwk.source.JWKSource;
import com.nimbusds.jose.jwk.source.JWKSourceBuilder;
import com.nimbusds.jose.proc.BadJOSEException;
import com.nimbusds.jose.proc.JWSKeySelector;
import com.nimbusds.jose.proc.JWSVerificationKeySelector;
import com.nimbusds.jose.proc.SecurityContext;
import com.nimbusds.jose.util.ResourceRetriever;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.proc.*;
import com.nimbusds.openid.connect.sdk.op.OIDCProviderMetadata;
import io.prometheus.client.Counter;
import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.Cookie;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import lombok.extern.slf4j.Slf4j;
import no.nav.data.common.security.AppIdMapping;
import no.nav.data.common.security.AuthController;
import no.nav.data.common.security.RoleSupport;
import no.nav.data.common.security.domain.Auth;
import no.nav.data.common.security.dto.Credential;
import no.nav.data.common.utils.MetricUtils;
import org.apache.commons.lang3.Strings;
import org.springframework.http.HttpHeaders;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.web.authentication.preauth.PreAuthenticatedAuthenticationToken;
import org.springframework.web.filter.OncePerRequestFilter;

import java.io.IOException;
import java.net.MalformedURLException;
import java.text.ParseException;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Stream;

import static no.nav.data.common.security.SecurityConstants.COOKIE_NAME;
import static no.nav.data.common.security.SecurityConstants.TOKEN_TYPE;
import static org.springframework.util.StringUtils.hasText;

@Slf4j
public class AADStatelessAuthenticationFilter extends OncePerRequestFilter {

    private static final Counter counter = initCounter();

    private final AzureTokenProvider azureTokenProvider;
    private final RoleSupport roleSupport;
    private final List<String> allowedAppIds;
    private final OIDCProviderMetadata oidcProviderMetadata;
    private final JWKSource<SecurityContext> keySource;

    public AADStatelessAuthenticationFilter(AzureTokenProvider azureTokenProvider, RoleSupport roleSupport, AppIdMapping appIdMapping,
            AADAuthenticationProperties aadAuthProps, ResourceRetriever resourceRetriever, OIDCProviderMetadata oidcProviderMetadata) {
        this.azureTokenProvider = azureTokenProvider;
        this.roleSupport = roleSupport;
        this.allowedAppIds = List.copyOf(appIdMapping.getIds());
        this.oidcProviderMetadata = oidcProviderMetadata;

        // azure spring
        this.validAudiences.add(aadAuthProps.getClientId());
        try {
            keySource = JWKSourceBuilder.create(oidcProviderMetadata.getJWKSetURI().toURL(), resourceRetriever).build();
        } catch (MalformedURLException e) {
            log.error("Failed to parse active directory key discovery uri.", e);
            throw new IllegalStateException("Failed to parse active directory key discovery uri.", e);
        }
    }

    @Override
    protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {
        boolean cleanupRequired = false;

        if (Strings.CI.startsWith(request.getServletPath(), "/login")) {
            counter.labels("login").inc();
        } else {
            cleanupRequired = authenticate(request, response);
        }

        try {
            filterChain.doFilter(request, response);
        } finally {
            if (cleanupRequired) {
                SecurityContextHolder.clearContext();
            }
        }
    }

    private boolean authenticate(HttpServletRequest request, HttpServletResponse response) throws ServletException {
        Credential credential = getCredential(request, response);
        if (credential != null) {
            try {
                var principal = buildUserPrincipal(credential.getAccessToken());
                var grantedAuthorities = roleSupport.lookupGrantedAuthorities(principal.getStringListClaim("groups"));
                var authentication = new PreAuthenticatedAuthenticationToken(principal, credential, grantedAuthorities);
                authentication.setDetails(new AzureUserInfo(principal, grantedAuthorities));
                authentication.setAuthenticated(true);
                log.trace("Request token verification success for subject {} with roles {}.", AzureUserInfo.getUserId(principal), grantedAuthorities);
                SecurityContextHolder.getContext().setAuthentication(authentication);
                return true;
            } catch (BadJWTException ex) {
                String errorMessage = "Invalid JWT. Either expired or not yet valid. " + ex.getMessage();
                log.warn(errorMessage);
                throw new ServletException(errorMessage, ex);
            } catch (ParseException | BadJOSEException | JOSEException ex) {
                log.error("Failed to initialize UserPrincipal.", ex);
                throw new ServletException(ex);
            }
        } else {
            if (!Strings.CI.startsWith(request.getServletPath(), "/internal")) {
                counter.labels("no_auth").inc();
            }
        }
        return false;
    }

    private Credential getCredential(HttpServletRequest request, HttpServletResponse response) {
        if (request.getCookies() != null) {
            Optional<Cookie> cookie = Stream.of(request.getCookies())
                    .filter(c -> c.getName().equals(COOKIE_NAME))
                    .findFirst();
            if (cookie.isPresent()) {
                try {
                    String session = cookie.get().getValue();
                    Auth auth = azureTokenProvider.getAuth(session);
                    counter.labels("cookie").inc();
                    return new Credential(auth);
                } catch (Exception e) {
                    log.warn("Invalid auth cookie", e);
                    response.addCookie(AuthController.createCookie(null, 0, request));
                    return null;
                }
            }
        }
        String authHeader = request.getHeader(HttpHeaders.AUTHORIZATION);
        if (hasText(authHeader) && authHeader.startsWith(TOKEN_TYPE)) {
            String authHeader1 = request.getHeader(HttpHeaders.AUTHORIZATION);
            String token = authHeader1.replace(TOKEN_TYPE, "");
            counter.labels("direct_token").inc();
            return new Credential(token);
        }
        return null;
    }

    private JWTClaimsSet buildUserPrincipal(String token) throws ParseException, JOSEException, BadJOSEException {
        var principal = buildAndValidateClaims(token);
        String appIdClaim = AzureUserInfo.getAppId(principal);
        if (appIdClaim == null || !allowedAppIds.contains(appIdClaim)) {
            throw new BadJWTException("Invalid token appId. Provided value " + appIdClaim + " does not match allowed appId");
        }
        return principal;
    }

    private static Counter initCounter() {
        return MetricUtils.counter()
                .labels("no_auth").labels("cookie").labels("direct_token").labels("login")
                .name("team_adal_auth_counter")
                .help("Counter for authentication events")
                .labelNames("action")
                .register();
    }

    // From spring azure start
    private final Set<String> validAudiences = new HashSet<>();

    public JWTClaimsSet buildAndValidateClaims(String idToken) throws ParseException, BadJOSEException, JOSEException {
        final JWSObject jwsObject = JWSObject.parse(idToken);
        final ConfigurableJWTProcessor<SecurityContext> validator =
                getAadJwtTokenValidator(jwsObject.getHeader().getAlgorithm());
        final JWTClaimsSet jwtClaimsSet = validator.process(idToken, null);
        final JWTClaimsSetVerifier<SecurityContext> verifier = validator.getJWTClaimsSetVerifier();
        verifier.verify(jwtClaimsSet, null);

        return jwtClaimsSet;
    }

    private ConfigurableJWTProcessor<SecurityContext> getAadJwtTokenValidator(JWSAlgorithm jwsAlgorithm) {
        final ConfigurableJWTProcessor<SecurityContext> jwtProcessor = new DefaultJWTProcessor<>();

        final JWSKeySelector<SecurityContext> keySelector =
                new JWSVerificationKeySelector<>(jwsAlgorithm, keySource);
        jwtProcessor.setJWSKeySelector(keySelector);

        jwtProcessor.setJWTClaimsSetVerifier(new DefaultJWTClaimsVerifier<>(null, Set.of("iss", "aud")) {
            @Override
            public void verify(JWTClaimsSet claimsSet, SecurityContext ctx) throws BadJWTException {
                super.verify(claimsSet, ctx);
                final String issuer = claimsSet.getIssuer();
                if (issuer == null || !issuer.equals(oidcProviderMetadata.getIssuer().getValue())) {
                    throw new BadJWTException("Invalid token issuer " + issuer);
                }
                final Optional<String> matchedAudience = claimsSet.getAudience().stream().filter(validAudiences::contains).findFirst();
                if (matchedAudience.isPresent()) {
                    log.trace("Matched audience [{}]", matchedAudience.get());
                } else {
                    throw new BadJWTException("Invalid token audience. Provided value " + claimsSet.getAudience() + "does not match neither client-id nor AppIdUri.");
                }
            }
        });
        return jwtProcessor;
    }

}