ClaimsBasedSecurityChannel.java
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
package com.azure.core.amqp.implementation;
import com.azure.core.amqp.AmqpRetryOptions;
import com.azure.core.amqp.ClaimsBasedSecurityNode;
import com.azure.core.amqp.exception.AmqpException;
import com.azure.core.amqp.exception.AmqpResponseCode;
import com.azure.core.credential.TokenCredential;
import com.azure.core.credential.TokenRequestContext;
import com.azure.core.util.logging.ClientLogger;
import org.apache.qpid.proton.Proton;
import org.apache.qpid.proton.amqp.messaging.AmqpValue;
import org.apache.qpid.proton.amqp.messaging.ApplicationProperties;
import org.apache.qpid.proton.message.Message;
import reactor.core.publisher.Mono;
import reactor.core.publisher.SynchronousSink;
import java.time.OffsetDateTime;
import java.util.Date;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import static com.azure.core.amqp.implementation.ExceptionUtil.amqpResponseCodeToException;
import static com.azure.core.amqp.implementation.RequestResponseUtils.getStatusCode;
import static com.azure.core.amqp.implementation.RequestResponseUtils.getStatusDescription;
public class ClaimsBasedSecurityChannel implements ClaimsBasedSecurityNode {
static final String PUT_TOKEN_TYPE = "type";
static final String PUT_TOKEN_AUDIENCE = "name";
static final String PUT_TOKEN_EXPIRY = "expiration";
private static final String PUT_TOKEN_OPERATION = "operation";
private static final String PUT_TOKEN_OPERATION_VALUE = "put-token";
private final ClientLogger logger = new ClientLogger(ClaimsBasedSecurityChannel.class);
private final TokenCredential credential;
private final Mono<RequestResponseChannel> cbsChannelMono;
private final CbsAuthorizationType authorizationType;
private final AmqpRetryOptions retryOptions;
public ClaimsBasedSecurityChannel(Mono<RequestResponseChannel> responseChannelMono, TokenCredential tokenCredential,
CbsAuthorizationType authorizationType, AmqpRetryOptions retryOptions) {
this.authorizationType = Objects.requireNonNull(authorizationType, "'authorizationType' cannot be null.");
this.retryOptions = Objects.requireNonNull(retryOptions, "'retryOptions' cannot be null.");
this.credential = Objects.requireNonNull(tokenCredential, "'tokenCredential' cannot be null.");
this.cbsChannelMono = Objects.requireNonNull(responseChannelMono, "'responseChannelMono' cannot be null.");
}
@Override
public Mono<OffsetDateTime> authorize(String tokenAudience, String scopes) {
return cbsChannelMono.flatMap(channel ->
credential.getToken(new TokenRequestContext().addScopes(scopes))
.flatMap(accessToken -> {
final Message request = Proton.message();
final Map<String, Object> properties = new HashMap<>();
properties.put(PUT_TOKEN_OPERATION, PUT_TOKEN_OPERATION_VALUE);
properties.put(PUT_TOKEN_EXPIRY, Date.from(accessToken.getExpiresAt().toInstant()));
properties.put(PUT_TOKEN_TYPE, authorizationType.getTokenType());
properties.put(PUT_TOKEN_AUDIENCE, tokenAudience);
final ApplicationProperties applicationProperties = new ApplicationProperties(properties);
request.setApplicationProperties(applicationProperties);
request.setBody(new AmqpValue(accessToken.getToken()));
return channel.sendWithAck(request)
.handle((Message message, SynchronousSink<OffsetDateTime> sink) -> {
if (RequestResponseUtils.isSuccessful(message)) {
sink.next(accessToken.getExpiresAt());
} else {
final String description = getStatusDescription(message);
final AmqpResponseCode statusCode = getStatusCode(message);
final Exception error = amqpResponseCodeToException(
statusCode.getValue(), description, channel.getErrorContext());
sink.error(error);
}
})
.switchIfEmpty(Mono.error(new AmqpException(true, String.format(
"No response received from CBS node. tokenAudience: '%s'. scopes: '%s'",
tokenAudience, scopes), channel.getErrorContext())));
}));
}
@Override
public void close() {
final RequestResponseChannel channel = cbsChannelMono.block(retryOptions.getTryTimeout());
if (channel != null) {
channel.dispose();
}
}
}