JacksonHttpSessionOAuth2AuthorizedClientRepository.java
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
package com.azure.spring.aad.webapp;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
import org.springframework.security.core.Authentication;
import org.springframework.security.jackson2.CoreJackson2Module;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.jackson2.OAuth2ClientJackson2Module;
import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
import org.springframework.util.Assert;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
/**
* An implementation of an {@link OAuth2AuthorizedClientRepository} that stores
* {@link OAuth2AuthorizedClient}'s in the {@code HttpSession}. To make it compatible
* with different spring versions.
* Refs: https://github.com/spring-projects/spring-security/issues/9204
*
* @see OAuth2AuthorizedClientRepository
* @see OAuth2AuthorizedClient
*/
public class JacksonHttpSessionOAuth2AuthorizedClientRepository implements OAuth2AuthorizedClientRepository {
private static final String DEFAULT_AUTHORIZED_CLIENTS_ATTR_NAME =
HttpSessionOAuth2AuthorizedClientRepository.class.getName() + ".AUTHORIZED_CLIENTS";
private final String sessionAttributeName = DEFAULT_AUTHORIZED_CLIENTS_ATTR_NAME;
private final ObjectMapper objectMapper;
private static final TypeReference<Map<String, OAuth2AuthorizedClient>> TYPE_REFERENCE =
new TypeReference<Map<String, OAuth2AuthorizedClient>>() {
};
public JacksonHttpSessionOAuth2AuthorizedClientRepository() {
objectMapper = new ObjectMapper();
objectMapper.registerModule(new OAuth2ClientJackson2Module());
objectMapper.registerModule(new CoreJackson2Module());
objectMapper.registerModule(new JavaTimeModule());
}
@SuppressWarnings("unchecked")
@Override
public <T extends OAuth2AuthorizedClient> T loadAuthorizedClient(String clientRegistrationId,
Authentication principal,
HttpServletRequest request) {
Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty");
Assert.notNull(request, "request cannot be null");
return (T) this.getAuthorizedClients(request).get(clientRegistrationId);
}
@Override
public void saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal,
HttpServletRequest request, HttpServletResponse response) {
Assert.notNull(authorizedClient, "authorizedClient cannot be null");
Assert.notNull(request, "request cannot be null");
Assert.notNull(response, "response cannot be null");
Map<String, OAuth2AuthorizedClient> authorizedClients = this.getAuthorizedClients(request);
authorizedClients.put(authorizedClient.getClientRegistration().getRegistrationId(), authorizedClient);
request.getSession().setAttribute(this.sessionAttributeName, toString(authorizedClients));
}
@Override
public void removeAuthorizedClient(String clientRegistrationId, Authentication principal,
HttpServletRequest request, HttpServletResponse response) {
Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty");
Assert.notNull(request, "request cannot be null");
Map<String, OAuth2AuthorizedClient> authorizedClients = this.getAuthorizedClients(request);
if (!authorizedClients.isEmpty()) {
if (authorizedClients.remove(clientRegistrationId) != null) {
if (!authorizedClients.isEmpty()) {
request.getSession().setAttribute(this.sessionAttributeName, toString(authorizedClients));
} else {
request.getSession().removeAttribute(this.sessionAttributeName);
}
}
}
}
private String toString(Map<String, OAuth2AuthorizedClient> authorizedClients) {
String result;
try {
result = objectMapper.writeValueAsString(authorizedClients);
} catch (JsonProcessingException e) {
throw new IllegalStateException(e);
}
return result;
}
private Map<String, OAuth2AuthorizedClient> getAuthorizedClients(HttpServletRequest request) {
HttpSession session = request.getSession(false);
String authorizedClientsString = (String) Optional.ofNullable(session)
.map(s -> s.getAttribute(this.sessionAttributeName))
.orElse(null);
if (authorizedClientsString == null) {
return new HashMap<>();
}
Map<String, OAuth2AuthorizedClient> authorizedClients;
try {
authorizedClients = objectMapper.readValue(authorizedClientsString, TYPE_REFERENCE);
} catch (JsonProcessingException e) {
throw new IllegalStateException(e);
}
return authorizedClients;
}
}