FluxUtil.java

  1. // Copyright (c) Microsoft Corporation. All rights reserved.
  2. // Licensed under the MIT License.

  3. package com.azure.core.implementation.util;

  4. import com.azure.core.http.rest.PagedFlux;
  5. import com.azure.core.http.rest.Response;
  6. import com.azure.core.util.Context;
  7. import com.azure.core.util.logging.ClientLogger;
  8. import org.reactivestreams.Subscriber;
  9. import org.reactivestreams.Subscription;
  10. import reactor.core.CoreSubscriber;
  11. import reactor.core.Exceptions;
  12. import reactor.core.publisher.Flux;
  13. import reactor.core.publisher.Mono;
  14. import reactor.core.publisher.Operators;

  15. import java.io.ByteArrayOutputStream;
  16. import java.io.IOException;
  17. import java.lang.reflect.Type;
  18. import java.nio.ByteBuffer;
  19. import java.nio.channels.AsynchronousFileChannel;
  20. import java.nio.channels.CompletionHandler;
  21. import java.util.Map;
  22. import java.util.Map.Entry;
  23. import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;
  24. import java.util.concurrent.atomic.AtomicLongFieldUpdater;
  25. import java.util.function.Function;
  26. import java.util.stream.Collectors;

  27. /**
  28.  * Utility type exposing methods to deal with {@link Flux}.
  29.  */
  30. public final class FluxUtil {
  31.     /**
  32.      * Checks if a type is Flux<ByteBuffer>.
  33.      *
  34.      * @param entityType the type to check
  35.      * @return whether the type represents a Flux that emits ByteBuffer
  36.      */
  37.     public static boolean isFluxByteBuffer(Type entityType) {
  38.         if (TypeUtil.isTypeOrSubTypeOf(entityType, Flux.class)) {
  39.             final Type innerType = TypeUtil.getTypeArguments(entityType)[0];
  40.             if (TypeUtil.isTypeOrSubTypeOf(innerType, ByteBuffer.class)) {
  41.                 return true;
  42.             }
  43.         }
  44.         return false;
  45.     }

  46.     /**
  47.      * Collects ByteBuffer emitted by a Flux into a byte array.
  48.      *
  49.      * @param stream A stream which emits ByteBuffer instances.
  50.      * @return A Mono which emits the concatenation of all the ByteBuffer instances given by the source Flux.
  51.      */
  52.     public static Mono<byte[]> collectBytesInByteBufferStream(Flux<ByteBuffer> stream) {
  53.         return stream
  54.             .collect(ByteArrayOutputStream::new, FluxUtil::accept)
  55.             .map(ByteArrayOutputStream::toByteArray);
  56.     }

  57.     private static void accept(ByteArrayOutputStream byteOutputStream, ByteBuffer byteBuffer) {
  58.         try {
  59.             byteOutputStream.write(byteBufferToArray(byteBuffer));
  60.         } catch (IOException e) {
  61.             throw new RuntimeException("Error occurred writing ByteBuffer to ByteArrayOutputStream.", e);
  62.         }
  63.     }

  64.     /**
  65.      * Gets the content of the provided ByteBuffer as a byte array. This method will create a new byte array even if the
  66.      * ByteBuffer can have optionally backing array.
  67.      *
  68.      * @param byteBuffer the byte buffer
  69.      * @return the byte array
  70.      */
  71.     public static byte[] byteBufferToArray(ByteBuffer byteBuffer) {
  72.         int length = byteBuffer.remaining();
  73.         byte[] byteArray = new byte[length];
  74.         byteBuffer.get(byteArray);
  75.         return byteArray;
  76.     }

  77.     /**
  78.      * This method converts the incoming {@code subscriberContext} from {@link reactor.util.context.Context Reactor
  79.      * Context} to {@link Context Azure Context} and calls the given lambda function with this context and returns a
  80.      * single entity of type {@code T}
  81.      * <p>
  82.      * If the reactor context is empty, {@link Context#NONE} will be used to call the lambda function
  83.      * </p>
  84.      *
  85.      * <p><strong>Code samples</strong></p>
  86.      * {@codesnippet com.azure.core.implementation.util.fluxutil.withcontext}
  87.      *
  88.      * @param serviceCall The lambda function that makes the service call into which azure context will be passed
  89.      * @param <T> The type of response returned from the service call
  90.      * @return The response from service call
  91.      */
  92.     public static <T> Mono<T> withContext(Function<Context, Mono<T>> serviceCall) {
  93.         return Mono.subscriberContext()
  94.             .map(FluxUtil::toAzureContext)
  95.             .flatMap(serviceCall);
  96.     }

  97.     /**
  98.      * Converts the incoming content to Mono.
  99.      *
  100.      * @param response whose {@link Response#getValue() value} is to be converted
  101.      * @return The converted {@link Mono}
  102.      */
  103.     public static <T> Mono<T> toMono(Response<T> response) {
  104.         return Mono.justOrEmpty(response.getValue());
  105.     }

  106.     /**
  107.      * Propagates a {@link RuntimeException} through the error channel of {@link Mono}.
  108.      *
  109.      * @param logger The {@link ClientLogger} to log the exception.
  110.      * @param ex The {@link RuntimeException}.
  111.      * @param <T> The return type.
  112.      * @return A {@link Mono} that terminates with error wrapping the {@link RuntimeException}.
  113.      */
  114.     public static <T> Mono<T> monoError(ClientLogger logger, RuntimeException ex) {
  115.         return Mono.error(logger.logExceptionAsError(Exceptions.propagate(ex)));
  116.     }

  117.     /**
  118.      * Propagates a {@link RuntimeException} through the error channel of {@link Flux}.
  119.      *
  120.      * @param logger The {@link ClientLogger} to log the exception.
  121.      * @param ex The {@link RuntimeException}.
  122.      * @param <T> The return type.
  123.      * @return A {@link Flux} that terminates with error wrapping the {@link RuntimeException}.
  124.      */
  125.     public static <T> Flux<T> fluxError(ClientLogger logger, RuntimeException ex) {
  126.         return Flux.error(logger.logExceptionAsError(Exceptions.propagate(ex)));
  127.     }

  128.     /**
  129.      * Propagates a {@link RuntimeException} through the error channel of {@link PagedFlux}.
  130.      *
  131.      * @param logger The {@link ClientLogger} to log the exception.
  132.      * @param ex The {@link RuntimeException}.
  133.      * @param <T> The return type.
  134.      * @return A {@link PagedFlux} that terminates with error wrapping the {@link RuntimeException}.
  135.      */
  136.     public static <T> PagedFlux<T> pagedFluxError(ClientLogger logger, RuntimeException ex) {
  137.         return new PagedFlux<>(() -> monoError(logger, ex));
  138.     }

  139.     /**
  140.      * This method converts the incoming {@code subscriberContext} from {@link reactor.util.context.Context Reactor
  141.      * Context} to {@link Context Azure Context} and calls the given lambda function with this context and returns a
  142.      * collection of type {@code T}
  143.      * <p>
  144.      * If the reactor context is empty, {@link Context#NONE} will be used to call the lambda function
  145.      * </p>
  146.      *
  147.      * <p><strong>Code samples</strong></p>
  148.      * {@codesnippet com.azure.core.implementation.util.fluxutil.fluxcontext}
  149.      *
  150.      * @param serviceCall The lambda function that makes the service call into which the context will be passed
  151.      * @param <T> The type of response returned from the service call
  152.      * @return The response from service call
  153.      */
  154.     public static <T> Flux<T> fluxContext(Function<Context, Flux<T>> serviceCall) {
  155.         return Mono.subscriberContext()
  156.             .map(FluxUtil::toAzureContext)
  157.             .flatMapMany(serviceCall);
  158.     }

  159.     /**
  160.      * Converts a reactor context to azure context. If the reactor context is {@code null} or empty, {@link
  161.      * Context#NONE} will be returned.
  162.      *
  163.      * @param context The reactor context
  164.      * @return The azure context
  165.      */
  166.     private static Context toAzureContext(reactor.util.context.Context context) {
  167.         Map<Object, Object> keyValues = context.stream().collect(Collectors.toMap(Entry::getKey, Entry::getValue));
  168.         if (ImplUtils.isNullOrEmpty(keyValues)) {
  169.             return Context.NONE;
  170.         }
  171.         return Context.of(keyValues);
  172.     }

  173.     /**
  174.      * Writes the bytes emitted by a Flux to an AsynchronousFileChannel.
  175.      *
  176.      * @param content the Flux content
  177.      * @param outFile the file channel
  178.      * @return a Mono which performs the write operation when subscribed
  179.      */
  180.     public static Mono<Void> writeFile(Flux<ByteBuffer> content, AsynchronousFileChannel outFile) {
  181.         return writeFile(content, outFile, 0);
  182.     }

  183.     /**
  184.      * Writes the bytes emitted by a Flux to an AsynchronousFileChannel starting at the given position in the file.
  185.      *
  186.      * @param content the Flux content
  187.      * @param outFile the file channel
  188.      * @param position the position in the file to begin writing
  189.      * @return a Mono which performs the write operation when subscribed
  190.      */
  191.     public static Mono<Void> writeFile(Flux<ByteBuffer> content, AsynchronousFileChannel outFile, long position) {
  192.         return Mono.create(emitter -> content.subscribe(new Subscriber<ByteBuffer>() {
  193.             // volatile ensures that writes to these fields by one thread will be immediately visible to other threads.
  194.             // An I/O pool thread will write to isWriting and read isCompleted,
  195.             // while another thread may read isWriting and write to isCompleted.
  196.             volatile boolean isWriting = false;
  197.             volatile boolean isCompleted = false;
  198.             volatile Subscription subscription;
  199.             volatile long pos = position;

  200.             @Override
  201.             public void onSubscribe(Subscription s) {
  202.                 subscription = s;
  203.                 s.request(1);
  204.             }

  205.             @Override
  206.             public void onNext(ByteBuffer bytes) {
  207.                 isWriting = true;
  208.                 outFile.write(bytes, pos, null, onWriteCompleted);
  209.             }


  210.             CompletionHandler<Integer, Object> onWriteCompleted = new CompletionHandler<Integer, Object>() {
  211.                 @Override
  212.                 public void completed(Integer bytesWritten, Object attachment) {
  213.                     isWriting = false;
  214.                     if (isCompleted) {
  215.                         emitter.success();
  216.                     }
  217.                     //noinspection NonAtomicOperationOnVolatileField
  218.                     pos += bytesWritten;
  219.                     subscription.request(1);
  220.                 }

  221.                 @Override
  222.                 public void failed(Throwable exc, Object attachment) {
  223.                     subscription.cancel();
  224.                     emitter.error(exc);
  225.                 }
  226.             };

  227.             @Override
  228.             public void onError(Throwable throwable) {
  229.                 subscription.cancel();
  230.                 emitter.error(throwable);
  231.             }

  232.             @Override
  233.             public void onComplete() {
  234.                 isCompleted = true;
  235.                 if (!isWriting) {
  236.                     emitter.success();
  237.                 }
  238.             }
  239.         }));
  240.     }

  241.     /**
  242.      * Creates a {@link Flux} from an {@link AsynchronousFileChannel} which reads part of a file into chunks of the
  243.      * given size.
  244.      *
  245.      * @param fileChannel The file channel.
  246.      * @param chunkSize the size of file chunks to read.
  247.      * @param offset The offset in the file to begin reading.
  248.      * @param length The number of bytes to read from the file.
  249.      * @return the Flux.
  250.      */
  251.     public static Flux<ByteBuffer> readFile(AsynchronousFileChannel fileChannel, int chunkSize, long offset,
  252.                                             long length) {
  253.         return new FileReadFlux(fileChannel, chunkSize, offset, length);
  254.     }

  255.     /**
  256.      * Creates a {@link Flux} from an {@link AsynchronousFileChannel} which reads part of a file.
  257.      *
  258.      * @param fileChannel The file channel.
  259.      * @param offset The offset in the file to begin reading.
  260.      * @param length The number of bytes to read from the file.
  261.      * @return the Flux.
  262.      */
  263.     public static Flux<ByteBuffer> readFile(AsynchronousFileChannel fileChannel, long offset, long length) {
  264.         return readFile(fileChannel, DEFAULT_CHUNK_SIZE, offset, length);
  265.     }

  266.     /**
  267.      * Creates a {@link Flux} from an {@link AsynchronousFileChannel} which reads the entire file.
  268.      *
  269.      * @param fileChannel The file channel.
  270.      * @return The AsyncInputStream.
  271.      */
  272.     public static Flux<ByteBuffer> readFile(AsynchronousFileChannel fileChannel) {
  273.         try {
  274.             long size = fileChannel.size();
  275.             return readFile(fileChannel, DEFAULT_CHUNK_SIZE, 0, size);
  276.         } catch (IOException e) {
  277.             return Flux.error(new RuntimeException("Failed to read the file.", e));
  278.         }
  279.     }

  280.     private static final int DEFAULT_CHUNK_SIZE = 1024 * 64;

  281.     private static final class FileReadFlux extends Flux<ByteBuffer> {
  282.         private final AsynchronousFileChannel fileChannel;
  283.         private final int chunkSize;
  284.         private final long offset;
  285.         private final long length;

  286.         FileReadFlux(AsynchronousFileChannel fileChannel, int chunkSize, long offset, long length) {
  287.             this.fileChannel = fileChannel;
  288.             this.chunkSize = chunkSize;
  289.             this.offset = offset;
  290.             this.length = length;
  291.         }

  292.         @Override
  293.         public void subscribe(CoreSubscriber<? super ByteBuffer> actual) {
  294.             FileReadSubscription subscription =
  295.                 new FileReadSubscription(actual, fileChannel, chunkSize, offset, length);
  296.             actual.onSubscribe(subscription);
  297.         }

  298.         static final class FileReadSubscription implements Subscription, CompletionHandler<Integer, ByteBuffer> {
  299.             private static final int NOT_SET = -1;
  300.             private static final long serialVersionUID = -6831808726875304256L;
  301.             //
  302.             private final Subscriber<? super ByteBuffer> subscriber;
  303.             private volatile long position;
  304.             //
  305.             private final AsynchronousFileChannel fileChannel;
  306.             private final int chunkSize;
  307.             private final long offset;
  308.             private final long length;
  309.             //
  310.             private volatile boolean done;
  311.             private Throwable error;
  312.             private volatile ByteBuffer next;
  313.             private volatile boolean cancelled;
  314.             //
  315.             volatile int wip;
  316.             @SuppressWarnings("rawtypes")
  317.             static final AtomicIntegerFieldUpdater<FileReadSubscription> WIP =
  318.                 AtomicIntegerFieldUpdater.newUpdater(FileReadSubscription.class, "wip");
  319.             volatile long requested;
  320.             @SuppressWarnings("rawtypes")
  321.             static final AtomicLongFieldUpdater<FileReadSubscription> REQUESTED =
  322.                 AtomicLongFieldUpdater.newUpdater(FileReadSubscription.class, "requested");
  323.             //

  324.             FileReadSubscription(Subscriber<? super ByteBuffer> subscriber, AsynchronousFileChannel fileChannel,
  325.                                  int chunkSize, long offset, long length) {
  326.                 this.subscriber = subscriber;
  327.                 //
  328.                 this.fileChannel = fileChannel;
  329.                 this.chunkSize = chunkSize;
  330.                 this.offset = offset;
  331.                 this.length = length;
  332.                 //
  333.                 this.position = NOT_SET;
  334.             }

  335.             //region Subscription implementation

  336.             @Override
  337.             public void request(long n) {
  338.                 if (Operators.validate(n)) {
  339.                     Operators.addCap(REQUESTED, this, n);
  340.                     drain();
  341.                 }
  342.             }

  343.             @Override
  344.             public void cancel() {
  345.                 this.cancelled = true;
  346.             }

  347.             //endregion

  348.             //region CompletionHandler implementation

  349.             @Override
  350.             public void completed(Integer bytesRead, ByteBuffer buffer) {
  351.                 if (!cancelled) {
  352.                     if (bytesRead == -1) {
  353.                         done = true;
  354.                     } else {
  355.                         // use local variable to perform fewer volatile reads
  356.                         long pos = position;
  357.                         int bytesWanted = Math.min(bytesRead, maxRequired(pos));
  358.                         long position2 = pos + bytesWanted;
  359.                         //noinspection NonAtomicOperationOnVolatileField
  360.                         position = position2;
  361.                         buffer.position(bytesWanted);
  362.                         buffer.flip();
  363.                         next = buffer;
  364.                         if (position2 >= offset + length) {
  365.                             done = true;
  366.                         }
  367.                     }
  368.                     drain();
  369.                 }
  370.             }

  371.             @Override
  372.             public void failed(Throwable exc, ByteBuffer attachment) {
  373.                 if (!cancelled) {
  374.                     // must set error before setting done to true
  375.                     // so that is visible in drain loop
  376.                     error = exc;
  377.                     done = true;
  378.                     drain();
  379.                 }
  380.             }

  381.             //endregion

  382.             private void drain() {
  383.                 if (WIP.getAndIncrement(this) != 0) {
  384.                     return;
  385.                 }
  386.                 // on first drain (first request) we initiate the first read
  387.                 if (position == NOT_SET) {
  388.                     position = offset;
  389.                     doRead();
  390.                 }
  391.                 int missed = 1;
  392.                 for (;;) {
  393.                     if (cancelled) {
  394.                         return;
  395.                     }
  396.                     if (REQUESTED.get(this) > 0) {
  397.                         boolean emitted = false;
  398.                         // read d before next to avoid race
  399.                         boolean d = done;
  400.                         ByteBuffer bb = next;
  401.                         if (bb != null) {
  402.                             next = null;
  403.                             subscriber.onNext(bb);
  404.                             emitted = true;
  405.                         } else {
  406.                             emitted = false;
  407.                         }
  408.                         if (d) {
  409.                             if (error != null) {
  410.                                 subscriber.onError(error);
  411.                                 // exit without reducing wip so that further drains will be NOOP
  412.                                 return;
  413.                             } else {
  414.                                 subscriber.onComplete();
  415.                                 // exit without reducing wip so that further drains will be NOOP
  416.                                 return;
  417.                             }
  418.                         }
  419.                         if (emitted) {
  420.                             // do this after checking d to avoid calling read
  421.                             // when done
  422.                             Operators.produced(REQUESTED, this, 1);
  423.                             //
  424.                             doRead();
  425.                         }
  426.                     }
  427.                     missed = WIP.addAndGet(this, -missed);
  428.                     if (missed == 0) {
  429.                         return;
  430.                     }
  431.                 }
  432.             }

  433.             private void doRead() {
  434.                 // use local variable to limit volatile reads
  435.                 long pos = position;
  436.                 ByteBuffer innerBuf = ByteBuffer.allocate(Math.min(chunkSize, maxRequired(pos)));
  437.                 fileChannel.read(innerBuf, pos, innerBuf, this);
  438.             }

  439.             private int maxRequired(long pos) {
  440.                 long maxRequired = offset + length - pos;
  441.                 if (maxRequired <= 0) {
  442.                     return 0;
  443.                 } else {
  444.                     int m = (int) (maxRequired);
  445.                     // support really large files by checking for overflow
  446.                     if (m < 0) {
  447.                         return Integer.MAX_VALUE;
  448.                     } else {
  449.                         return m;
  450.                     }
  451.                 }
  452.             }
  453.         }
  454.     }


  455.     // Private Ctr
  456.     private FluxUtil() {
  457.     }
  458. }