RntbdContextNegotiator.java

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package com.azure.cosmos.implementation.directconnectivity.rntbd;

import com.azure.cosmos.implementation.UserAgentContainer;
import com.azure.cosmos.implementation.Utils;
import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelPromise;
import io.netty.channel.CombinedChannelDuplexHandler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.concurrent.CompletableFuture;

import static com.azure.cosmos.implementation.guava25.base.Preconditions.checkArgument;
import static com.azure.cosmos.implementation.guava25.base.Preconditions.checkNotNull;

public final class RntbdContextNegotiator extends CombinedChannelDuplexHandler<RntbdContextDecoder, RntbdContextRequestEncoder> {

    private static final Logger logger = LoggerFactory.getLogger(RntbdContextNegotiator.class);
    private final RntbdRequestManager manager;
    private final UserAgentContainer userAgent;

    private volatile boolean pendingRntbdContextRequest = true;

    public RntbdContextNegotiator(final RntbdRequestManager manager, final UserAgentContainer userAgent) {

        super(new RntbdContextDecoder(), new RntbdContextRequestEncoder());

        checkNotNull(manager, "manager");
        checkNotNull(userAgent, "userAgent");

        this.manager = manager;
        this.userAgent = userAgent;
    }

    /**
     * Called once a write operation is made. The write operation will write the messages through the
     * {@link ChannelPipeline}. Those are then ready to be flushed to the actual {@link Channel} once
     * {@link Channel#flush()} is called
     *
     * @param context the {@link ChannelHandlerContext} for which the write operation is made
     * @param message the message to write
     * @param promise the {@link ChannelPromise} to notify once the operation completes
     * @throws Exception thrown if an error occurs
     */
    @Override
    public void write(final ChannelHandlerContext context, final Object message, final ChannelPromise promise) throws Exception {

        checkArgument(message instanceof ByteBuf, "message: %s", message.getClass());
        final ByteBuf out = (ByteBuf)message;

        if (this.manager.hasRntbdContext()) {
            context.writeAndFlush(out, promise);
        } else {
            if (this.pendingRntbdContextRequest) {
                // Thread safe: netty guarantees that no channel handler methods are called concurrently
                this.startRntbdContextRequest(context);
                this.pendingRntbdContextRequest = false;
            }
            this.manager.pendWrite(out, promise);
        }
    }

    // region Privates

    private void startRntbdContextRequest(final ChannelHandlerContext context) throws Exception {

        logger.debug("{} START CONTEXT REQUEST", context.channel());

        final Channel channel = context.channel();
        final RntbdContextRequest request = new RntbdContextRequest(Utils.randomUUID(), this.userAgent);
        final CompletableFuture<RntbdContextRequest> contextRequestFuture = this.manager.rntbdContextRequestFuture();

        super.write(context, request, channel.newPromise().addListener((ChannelFutureListener)future -> {

            if (future.isSuccess()) {
                contextRequestFuture.complete(request);
                return;
            }

            if (future.isCancelled()) {
                contextRequestFuture.cancel(true);
                return;
            }

            contextRequestFuture.completeExceptionally(future.cause());
        }));
    }

    // endregion
}