AzureOAuth2ResponseErrorHandler.java
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
package com.azure.spring.aad.webapp;
import com.nimbusds.oauth2.sdk.token.BearerTokenError;
import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter;
import org.springframework.util.StringUtils;
import org.springframework.web.client.DefaultResponseErrorHandler;
import org.springframework.web.client.ResponseErrorHandler;
import java.io.IOException;
import java.util.Map;
/**
* Handle resource server error response
*/
public class AzureOAuth2ResponseErrorHandler implements ResponseErrorHandler {
private final OAuth2ErrorHttpMessageConverter oauth2ErrorConverter = new OAuth2ErrorHttpMessageConverter();
private final ResponseErrorHandler defaultErrorHandler = new DefaultResponseErrorHandler();
protected AzureOAuth2ResponseErrorHandler() {
this.oauth2ErrorConverter.setErrorConverter(new AADOAuth2ErrorConverter());
}
@Override
public boolean hasError(ClientHttpResponse response) throws IOException {
return this.defaultErrorHandler.hasError(response);
}
@Override
public void handleError(ClientHttpResponse response) throws IOException {
if (!HttpStatus.BAD_REQUEST.equals(response.getStatusCode())) {
this.defaultErrorHandler.handleError(response);
}
// A Bearer Token Error may be in the WWW-Authenticate response header
OAuth2Error oauth2Error = this.readErrorFromWwwAuthenticate(response.getHeaders());
if (oauth2Error == null) {
oauth2Error = this.oauth2ErrorConverter.read(OAuth2Error.class, response);
}
throw new OAuth2AuthorizationException(oauth2Error);
}
private OAuth2Error readErrorFromWwwAuthenticate(HttpHeaders headers) {
String wwwAuthenticateHeader = headers.getFirst(HttpHeaders.WWW_AUTHENTICATE);
if (!StringUtils.hasText(wwwAuthenticateHeader)) {
return null;
}
BearerTokenError bearerTokenError;
try {
bearerTokenError = BearerTokenError.parse(wwwAuthenticateHeader);
} catch (Exception ex) {
return null;
}
String errorCode = bearerTokenError.getCode() != null
? bearerTokenError.getCode() : OAuth2ErrorCodes.SERVER_ERROR;
String errorDescription = bearerTokenError.getDescription();
String errorUri = bearerTokenError.getURI() != null
? bearerTokenError.getURI().toString() : null;
return new OAuth2Error(errorCode, errorDescription, errorUri);
}
private static class AADOAuth2ErrorConverter implements Converter<Map<String, String>, OAuth2Error> {
@Override
public OAuth2Error convert(Map<String, String> parameters) {
String errorCode = parameters.get("error");
String description = parameters.get("error_description");
String errorCodes = parameters.get("error_codes");
String timestamp = parameters.get("timestamp");
String traceId = parameters.get("trace_id");
String correlationId = parameters.get("correlation_id");
String uri = parameters.get("error_uri");
String subError = parameters.get("suberror");
String claims = parameters.get("claims");
return new AzureOAuth2Error(errorCode, description, errorCodes, timestamp, traceId, correlationId,
uri, subError, claims);
}
}
}