ContentDownloader.java

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
package com.azure.communication.callingserver;

import com.azure.communication.callingserver.implementation.Constants;
import com.azure.communication.callingserver.models.CallingServerErrorException;
import com.azure.communication.callingserver.models.ParallelDownloadOptions;
import com.azure.core.http.HttpMethod;
import com.azure.core.http.HttpPipeline;
import com.azure.core.http.HttpRange;
import com.azure.core.http.HttpRequest;
import com.azure.core.http.HttpResponse;
import com.azure.core.http.rest.Response;
import com.azure.core.http.rest.SimpleResponse;
import com.azure.core.util.Context;
import com.azure.core.util.FluxUtil;
import com.azure.core.util.logging.ClientLogger;
import reactor.core.Exceptions;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.publisher.SignalType;
import reactor.core.scheduler.Schedulers;
import reactor.util.function.Tuple2;

import java.io.IOException;
import java.io.OutputStream;
import java.io.UncheckedIOException;
import java.net.MalformedURLException;
import java.net.URL;
import java.nio.ByteBuffer;
import java.nio.channels.AsynchronousFileChannel;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Function;

import static java.lang.StrictMath.toIntExact;

class ContentDownloader {
    private final String resourceEndpoint;
    private final HttpPipeline httpPipeline;
    private final ClientLogger logger = new ClientLogger(ContentDownloader.class);

    ContentDownloader(String resourceEndpoint, HttpPipeline httpPipeline) {
        this.resourceEndpoint = resourceEndpoint;
        this.httpPipeline = httpPipeline;
    }

    Mono<Response<Void>> downloadToStreamWithResponse(
        String sourceEndpoint,
        OutputStream destinationStream,
        HttpRange httpRange,
        Context context) {
        return downloadStreamWithResponse(sourceEndpoint, httpRange, context)
            .flatMap(response -> response.getValue().reduce(destinationStream, (outputStream, buffer) -> {
                try {
                    outputStream.write(FluxUtil.byteBufferToArray(buffer));
                    return outputStream;
                } catch (IOException ex) {
                    throw logger.logExceptionAsError(Exceptions.propagate(new UncheckedIOException(ex)));
                }
            }).thenReturn(new SimpleResponse<>(response.getRequest(), response.getStatusCode(),
                response.getHeaders(), null)));
    }

    Mono<Response<Flux<ByteBuffer>>> downloadStreamWithResponse(
        String sourceEndpoint,
        HttpRange httpRange,
        Context context) {
        Mono<HttpResponse> httpResponse = makeDownloadRequest(sourceEndpoint, httpRange, context);
        return httpResponse.map(response -> {
            Flux<ByteBuffer> result = getFluxStream(response, sourceEndpoint, httpRange, context);
            return new SimpleResponse<>(response.getRequest(), response.getStatusCode(),
                response.getHeaders(), result);
        });
    }

    Mono<Response<Void>> downloadToFileWithResponse(
        String sourceEndpoint,
        AsynchronousFileChannel destinationFile,
        ParallelDownloadOptions parallelDownloadOptions,
        Context context) {
        Lock progressLock = new ReentrantLock();
        AtomicLong totalProgress = new AtomicLong(0);

        Function<HttpRange, Mono<Response<Flux<ByteBuffer>>>> downloadFunc =
            range -> downloadStreamWithResponse(sourceEndpoint, range, context);

        return downloadFirstChunk(parallelDownloadOptions, downloadFunc)
            .flatMap(setupTuple2 -> {
                long newCount = setupTuple2.getT1();
                int numChunks = calculateNumBlocks(newCount, parallelDownloadOptions.getBlockSize());

                // In case it is an empty blob, this ensures we still actually perform a download operation.
                numChunks = numChunks == 0 ? 1 : numChunks;

                Response<Flux<ByteBuffer>> initialResponse = setupTuple2.getT2();
                return Flux.range(0, numChunks)
                    .flatMap(chunkNum -> downloadChunk(chunkNum, initialResponse,
                        parallelDownloadOptions, newCount, downloadFunc,
                        response ->
                            writeBodyToFile(response, destinationFile, chunkNum,
                                parallelDownloadOptions, progressLock, totalProgress).flux()))
                    .then(Mono.just(new SimpleResponse<>(initialResponse, null)));
            });
    }

    private Flux<ByteBuffer> getFluxStream(
        HttpResponse httpResponse,
        String sourceEndpoint,
        HttpRange httpRange,
        Context context) {
        return FluxUtil.createRetriableDownloadFlux(
            () -> getResponseBody(httpResponse),
            (Throwable throwable, Long aLong) -> {
                if (throwable instanceof CallingServerErrorException) {
                    CallingServerErrorException exception = (CallingServerErrorException) throwable;
                    if (exception.getResponse().getStatusCode() == 416) {
                        return  makeDownloadRequest(sourceEndpoint, null, context)
                            .map(this::getResponseBody)
                            .flux()
                            .flatMap(flux -> flux);
                    }
                }

                HttpRange range;
                if (httpRange != null) {
                    range = new HttpRange(aLong + 1, httpRange.getLength() - aLong - 1);
                } else {
                    range = new HttpRange(aLong + 1);
                }

                return makeDownloadRequest(sourceEndpoint, range, context)
                    .map(this::getResponseBody)
                    .flux()
                    .flatMap(flux -> flux);
            },
            Constants.ContentDownloader.MAX_RETRIES
        );
    }

    private Flux<ByteBuffer> getResponseBody(HttpResponse response) {
        switch (response.getStatusCode()) {
            case 200:
            case 206:
                return response.getBody();
            case 416:   // Retriable with new HttpRange, potentially bytes=0-
                return FluxUtil.fluxError(logger,
                    new CallingServerErrorException(formatExceptionMessage(response), response));
            default:
                throw logger.logExceptionAsError(
                    new CallingServerErrorException(formatExceptionMessage(response), response)
                );
        }
    }

    private String formatExceptionMessage(HttpResponse httpResponse) {
        return String.format("Service Request failed!%nStatus: %s", httpResponse.getStatusCode());
    }

    private Mono<HttpResponse> makeDownloadRequest(
        String sourceEndpoint,
        HttpRange httpRange,
        Context context) {
        HttpRequest request = getHttpRequest(sourceEndpoint, httpRange);
        URL urlToSignWith = getUrlToSignRequestWith(sourceEndpoint);

        Context finalContext;
        if (context == null) {
            finalContext = new Context("hmacSignatureURL", urlToSignWith);
        } else {
            finalContext = context.addData("hmacSignatureURL", urlToSignWith);
        }

        return httpPipeline.send(request, finalContext);
    }

    private URL getUrlToSignRequestWith(String endpoint) {
        try {
            String path = new URL(endpoint).getPath();

            if (path.startsWith("/")) {
                path = path.substring(1);
            }

            return new URL(resourceEndpoint + path);
        } catch (MalformedURLException ex) {
            throw logger.logExceptionAsError(new IllegalArgumentException(ex));
        }
    }

    private HttpRequest getHttpRequest(String sourceEndpoint, HttpRange httpRange) {
        HttpRequest request = new HttpRequest(HttpMethod.GET, sourceEndpoint);

        if (null != httpRange) {
            request.setHeader(Constants.HeaderNames.RANGE, httpRange.toString());
        }

        return request;
    }

    private Mono<Tuple2<Long, Response<Flux<ByteBuffer>>>> downloadFirstChunk(
        ParallelDownloadOptions parallelDownloadOptions,
        Function<HttpRange, Mono<Response<Flux<ByteBuffer>>>> downloader) {
        return downloader.apply(new HttpRange(0, parallelDownloadOptions.getBlockSize()))
            .subscribeOn(Schedulers.boundedElastic())
            .flatMap(response -> {
                // Extract the total length of the blob from the contentRange header. e.g. "bytes 1-6/7"
                long totalLength = extractTotalBlobLength(
                    response.getHeaders().getValue(Constants.HeaderNames.CONTENT_RANGE)
                );

                return Mono.zip(Mono.just(totalLength), Mono.just(response));
            });
    }

    private long extractTotalBlobLength(String contentRange) {
        return contentRange == null ? 0 : Long.parseLong(contentRange.split("/")[1]);
    }

    private int calculateNumBlocks(long dataSize, long blockLength) {
        // Can successfully cast to an int because MaxBlockSize is an int, which this expression must be less than.
        int numBlocks = toIntExact(dataSize / blockLength);
        // Include an extra block for trailing data.
        if (dataSize % blockLength != 0) {
            numBlocks++;
        }
        return numBlocks;
    }

    private <T> Flux<T> downloadChunk(
        Integer chunkNum,
        Response<Flux<ByteBuffer>> initialResponse,
        ParallelDownloadOptions parallelDownloadOptions,
        long newCount,
        Function<HttpRange, Mono<Response<Flux<ByteBuffer>>>> downloader,
        Function<Response<Flux<ByteBuffer>>, Flux<T>> returnTransformer) {
        if (chunkNum == 0) {
            return returnTransformer.apply(initialResponse);
        }

        // Calculate whether we need a full chunk or something smaller because we are at the end.
        long modifier = chunkNum.longValue() * parallelDownloadOptions.getBlockSize();
        long chunkSizeActual = Math.min(parallelDownloadOptions.getBlockSize(),
            newCount - modifier);
        HttpRange chunkRange = new HttpRange(modifier, chunkSizeActual);

        // Make the download call.
        return downloader.apply(chunkRange)
            .subscribeOn(Schedulers.boundedElastic())
            .flatMapMany(returnTransformer);
    }

    private static Mono<Void> writeBodyToFile(
        Response<Flux<ByteBuffer>> response,
        AsynchronousFileChannel file,
        long chunkNum,
        ParallelDownloadOptions parallelDownloadOptions,
        Lock progressLock,
        AtomicLong totalProgress) {
        // Extract the body.
        Flux<ByteBuffer> data = response.getValue();

        // Report progress as necessary.
        data = ProgressReporter.addParallelProgressReporting(data,
            parallelDownloadOptions.getProgressReceiver(), progressLock, totalProgress);

        // Write to the file.
        return FluxUtil.writeFile(data, file, chunkNum * parallelDownloadOptions.getBlockSize());
    }

    void downloadToFileCleanup(AsynchronousFileChannel channel, Path filePath, SignalType signalType) {
        try {
            channel.close();
            if (!signalType.equals(SignalType.ON_COMPLETE)) {
                Files.deleteIfExists(filePath);
                logger.verbose("Downloading to file failed. Cleaning up resources.");
            }
        } catch (IOException e) {
            throw logger.logExceptionAsError(new UncheckedIOException(e));
        }
    }
}