AADStatelessAuthenticationFilter.java
package no.nav.data.common.security.azure;
import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.JWSObject;
import com.nimbusds.jose.jwk.source.JWKSource;
import com.nimbusds.jose.jwk.source.JWKSourceBuilder;
import com.nimbusds.jose.proc.BadJOSEException;
import com.nimbusds.jose.proc.JWSKeySelector;
import com.nimbusds.jose.proc.JWSVerificationKeySelector;
import com.nimbusds.jose.proc.SecurityContext;
import com.nimbusds.jose.util.ResourceRetriever;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.proc.*;
import com.nimbusds.openid.connect.sdk.op.OIDCProviderMetadata;
import io.prometheus.client.Counter;
import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.Cookie;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import lombok.extern.slf4j.Slf4j;
import no.nav.data.common.security.AppIdMapping;
import no.nav.data.common.security.AuthController;
import no.nav.data.common.security.RoleSupport;
import no.nav.data.common.security.domain.Auth;
import no.nav.data.common.security.dto.Credential;
import no.nav.data.common.utils.MetricUtils;
import org.apache.commons.lang3.Strings;
import org.springframework.http.HttpHeaders;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.web.authentication.preauth.PreAuthenticatedAuthenticationToken;
import org.springframework.web.filter.OncePerRequestFilter;
import java.io.IOException;
import java.net.MalformedURLException;
import java.text.ParseException;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Stream;
import static no.nav.data.common.security.SecurityConstants.COOKIE_NAME;
import static no.nav.data.common.security.SecurityConstants.TOKEN_TYPE;
import static org.springframework.util.StringUtils.hasText;
@Slf4j
public class AADStatelessAuthenticationFilter extends OncePerRequestFilter {
private static final Counter counter = initCounter();
private final AzureTokenProvider azureTokenProvider;
private final RoleSupport roleSupport;
private final List<String> allowedAppIds;
private final OIDCProviderMetadata oidcProviderMetadata;
private final JWKSource<SecurityContext> keySource;
public AADStatelessAuthenticationFilter(AzureTokenProvider azureTokenProvider, RoleSupport roleSupport, AppIdMapping appIdMapping,
AADAuthenticationProperties aadAuthProps, ResourceRetriever resourceRetriever, OIDCProviderMetadata oidcProviderMetadata) {
this.azureTokenProvider = azureTokenProvider;
this.roleSupport = roleSupport;
this.allowedAppIds = List.copyOf(appIdMapping.getIds());
this.oidcProviderMetadata = oidcProviderMetadata;
// azure spring
this.validAudiences.add(aadAuthProps.getClientId());
try {
keySource = JWKSourceBuilder.create(oidcProviderMetadata.getJWKSetURI().toURL(), resourceRetriever).build();
} catch (MalformedURLException e) {
log.error("Failed to parse active directory key discovery uri.", e);
throw new IllegalStateException("Failed to parse active directory key discovery uri.", e);
}
}
@Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {
boolean cleanupRequired = false;
if (Strings.CI.startsWith(request.getServletPath(), "/login")) {
counter.labels("login").inc();
} else {
cleanupRequired = authenticate(request, response);
}
try {
filterChain.doFilter(request, response);
} finally {
if (cleanupRequired) {
SecurityContextHolder.clearContext();
}
}
}
private boolean authenticate(HttpServletRequest request, HttpServletResponse response) throws ServletException {
Credential credential = getCredential(request, response);
if (credential != null) {
try {
var principal = buildUserPrincipal(credential.getAccessToken());
var grantedAuthorities = roleSupport.lookupGrantedAuthorities(principal.getStringListClaim("groups"));
var authentication = new PreAuthenticatedAuthenticationToken(principal, credential, grantedAuthorities);
authentication.setDetails(new AzureUserInfo(principal, grantedAuthorities));
authentication.setAuthenticated(true);
log.trace("Request token verification success for subject {} with roles {}.", AzureUserInfo.getUserId(principal), grantedAuthorities);
SecurityContextHolder.getContext().setAuthentication(authentication);
return true;
} catch (BadJWTException ex) {
String errorMessage = "Invalid JWT. Either expired or not yet valid. " + ex.getMessage();
log.warn(errorMessage);
throw new ServletException(errorMessage, ex);
} catch (ParseException | BadJOSEException | JOSEException ex) {
log.error("Failed to initialize UserPrincipal.", ex);
throw new ServletException(ex);
}
} else {
if (!Strings.CI.startsWith(request.getServletPath(), "/internal")) {
counter.labels("no_auth").inc();
}
}
return false;
}
private Credential getCredential(HttpServletRequest request, HttpServletResponse response) {
if (request.getCookies() != null) {
Optional<Cookie> cookie = Stream.of(request.getCookies())
.filter(c -> c.getName().equals(COOKIE_NAME))
.findFirst();
if (cookie.isPresent()) {
try {
String session = cookie.get().getValue();
Auth auth = azureTokenProvider.getAuth(session);
counter.labels("cookie").inc();
return new Credential(auth);
} catch (Exception e) {
log.warn("Invalid auth cookie", e);
response.addCookie(AuthController.createCookie(null, 0, request));
return null;
}
}
}
String authHeader = request.getHeader(HttpHeaders.AUTHORIZATION);
if (hasText(authHeader) && authHeader.startsWith(TOKEN_TYPE)) {
String authHeader1 = request.getHeader(HttpHeaders.AUTHORIZATION);
String token = authHeader1.replace(TOKEN_TYPE, "");
counter.labels("direct_token").inc();
return new Credential(token);
}
return null;
}
private JWTClaimsSet buildUserPrincipal(String token) throws ParseException, JOSEException, BadJOSEException {
var principal = buildAndValidateClaims(token);
String appIdClaim = AzureUserInfo.getAppId(principal);
if (appIdClaim == null || !allowedAppIds.contains(appIdClaim)) {
throw new BadJWTException("Invalid token appId. Provided value " + appIdClaim + " does not match allowed appId");
}
return principal;
}
private static Counter initCounter() {
return MetricUtils.counter()
.labels("no_auth").labels("cookie").labels("direct_token").labels("login")
.name("team_adal_auth_counter")
.help("Counter for authentication events")
.labelNames("action")
.register();
}
// From spring azure start
private final Set<String> validAudiences = new HashSet<>();
public JWTClaimsSet buildAndValidateClaims(String idToken) throws ParseException, BadJOSEException, JOSEException {
final JWSObject jwsObject = JWSObject.parse(idToken);
final ConfigurableJWTProcessor<SecurityContext> validator =
getAadJwtTokenValidator(jwsObject.getHeader().getAlgorithm());
final JWTClaimsSet jwtClaimsSet = validator.process(idToken, null);
final JWTClaimsSetVerifier<SecurityContext> verifier = validator.getJWTClaimsSetVerifier();
verifier.verify(jwtClaimsSet, null);
return jwtClaimsSet;
}
private ConfigurableJWTProcessor<SecurityContext> getAadJwtTokenValidator(JWSAlgorithm jwsAlgorithm) {
final ConfigurableJWTProcessor<SecurityContext> jwtProcessor = new DefaultJWTProcessor<>();
final JWSKeySelector<SecurityContext> keySelector =
new JWSVerificationKeySelector<>(jwsAlgorithm, keySource);
jwtProcessor.setJWSKeySelector(keySelector);
jwtProcessor.setJWTClaimsSetVerifier(new DefaultJWTClaimsVerifier<>(null, Set.of("iss", "aud")) {
@Override
public void verify(JWTClaimsSet claimsSet, SecurityContext ctx) throws BadJWTException {
super.verify(claimsSet, ctx);
final String issuer = claimsSet.getIssuer();
if (issuer == null || !issuer.equals(oidcProviderMetadata.getIssuer().getValue())) {
throw new BadJWTException("Invalid token issuer " + issuer);
}
final Optional<String> matchedAudience = claimsSet.getAudience().stream().filter(validAudiences::contains).findFirst();
if (matchedAudience.isPresent()) {
log.trace("Matched audience [{}]", matchedAudience.get());
} else {
throw new BadJWTException("Invalid token audience. Provided value " + claimsSet.getAudience() + "does not match neither client-id nor AppIdUri.");
}
}
});
return jwtProcessor;
}
}