UserPrincipalManager.java
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
package com.azure.spring.autoconfigure.aad;
import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.JWSObject;
import com.nimbusds.jose.jwk.source.JWKSetCache;
import com.nimbusds.jose.jwk.source.JWKSource;
import com.nimbusds.jose.jwk.source.RemoteJWKSet;
import com.nimbusds.jose.proc.BadJOSEException;
import com.nimbusds.jose.proc.JWSKeySelector;
import com.nimbusds.jose.proc.JWSVerificationKeySelector;
import com.nimbusds.jose.proc.SecurityContext;
import com.nimbusds.jose.util.ResourceRetriever;
import com.nimbusds.jwt.JWT;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.JWTParser;
import com.nimbusds.jwt.proc.BadJWTException;
import com.nimbusds.jwt.proc.ConfigurableJWTProcessor;
import com.nimbusds.jwt.proc.DefaultJWTClaimsVerifier;
import com.nimbusds.jwt.proc.DefaultJWTProcessor;
import net.minidev.json.JSONArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.net.MalformedURLException;
import java.net.URL;
import java.text.ParseException;
import java.util.Collection;
import java.util.HashSet;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
/**
* A user principal manager to load user info from JWT.
*/
public class UserPrincipalManager {
private static final Logger LOGGER = LoggerFactory.getLogger(UserPrincipalManager.class);
private static final String LOGIN_MICROSOFT_ONLINE_ISSUER = "https://login.microsoftonline.com/";
private static final String STS_WINDOWS_ISSUER = "https://sts.windows.net/";
private static final String STS_CHINA_CLOUD_API_ISSUER = "https://sts.chinacloudapi.cn/";
private final JWKSource<SecurityContext> keySource;
private final AADAuthenticationProperties aadAuthenticationProperties;
private final Boolean explicitAudienceCheck;
private final Set<String> validAudiences = new HashSet<>();
/**
* ΓΈ Creates a new {@link UserPrincipalManager} with a predefined {@link JWKSource}.
* <p>
* This is helpful in cases the JWK is not a remote JWKSet or for unit testing.
*
* @param keySource - {@link JWKSource} containing at least one key
*/
public UserPrincipalManager(JWKSource<SecurityContext> keySource) {
this.keySource = keySource;
this.explicitAudienceCheck = false;
this.aadAuthenticationProperties = null;
}
/**
* Create a new {@link UserPrincipalManager} based of the {@link ServiceEndpoints#getAadKeyDiscoveryUri()} and
* {@link AADAuthenticationProperties#getEnvironment()}.
*
* @param serviceEndpointsProps - used to retrieve the JWKS URL
* @param aadAuthenticationProperties - used to retrieve the environment.
* @param resourceRetriever - configures the {@link RemoteJWKSet} call.
* @param explicitAudienceCheck Whether explicitly check the audience.
* @throws IllegalArgumentException If AAD key discovery URI is malformed.
*/
public UserPrincipalManager(ServiceEndpointsProperties serviceEndpointsProps,
AADAuthenticationProperties aadAuthenticationProperties,
ResourceRetriever resourceRetriever,
boolean explicitAudienceCheck) {
this.aadAuthenticationProperties = aadAuthenticationProperties;
this.explicitAudienceCheck = explicitAudienceCheck;
if (explicitAudienceCheck) {
// client-id for "normal" check
this.validAudiences.add(this.aadAuthenticationProperties.getClientId());
// app id uri for client credentials flow (server to server communication)
this.validAudiences.add(this.aadAuthenticationProperties.getAppIdUri());
}
try {
String aadKeyDiscoveryUri = getAadKeyDiscoveryUri(serviceEndpointsProps);
keySource = new RemoteJWKSet<>(new URL(aadKeyDiscoveryUri), resourceRetriever);
} catch (MalformedURLException e) {
LOGGER.error("Failed to parse active directory key discovery uri.", e);
throw new IllegalArgumentException("Failed to parse active directory key discovery uri.", e);
}
}
private String getAadKeyDiscoveryUri(ServiceEndpointsProperties serviceEndpointsProps) {
return Optional.of(aadAuthenticationProperties)
.map(AADAuthenticationProperties::getEnvironment)
.map(serviceEndpointsProps::getServiceEndpoints)
.map(ServiceEndpoints::getAadKeyDiscoveryUri)
.orElse("");
}
/**
* Create a new {@link UserPrincipalManager} based of the {@link ServiceEndpoints#getAadKeyDiscoveryUri()} and
* {@link AADAuthenticationProperties#getEnvironment()}.
*
* @param serviceEndpointsProps - used to retrieve the JWKS URL
* @param aadAuthenticationProperties - used to retrieve the environment.
* @param resourceRetriever - configures the {@link RemoteJWKSet} call.
* @param jwkSetCache - used to cache the JWK set for a finite time, default set to 5 minutes which matches
* constructor above if no jwkSetCache is passed in
* @param explicitAudienceCheck Whether explicitly check the audience.
* @throws IllegalArgumentException If AAD key discovery URI is malformed.
*/
public UserPrincipalManager(ServiceEndpointsProperties serviceEndpointsProps,
AADAuthenticationProperties aadAuthenticationProperties,
ResourceRetriever resourceRetriever,
boolean explicitAudienceCheck,
JWKSetCache jwkSetCache) {
this.aadAuthenticationProperties = aadAuthenticationProperties;
this.explicitAudienceCheck = explicitAudienceCheck;
if (explicitAudienceCheck) {
// client-id for "normal" check
this.validAudiences.add(this.aadAuthenticationProperties.getClientId());
// app id uri for client credentials flow (server to server communication)
this.validAudiences.add(this.aadAuthenticationProperties.getAppIdUri());
}
try {
String aadKeyDiscoveryUri = getAadKeyDiscoveryUri(serviceEndpointsProps);
keySource = new RemoteJWKSet<>(new URL(aadKeyDiscoveryUri), resourceRetriever, jwkSetCache);
} catch (MalformedURLException e) {
LOGGER.error("Failed to parse active directory key discovery uri.", e);
throw new IllegalArgumentException("Failed to parse active directory key discovery uri.", e);
}
}
/**
* Parse the id token to {@link UserPrincipal}.
*
* @param aadIssuedBearerToken The token issued by AAD.
* @return The parsed {@link UserPrincipal}.
* @throws ParseException If the token couldn't be parsed to a valid JWS object.
* @throws JOSEException If an internal processing exception is encountered.
* @throws BadJOSEException If the JWT is rejected.
*/
public UserPrincipal buildUserPrincipal(String aadIssuedBearerToken) throws ParseException, JOSEException,
BadJOSEException {
final JWSObject jwsObject = JWSObject.parse(aadIssuedBearerToken);
final ConfigurableJWTProcessor<SecurityContext> validator = getValidator(jwsObject.getHeader().getAlgorithm());
final JWTClaimsSet jwtClaimsSet = validator.process(aadIssuedBearerToken, null);
validator.getJWTClaimsSetVerifier().verify(jwtClaimsSet, null);
UserPrincipal userPrincipal = new UserPrincipal(aadIssuedBearerToken, jwsObject, jwtClaimsSet);
Set<String> roles = Optional.of(userPrincipal)
.map(p -> p.getClaim(AADTokenClaim.ROLES))
.map(r -> (JSONArray) r)
.map(Collection<Object>::stream)
.orElseGet(Stream::empty)
.map(Object::toString)
.collect(Collectors.toSet());
userPrincipal.setRoles(roles);
return userPrincipal;
}
public boolean isTokenIssuedByAAD(String token) {
try {
final JWT jwt = JWTParser.parse(token);
return isAADIssuer(jwt.getJWTClaimsSet().getIssuer());
} catch (ParseException e) {
LOGGER.info("Fail to parse JWT {}, exception {}", token, e);
}
return false;
}
private static boolean isAADIssuer(String issuer) {
if (issuer == null) {
return false;
}
return issuer.startsWith(LOGIN_MICROSOFT_ONLINE_ISSUER)
|| issuer.startsWith(STS_WINDOWS_ISSUER)
|| issuer.startsWith(STS_CHINA_CLOUD_API_ISSUER);
}
private ConfigurableJWTProcessor<SecurityContext> getValidator(JWSAlgorithm jwsAlgorithm) {
final ConfigurableJWTProcessor<SecurityContext> jwtProcessor = new DefaultJWTProcessor<>();
final JWSKeySelector<SecurityContext> keySelector = new JWSVerificationKeySelector<>(jwsAlgorithm, keySource);
jwtProcessor.setJWSKeySelector(keySelector);
//TODO: would it make sense to inject it? and make it configurable or even allow to provide own implementation
jwtProcessor.setJWTClaimsSetVerifier(new DefaultJWTClaimsVerifier<SecurityContext>() {
@Override
public void verify(JWTClaimsSet claimsSet, SecurityContext ctx) throws BadJWTException {
super.verify(claimsSet, ctx);
final String issuer = claimsSet.getIssuer();
if (!isAADIssuer(issuer)) {
throw new BadJWTException("Invalid token issuer");
}
if (explicitAudienceCheck) {
Optional<String> matchedAudience = claimsSet.getAudience()
.stream()
.filter(validAudiences::contains)
.findFirst();
if (matchedAudience.isPresent()) {
LOGGER.debug("Matched audience: [{}]", matchedAudience.get());
} else {
throw new BadJWTException("Invalid token audience. Provided value " + claimsSet.getAudience()
+ "does not match neither client-id nor AppIdUri.");
}
}
}
});
return jwtProcessor;
}
}