ThroughputRequestThrottler.java
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
package com.azure.cosmos.implementation.throughputControl;
import com.azure.cosmos.BridgeInternal;
import com.azure.cosmos.CosmosException;
import com.azure.cosmos.implementation.HttpConstants;
import com.azure.cosmos.implementation.OperationType;
import com.azure.cosmos.implementation.RequestRateTooLargeException;
import com.azure.cosmos.implementation.ResourceType;
import com.azure.cosmos.implementation.RxDocumentServiceRequest;
import com.azure.cosmos.implementation.RxDocumentServiceResponse;
import com.azure.cosmos.implementation.Utils;
import com.azure.cosmos.implementation.apachecommons.lang.StringUtils;
import com.azure.cosmos.implementation.directconnectivity.StoreResponse;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.Exceptions;
import reactor.core.publisher.Mono;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.locks.ReentrantReadWriteLock;
/**
* This is the place where we tracking the RU usage, and make decision whether we should block the request.
*/
public class ThroughputRequestThrottler {
private static final Logger logger = LoggerFactory.getLogger(ThroughputRequestThrottler.class);
private final AtomicReference<Double> availableThroughput;
private final AtomicReference<Double> scheduledThroughput;
private final ReentrantReadWriteLock.WriteLock throughputWriteLock;
private final ReentrantReadWriteLock.ReadLock throughputReadLock;
private final ConcurrentHashMap<OperationType, ThroughputControlTrackingUnit> trackingDictionary;
private final String pkRangeId;
private String cycleId;
public ThroughputRequestThrottler(double scheduledThroughput, String pkRangeId) {
this.availableThroughput = new AtomicReference<>(scheduledThroughput);
this.scheduledThroughput = new AtomicReference<>(scheduledThroughput);
ReentrantReadWriteLock throughputReadWriteLock = new ReentrantReadWriteLock();
this.throughputWriteLock = throughputReadWriteLock.writeLock();
this.throughputReadLock = throughputReadWriteLock.readLock();
this.trackingDictionary = new ConcurrentHashMap<>();
this.cycleId = UUID.randomUUID().toString();
this.pkRangeId = pkRangeId;
}
public double renewThroughputUsageCycle(double scheduledThroughput) {
try {
this.throughputWriteLock.lock();
double throughputUsagePercentage = (this.scheduledThroughput.get() - this.availableThroughput.get()) / this.scheduledThroughput.get();
this.scheduledThroughput.set(scheduledThroughput);
this.updateAvailableThroughput();
if (throughputUsagePercentage > 0) {
logger.debug(
"[CycleId: {}, pkRangeId: {}, ruUsagePercentage: {}]",
this.cycleId, this.pkRangeId, throughputUsagePercentage);
}
String newCycleId = UUID.randomUUID().toString();
for (ThroughputControlTrackingUnit trackingUnit : this.trackingDictionary.values()) {
trackingUnit.reset(newCycleId);
}
this.cycleId = newCycleId;
return throughputUsagePercentage;
} finally {
this.throughputWriteLock.unlock();
}
}
private void updateAvailableThroughput() {
// The base rule is: If RU is overused during the current cycle, the over used part will be deducted from the next cyclle
// If RU is not fully utilized during the current cycle, it will be voided.
this.availableThroughput.getAndAccumulate(this.scheduledThroughput.get(), (available, refill) -> Math.min(available,0) + refill);
}
public <T> Mono<T> processRequest(RxDocumentServiceRequest request, Mono<T> originalRequestMono) {
try {
this.throughputReadLock.lock();
ThroughputControlTrackingUnit trackingUnit =
this.trackingDictionary.compute(request.getOperationType(), ((key, value) -> {
if (value == null) {
value = new ThroughputControlTrackingUnit(request.getOperationType(), this.cycleId);
}
return value;
}));
if (this.availableThroughput.get() > 0) {
if (StringUtils.isEmpty(request.requestContext.throughputControlCycleId)) {
request.requestContext.throughputControlCycleId = this.cycleId;
}
trackingUnit.increasePassedRequest();
return originalRequestMono
.doOnSuccess(response -> this.trackRequestCharge(request, response))
.doOnError(throwable -> this.trackRequestCharge(request, throwable));
} else {
trackingUnit.increaseRejectedRequest();
// there is no enough throughput left, block request
RequestRateTooLargeException requestRateTooLargeException = new RequestRateTooLargeException();
int backoffTimeInMilliSeconds = (int)Math.ceil(Math.abs(this.availableThroughput.get() / this.scheduledThroughput.get())) * 1000;
requestRateTooLargeException.getResponseHeaders().put(
HttpConstants.HttpHeaders.RETRY_AFTER_IN_MILLISECONDS,
String.valueOf(backoffTimeInMilliSeconds));
if (isBulkRequest(request)) {
// For batch requests the BulkExecutor
requestRateTooLargeException.getResponseHeaders().put(
HttpConstants.HttpHeaders.SUB_STATUS,
String.valueOf(HttpConstants.SubStatusCodes.THROUGHPUT_CONTROL_BULK_REQUEST_RATE_TOO_LARGE));
} else {
requestRateTooLargeException.getResponseHeaders().put(
HttpConstants.HttpHeaders.SUB_STATUS,
String.valueOf(HttpConstants.SubStatusCodes.THROUGHPUT_CONTROL_REQUEST_RATE_TOO_LARGE));
}
if (request.requestContext != null) {
BridgeInternal.setResourceAddress(requestRateTooLargeException, request.requestContext.resourcePhysicalAddress);
}
return Mono.error(requestRateTooLargeException);
}
} finally {
this.throughputReadLock.unlock();
}
}
private static boolean isBulkRequest(RxDocumentServiceRequest request) {
if (request.getOperationType() != OperationType.Batch ||
request.getResourceType() != ResourceType.Document) {
return false;
}
String isAtomicBatch = request.getHeaders().get(HttpConstants.HttpHeaders.IS_BATCH_ATOMIC);
if(StringUtils.isEmpty(isAtomicBatch)) {
return true;
} else {
return !isAtomicBatch.equalsIgnoreCase(Boolean.TRUE.toString());
}
}
private <T> void trackRequestCharge (RxDocumentServiceRequest request, T response) {
try {
// Read lock is enough here.
this.throughputReadLock.lock();
double requestCharge = 0;
boolean failedRequest = false;
if (response instanceof StoreResponse) {
requestCharge = ((StoreResponse)response).getRequestCharge();
} else if (response instanceof RxDocumentServiceResponse) {
requestCharge = ((RxDocumentServiceResponse)response).getRequestCharge();
} else if (response instanceof Throwable) {
CosmosException cosmosException = Utils.as(Exceptions.unwrap((Throwable) response), CosmosException.class);
if (cosmosException != null) {
requestCharge = cosmosException.getRequestCharge();
failedRequest = true;
}
}
ThroughputControlTrackingUnit trackingUnit = trackingDictionary.get(request.getOperationType());
if (trackingUnit != null) {
if (failedRequest) {
trackingUnit.increaseFailedResponse();
} else {
trackingUnit.increaseSuccessResponse();
trackingUnit.trackRRuUsage(requestCharge);
}
}
// If the response comes back in a different cycle, discard it.
if (StringUtils.equals(this.cycleId, request.requestContext.throughputControlCycleId)) {
this.availableThroughput.getAndAccumulate(requestCharge, (available, consumed) -> available - consumed);
} else {
if (trackingUnit != null) {
trackingUnit.increaseOutOfCycleResponse();
}
}
} finally {
this.throughputReadLock.unlock();
}
}
public double getAvailableThroughput() {
return this.availableThroughput.get();
}
public double getScheduledThroughput() {
return this.scheduledThroughput.get();
}
}