AADAuthenticationFilter.java

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package com.azure.spring.autoconfigure.aad;

import com.microsoft.aad.msal4j.MsalServiceException;
import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.jwk.source.JWKSetCache;
import com.nimbusds.jose.proc.BadJOSEException;
import com.nimbusds.jose.util.ResourceRetriever;
import com.nimbusds.jwt.proc.BadJWTException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.web.authentication.preauth.PreAuthenticatedAuthenticationToken;
import org.springframework.web.filter.OncePerRequestFilter;

import javax.naming.ServiceUnavailableException;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;
import java.io.IOException;
import java.net.MalformedURLException;
import java.text.ParseException;
import java.util.Optional;

import static com.azure.spring.autoconfigure.aad.Constants.BEARER_PREFIX;

/**
 * A stateful authentication filter which uses Microsoft Graph groups to authorize. Both ID token and access token are
 * supported. In the case of access token, only access token issued for the exact same application this filter used for
 * could be accepted, e.g. access token issued for Microsoft Graph could not be processed by users' application.
 */
public class AADAuthenticationFilter extends OncePerRequestFilter {
    private static final Logger LOGGER = LoggerFactory.getLogger(AADAuthenticationFilter.class);
    private static final String CURRENT_USER_PRINCIPAL = "CURRENT_USER_PRINCIPAL";

    private final UserPrincipalManager userPrincipalManager;
    private final AzureADGraphClient azureADGraphClient;

    public AADAuthenticationFilter(AADAuthenticationProperties aadAuthenticationProperties,
                                   ServiceEndpointsProperties serviceEndpointsProperties,
                                   ResourceRetriever resourceRetriever) {
        this(
            aadAuthenticationProperties,
            serviceEndpointsProperties,
            new UserPrincipalManager(
                serviceEndpointsProperties,
                aadAuthenticationProperties,
                resourceRetriever,
                false
            )
        );
    }

    public AADAuthenticationFilter(AADAuthenticationProperties aadAuthenticationProperties,
                                   ServiceEndpointsProperties serviceEndpointsProperties,
                                   ResourceRetriever resourceRetriever,
                                   JWKSetCache jwkSetCache) {
        this(
            aadAuthenticationProperties,
            serviceEndpointsProperties,
            new UserPrincipalManager(
                serviceEndpointsProperties,
                aadAuthenticationProperties,
                resourceRetriever,
                false,
                jwkSetCache
            )
        );
    }

    public AADAuthenticationFilter(AADAuthenticationProperties aadAuthenticationProperties,
                                   ServiceEndpointsProperties serviceEndpointsProperties,
                                   UserPrincipalManager userPrincipalManager) {
        this.userPrincipalManager = userPrincipalManager;
        this.azureADGraphClient = new AzureADGraphClient(
            aadAuthenticationProperties.getClientId(),
            aadAuthenticationProperties.getClientSecret(),
            aadAuthenticationProperties,
            serviceEndpointsProperties
        );
    }

    @Override
    protected void doFilterInternal(HttpServletRequest httpServletRequest,
                                    HttpServletResponse httpServletResponse,
                                    FilterChain filterChain) throws ServletException, IOException {
        String aadIssuedBearerToken = Optional.of(httpServletRequest)
                                              .map(r -> r.getHeader(HttpHeaders.AUTHORIZATION))
                                              .map(String::trim)
                                              .filter(s -> s.startsWith(BEARER_PREFIX))
                                              .map(s -> s.replace(BEARER_PREFIX, ""))
                                              .filter(userPrincipalManager::isTokenIssuedByAAD)
                                              .orElse(null);
        if (aadIssuedBearerToken == null || alreadyAuthenticated()) {
            filterChain.doFilter(httpServletRequest, httpServletResponse);
            return;
        }
        try {
            HttpSession httpSession = httpServletRequest.getSession();
            UserPrincipal userPrincipal = (UserPrincipal) httpSession.getAttribute(CURRENT_USER_PRINCIPAL);
            if (userPrincipal == null
                || !userPrincipal.getAadIssuedBearerToken().equals(aadIssuedBearerToken)
                || userPrincipal.getAccessTokenForGraphApi() == null
            ) {
                userPrincipal = userPrincipalManager.buildUserPrincipal(aadIssuedBearerToken);
                String tenantId = userPrincipal.getClaim(AADTokenClaim.TID).toString();
                String accessTokenForGraphApi = azureADGraphClient
                    .acquireTokenForGraphApi(aadIssuedBearerToken, tenantId)
                    .accessToken();
                userPrincipal.setAccessTokenForGraphApi(accessTokenForGraphApi);
                userPrincipal.setGroups(azureADGraphClient.getGroups(accessTokenForGraphApi));
                httpSession.setAttribute(CURRENT_USER_PRINCIPAL, userPrincipal);
            }
            final Authentication authentication = new PreAuthenticatedAuthenticationToken(
                userPrincipal,
                null,
                azureADGraphClient.toGrantedAuthoritySet(userPrincipal.getGroups())
            );
            LOGGER.info("Request token verification success. {}", authentication);
            SecurityContextHolder.getContext().setAuthentication(authentication);
        } catch (BadJWTException ex) {
            // Invalid JWT. Either expired or not yet valid.
            httpServletResponse.sendError(HttpStatus.UNAUTHORIZED.value());
            return;
        } catch (MalformedURLException | ParseException | JOSEException | BadJOSEException ex) {
            LOGGER.error("Failed to initialize UserPrincipal.", ex);
            throw new ServletException(ex);
        } catch (ServiceUnavailableException ex) {
            LOGGER.error("Failed to acquire graph api token.", ex);
            throw new ServletException(ex);
        } catch (MsalServiceException ex) {
            // Handle conditional access policy, step 2.
            // No step 3 any more, because ServletException will not be caught.
            // TODO: Do we need to return 401 instead of 500?
            if (ex.claims() != null && !ex.claims().isEmpty()) {
                throw new ServletException("Handle conditional access policy", ex);
            } else {
                throw ex;
            }
        }
        filterChain.doFilter(httpServletRequest, httpServletResponse);
    }

    private boolean alreadyAuthenticated() {
        return Optional.of(SecurityContextHolder.getContext())
                       .map(SecurityContext::getAuthentication)
                       .map(Authentication::isAuthenticated)
                       .orElse(false);
    }
}