AggregateDocumentQueryExecutionContext.java
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
package com.azure.cosmos.implementation.query;
import com.azure.cosmos.BridgeInternal;
import com.azure.cosmos.implementation.ClientSideRequestStatistics;
import com.azure.cosmos.implementation.Document;
import com.azure.cosmos.implementation.HttpConstants;
import com.azure.cosmos.implementation.QueryMetrics;
import com.azure.cosmos.implementation.Resource;
import com.azure.cosmos.implementation.query.aggregation.AggregateOperator;
import com.azure.cosmos.models.FeedResponse;
import com.fasterxml.jackson.databind.node.ObjectNode;
import reactor.core.publisher.Flux;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.function.BiFunction;
public class AggregateDocumentQueryExecutionContext<T extends Resource> implements IDocumentQueryExecutionComponent<T>{
public static final String PAYLOAD_PROPERTY_NAME = "payload";
private final boolean isValueAggregateQuery;
private IDocumentQueryExecutionComponent<T> component;
private ConcurrentMap<String, QueryMetrics> queryMetricsMap = new ConcurrentHashMap<>();
private SingleGroupAggregator singleGroupAggregator;
//QueryInfo class used in PipelinedDocumentQueryExecutionContext returns a Collection of AggregateOperators
public AggregateDocumentQueryExecutionContext(IDocumentQueryExecutionComponent<T> component,
List<AggregateOperator> aggregateOperators,
Map<String, AggregateOperator> groupByAliasToAggregateType,
List<String> orderedAliases,
boolean hasSelectValue,
String continuationToken) {
this.component = component;
this.isValueAggregateQuery = hasSelectValue;
this.singleGroupAggregator = SingleGroupAggregator.create(aggregateOperators,
groupByAliasToAggregateType,
orderedAliases,
hasSelectValue,
continuationToken);
}
@SuppressWarnings("unchecked")
@Override
public Flux<FeedResponse<T>> drainAsync(int maxPageSize) {
return this.component.drainAsync(maxPageSize)
.collectList()
.map( superList -> {
double requestCharge = 0;
List<Document> aggregateResults = new ArrayList<>();
HashMap<String, String> headers = new HashMap<>();
List<ClientSideRequestStatistics> diagnosticsList = new ArrayList<>();
for(FeedResponse<T> page : superList) {
diagnosticsList.addAll(BridgeInternal
.getClientSideRequestStatisticsList(page.getCosmosDiagnostics()));
if (page.getResults().size() == 0) {
headers.put(HttpConstants.HttpHeaders.REQUEST_CHARGE, Double.toString(requestCharge));
FeedResponse<Document> frp = BridgeInternal.createFeedResponse(aggregateResults, headers);
BridgeInternal.addClientSideDiagnosticsToFeed(frp.getCosmosDiagnostics(), diagnosticsList);
return (FeedResponse<T>) frp;
}
requestCharge += page.getRequestCharge();
for (T d : page.getResults()) {
RewrittenAggregateProjections rewrittenAggregateProjections =
new RewrittenAggregateProjections(this.isValueAggregateQuery,
(Document)d); //d is always a Document
this.singleGroupAggregator.addValues(rewrittenAggregateProjections.getPayload());
}
for(String key : BridgeInternal.queryMetricsFromFeedResponse(page).keySet()) {
if (queryMetricsMap.containsKey(key)) {
QueryMetrics qm = BridgeInternal.queryMetricsFromFeedResponse(page).get(key);
queryMetricsMap.get(key).add(qm);
} else {
queryMetricsMap.put(key, BridgeInternal.queryMetricsFromFeedResponse(page).get(key));
}
}
}
Document aggregateDocument = this.singleGroupAggregator.getResult();
if (aggregateDocument != null) {
aggregateResults.add(aggregateDocument);
}
headers.put(HttpConstants.HttpHeaders.REQUEST_CHARGE, Double.toString(requestCharge));
FeedResponse<Document> frp = BridgeInternal.createFeedResponse(aggregateResults, headers);
if(!queryMetricsMap.isEmpty()) {
for(Map.Entry<String, QueryMetrics> entry: queryMetricsMap.entrySet()) {
BridgeInternal.putQueryMetricsIntoMap(frp, entry.getKey(), entry.getValue());
}
}
BridgeInternal.addClientSideDiagnosticsToFeed(frp.getCosmosDiagnostics(), diagnosticsList);
return (FeedResponse<T>) frp;
}).flux();
}
public static <T extends Resource> Flux<IDocumentQueryExecutionComponent<T>> createAsync(
BiFunction<String, PipelinedDocumentQueryParams<T>, Flux<IDocumentQueryExecutionComponent<T>>> createSourceComponentFunction,
Collection<AggregateOperator> aggregates,
Map<String, AggregateOperator> groupByAliasToAggregateType,
List<String> groupByAliases,
boolean hasSelectValue,
String continuationToken,
PipelinedDocumentQueryParams<T> documentQueryParams) {
return createSourceComponentFunction
.apply(continuationToken, documentQueryParams)
.map(component -> new AggregateDocumentQueryExecutionContext<T>(component,
new ArrayList<>(aggregates),
groupByAliasToAggregateType,
groupByAliases,
hasSelectValue,
continuationToken));
}
public IDocumentQueryExecutionComponent<T> getComponent() {
return this.component;
}
class RewrittenAggregateProjections {
private Document payload;
public RewrittenAggregateProjections(boolean isValueAggregateQuery, Document document) {
if (document == null) {
throw new IllegalArgumentException("document cannot be null");
}
if (isValueAggregateQuery) {
this.payload = new Document(document.getPropertyBag());
} else {
if (!document.has(PAYLOAD_PROPERTY_NAME)) {
throw new IllegalStateException("Underlying object does not have an 'payload' field.");
}
if (document.get(PAYLOAD_PROPERTY_NAME) instanceof ObjectNode) {
this.payload = new Document((ObjectNode) document.get(PAYLOAD_PROPERTY_NAME));
}
}
}
public Document getPayload() {
return payload;
}
}
}