AzureAuthorizedClientRepository.java
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
package com.azure.spring.aad.webapp;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
/**
* OAuth2AuthorizedClientRepository used for AAD oauth2 clients.
*/
public class AzureAuthorizedClientRepository implements OAuth2AuthorizedClientRepository {
private final AzureClientRegistrationRepository repo;
private final OAuth2AuthorizedClientRepository delegate;
public AzureAuthorizedClientRepository(AzureClientRegistrationRepository repo) {
this(repo, new JacksonHttpSessionOAuth2AuthorizedClientRepository());
}
public AzureAuthorizedClientRepository(AzureClientRegistrationRepository repo,
OAuth2AuthorizedClientRepository delegate) {
this.repo = repo;
this.delegate = delegate;
}
@Override
public void saveAuthorizedClient(OAuth2AuthorizedClient client,
Authentication principal,
HttpServletRequest request,
HttpServletResponse response) {
delegate.saveAuthorizedClient(client, principal, request, response);
}
@Override
@SuppressWarnings("unchecked")
public <T extends OAuth2AuthorizedClient> T loadAuthorizedClient(String id,
Authentication principal,
HttpServletRequest request) {
OAuth2AuthorizedClient result = delegate.loadAuthorizedClient(id, principal, request);
if (result != null) {
return (T) result;
}
if (repo.isAuthzClient(id)) {
OAuth2AuthorizedClient client = loadAuthorizedClient(defaultClientRegistrationId(), principal, request);
return (T) createInitAuthzClient(client, id, principal);
}
return null;
}
private String defaultClientRegistrationId() {
return repo.getAzureClient().getClient().getRegistrationId();
}
private OAuth2AuthorizedClient createInitAuthzClient(OAuth2AuthorizedClient client,
String id,
Authentication principal) {
if (client == null || client.getRefreshToken() == null) {
return null;
}
OAuth2AccessToken accessToken = new OAuth2AccessToken(
OAuth2AccessToken.TokenType.BEARER,
"non-access-token",
Instant.MIN,
Instant.now().minus(100, ChronoUnit.DAYS));
return new OAuth2AuthorizedClient(
repo.findByRegistrationId(id),
principal.getName(),
accessToken,
client.getRefreshToken()
);
}
@Override
public void removeAuthorizedClient(String id,
Authentication principal,
HttpServletRequest request,
HttpServletResponse response) {
delegate.removeAuthorizedClient(id, principal, request, response);
}
}