AADOAuth2AuthorizationRequestResolver.java
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
package com.azure.spring.autoconfigure.aad;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizationRequestResolver;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestRedirectFilter;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestResolver;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import javax.servlet.http.HttpServletRequest;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
/**
* To add conditional policy claims to authorization URL.
*/
public class AADOAuth2AuthorizationRequestResolver implements OAuth2AuthorizationRequestResolver {
private final OAuth2AuthorizationRequestResolver defaultResolver;
public AADOAuth2AuthorizationRequestResolver(ClientRegistrationRepository clientRegistrationRepository) {
this.defaultResolver = new DefaultOAuth2AuthorizationRequestResolver(
clientRegistrationRepository,
OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI
);
}
@Override
public OAuth2AuthorizationRequest resolve(HttpServletRequest request) {
return addClaims(request, defaultResolver.resolve(request));
}
@Override
public OAuth2AuthorizationRequest resolve(HttpServletRequest request, String clientRegistrationId) {
return addClaims(request, defaultResolver.resolve(request, clientRegistrationId));
}
// Add claims to authorization-url
private OAuth2AuthorizationRequest addClaims(HttpServletRequest httpServletRequest,
OAuth2AuthorizationRequest oAuth2AuthorizationRequest) {
if (oAuth2AuthorizationRequest == null || httpServletRequest == null) {
return oAuth2AuthorizationRequest;
}
// Handle conditional access policy, step 4.
final String conditionalAccessPolicyClaims =
Optional.of(httpServletRequest)
.map(HttpServletRequest::getSession)
.map(httpSession -> {
String claims = (String) httpSession.getAttribute(Constants.CONDITIONAL_ACCESS_POLICY_CLAIMS);
if (claims != null) {
httpSession.removeAttribute(Constants.CONDITIONAL_ACCESS_POLICY_CLAIMS);
}
return claims;
})
.orElse(null);
if (conditionalAccessPolicyClaims == null) {
return oAuth2AuthorizationRequest;
}
final Map<String, Object> additionalParameters = new HashMap<>();
additionalParameters.put(Constants.CLAIMS, conditionalAccessPolicyClaims);
Optional.of(oAuth2AuthorizationRequest)
.map(OAuth2AuthorizationRequest::getAdditionalParameters)
.ifPresent(additionalParameters::putAll);
return OAuth2AuthorizationRequest.from(oAuth2AuthorizationRequest)
.additionalParameters(additionalParameters)
.build();
}
}