AzureADGraphClient.java
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
package com.azure.spring.autoconfigure.aad;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.microsoft.aad.msal4j.ClientCredentialFactory;
import com.microsoft.aad.msal4j.ConfidentialClientApplication;
import com.microsoft.aad.msal4j.IAuthenticationResult;
import com.microsoft.aad.msal4j.IClientCredential;
import com.microsoft.aad.msal4j.MsalServiceException;
import com.microsoft.aad.msal4j.OnBehalfOfParameters;
import com.microsoft.aad.msal4j.UserAssertion;
import com.nimbusds.oauth2.sdk.http.HTTPResponse;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.security.core.authority.SimpleGrantedAuthority;
import javax.naming.ServiceUnavailableException;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.net.HttpURLConnection;
import java.net.MalformedURLException;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.ExecutionException;
import java.util.stream.Collectors;
import static com.azure.spring.autoconfigure.aad.Constants.DEFAULT_AUTHORITY_SET;
import static com.azure.spring.autoconfigure.aad.Constants.ROLE_PREFIX;
/**
* Microsoft Graph client encapsulation.
*/
public class AzureADGraphClient {
private static final Logger LOGGER = LoggerFactory.getLogger(AzureADGraphClient.class);
private static final String MICROSOFT_GRAPH_SCOPE = "https://graph.microsoft.com/user.read";
private static final String AAD_GRAPH_API_SCOPE = "https://graph.windows.net/user.read";
// We use "aadfeed5" as suffix when client library is ADAL, upgrade to "aadfeed6" for MSAL
private static final String REQUEST_ID_SUFFIX = "aadfeed6";
private static final String V2_VERSION_ENV_FLAG = "v2-graph";
private final String clientId;
private final String clientSecret;
private final ServiceEndpoints serviceEndpoints;
private final AADAuthenticationProperties aadAuthenticationProperties;
private final boolean graphApiVersionIsV2;
public AzureADGraphClient(String clientId,
String clientSecret,
AADAuthenticationProperties aadAuthenticationProperties,
ServiceEndpointsProperties serviceEndpointsProps) {
this.clientId = clientId;
this.clientSecret = clientSecret;
this.aadAuthenticationProperties = aadAuthenticationProperties;
this.serviceEndpoints = serviceEndpointsProps.getServiceEndpoints(aadAuthenticationProperties.getEnvironment());
this.graphApiVersionIsV2 = Optional.of(aadAuthenticationProperties)
.map(AADAuthenticationProperties::getEnvironment)
.map(environment -> environment.contains(V2_VERSION_ENV_FLAG))
.orElse(false);
}
private String getUserMemberships(String accessToken, String urlString) throws IOException {
URL url = new URL(urlString);
final HttpURLConnection connection = (HttpURLConnection) url.openConnection();
// Set the appropriate header fields in the request header.
if (this.graphApiVersionIsV2) {
connection.setRequestMethod(HttpMethod.GET.toString());
connection.setRequestProperty(HttpHeaders.AUTHORIZATION, String.format("Bearer %s", accessToken));
connection.setRequestProperty(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE);
connection.setRequestProperty(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_FORM_URLENCODED_VALUE);
} else {
connection.setRequestMethod(HttpMethod.GET.toString());
connection.setRequestProperty("api-version", "1.6");
connection.setRequestProperty(HttpHeaders.AUTHORIZATION, String.format("Bearer %s", accessToken));
connection.setRequestProperty(HttpHeaders.ACCEPT, "application/json;odata=minimalmetadata");
}
final String responseInJson = getResponseString(connection);
final int responseCode = connection.getResponseCode();
if (responseCode == HTTPResponse.SC_OK) {
return responseInJson;
} else {
throw new IllegalStateException(
"Response is not " + HTTPResponse.SC_OK + ", response json: " + responseInJson);
}
}
private String getUrlStringFromODataNextLink(String odataNextLink) {
if (this.graphApiVersionIsV2) {
return odataNextLink;
} else {
String skipToken = odataNextLink.split("/memberOf\\?")[1];
return serviceEndpoints.getAadMembershipRestUri() + "&" + skipToken;
}
}
private static String getResponseString(HttpURLConnection connection) throws IOException {
try (BufferedReader reader = new BufferedReader(
new InputStreamReader(connection.getInputStream(), StandardCharsets.UTF_8))) {
final StringBuilder stringBuffer = new StringBuilder();
String line;
while ((line = reader.readLine()) != null) {
stringBuffer.append(line);
}
return stringBuffer.toString();
}
}
/**
* @param graphApiToken token used to access graph api.
* @return groups in graph api.
* @throws IOException throw exception if get groups failed by IOException.
*/
public Set<String> getGroups(String graphApiToken) throws IOException {
final Set<String> groups = new LinkedHashSet<>();
final ObjectMapper objectMapper = JacksonObjectMapperFactory.getInstance();
String aadMembershipRestUri = serviceEndpoints.getAadMembershipRestUri();
while (aadMembershipRestUri != null) {
String membershipsJson = getUserMemberships(graphApiToken, aadMembershipRestUri);
Memberships memberships = objectMapper.readValue(membershipsJson, Memberships.class);
memberships.getValue()
.stream()
.filter(this::isGroupObject)
.map(Membership::getDisplayName)
.forEach(groups::add);
aadMembershipRestUri = Optional.of(memberships)
.map(Memberships::getOdataNextLink)
.map(this::getUrlStringFromODataNextLink)
.orElse(null);
}
return groups;
}
private boolean isGroupObject(final Membership membership) {
return membership.getObjectType().equals(aadAuthenticationProperties.getUserGroup().getValue());
}
public Set<SimpleGrantedAuthority> toGrantedAuthoritySet(final Set<String> groups) {
Set<SimpleGrantedAuthority> grantedAuthoritySet =
groups.stream()
.filter(aadAuthenticationProperties::isAllowedGroup)
.map(group -> new SimpleGrantedAuthority(ROLE_PREFIX + group))
.collect(Collectors.toSet());
return Optional.of(grantedAuthoritySet)
.filter(g -> !g.isEmpty())
.orElse(DEFAULT_AUTHORITY_SET);
}
/**
* Acquire access token for calling Graph API.
*
* @param idToken The token used to perform an OBO request.
* @param tenantId The tenant id.
* @return The access token for Graph service.
* @throws ServiceUnavailableException If fail to acquire the token.
* @throws MsalServiceException If {@link MsalServiceException} has occurred.
*/
public IAuthenticationResult acquireTokenForGraphApi(String idToken, String tenantId)
throws ServiceUnavailableException {
final IClientCredential clientCredential =
ClientCredentialFactory.createFromSecret(clientSecret);
final UserAssertion assertion = new UserAssertion(idToken);
IAuthenticationResult result = null;
try {
final ConfidentialClientApplication application = ConfidentialClientApplication
.builder(clientId, clientCredential)
.authority(serviceEndpoints.getAadSigninUri() + tenantId + "/")
.correlationId(getCorrelationId())
.build();
final Set<String> scopes = new HashSet<>();
scopes.add(graphApiVersionIsV2 ? MICROSOFT_GRAPH_SCOPE : AAD_GRAPH_API_SCOPE);
final OnBehalfOfParameters onBehalfOfParameters = OnBehalfOfParameters.builder(scopes, assertion).build();
result = application.acquireToken(onBehalfOfParameters).get();
} catch (ExecutionException | InterruptedException | MalformedURLException e) {
// Handle conditional access policy, step 1.
final Throwable cause = e.getCause();
if (cause instanceof MsalServiceException) {
final MsalServiceException exception = (MsalServiceException) cause;
if (exception.claims() != null && !exception.claims().isEmpty()) {
throw exception;
}
}
LOGGER.error("acquire on behalf of token for graph api error", e);
}
if (result == null) {
throw new ServiceUnavailableException("unable to acquire on-behalf-of token for client "
+ clientId);
}
return result;
}
private static String getCorrelationId() {
final String uuid = UUID.randomUUID().toString();
return uuid.substring(0, uuid.length() - REQUEST_ID_SUFFIX.length()) + REQUEST_ID_SUFFIX;
}
}