RntbdServiceEndpoint.java

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package com.azure.cosmos.implementation.directconnectivity.rntbd;

import com.azure.cosmos.BridgeInternal;
import com.azure.cosmos.CosmosException;
import com.azure.cosmos.implementation.GoneException;
import com.azure.cosmos.implementation.HttpConstants;
import com.azure.cosmos.implementation.directconnectivity.IAddressResolver;
import com.azure.cosmos.implementation.directconnectivity.RntbdTransportClient;
import com.azure.cosmos.implementation.directconnectivity.TransportException;
import com.azure.cosmos.implementation.guava25.collect.ImmutableMap;
import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.databind.SerializerProvider;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import com.fasterxml.jackson.databind.ser.std.StdSerializer;
import io.micrometer.core.instrument.Tag;
import io.netty.bootstrap.Bootstrap;
import io.netty.channel.AdaptiveRecvByteBufAllocator;
import io.netty.channel.Channel;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.epoll.EpollChannelOption;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.ssl.SslContext;
import io.netty.util.concurrent.DefaultEventExecutor;
import io.netty.util.concurrent.DefaultThreadFactory;
import io.netty.util.concurrent.EventExecutor;
import io.netty.util.concurrent.Future;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.net.SocketAddress;
import java.net.URI;
import java.time.Duration;
import java.time.Instant;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.stream.Stream;

import static com.azure.cosmos.implementation.HttpConstants.HttpHeaders;
import static com.azure.cosmos.implementation.directconnectivity.RntbdTransportClient.Options;
import static com.azure.cosmos.implementation.guava25.base.Preconditions.checkNotNull;
import static com.azure.cosmos.implementation.guava27.Strings.lenientFormat;
import static java.util.concurrent.TimeUnit.NANOSECONDS;

@JsonSerialize(using = RntbdServiceEndpoint.JsonSerializer.class)
public final class RntbdServiceEndpoint implements RntbdEndpoint {

    // region Fields

    private static final String TAG_NAME = RntbdServiceEndpoint.class.getSimpleName();
    private static final long QUIET_PERIOD = 2_000_000_000L; // 2 seconds

    private static final AtomicLong instanceCount = new AtomicLong();
    private static final Logger logger = LoggerFactory.getLogger(RntbdServiceEndpoint.class);
    private static final AdaptiveRecvByteBufAllocator receiveBufferAllocator = new AdaptiveRecvByteBufAllocator();

    private final RntbdClientChannelPool channelPool;
    private final AtomicBoolean closed;
    private final AtomicInteger concurrentRequests;
    private final long id;
    private final AtomicLong lastRequestNanoTime;
    private final AtomicLong lastSuccessfulRequestNanoTime;

    private final Instant createdTime;
    private final RntbdMetrics metrics;
    private final Provider provider;
    private final URI serverKey;
    private final SocketAddress remoteAddress;
    private final RntbdRequestTimer requestTimer;
    private final Tag tag;
    private final int maxConcurrentRequests;
    private final boolean channelAcquisitionContextEnabled;

    private final RntbdConnectionStateListener connectionStateListener;

    // endregion

    // region Constructors

    private RntbdServiceEndpoint(
        final Provider provider,
        final Config config,
        final EventLoopGroup group,
        final RntbdRequestTimer timer,
        final URI physicalAddress) {

        this.serverKey = RntbdUtils.getServerKey(physicalAddress);

        final Bootstrap bootstrap = this.getBootStrap(group, config);

        this.createdTime = Instant.now();
        this.channelPool = new RntbdClientChannelPool(this, bootstrap, config);
        this.remoteAddress = bootstrap.config().remoteAddress();
        this.concurrentRequests = new AtomicInteger();
        // if no request has been sent over this endpoint we want to make sure we don't trigger a connection close
        // due to elapsedTimeInNanos being negative.
        // if no request has been sent initially over this endpoint, the below calculation can result in a very big difference
        // long elapsedTimeInNanos = System.nanoTime() - endpoint.lastRequestNanoTime()
        // which can cause endpoint to close unnecessary.
        this.lastRequestNanoTime = new AtomicLong(System.nanoTime());
        this.lastSuccessfulRequestNanoTime = new AtomicLong(System.nanoTime());
        this.closed = new AtomicBoolean();
        this.requestTimer = timer;

        this.tag = Tag.of(TAG_NAME, RntbdMetrics.escape(this.remoteAddress.toString()));
        this.id = instanceCount.incrementAndGet();
        this.provider = provider;

        this.metrics = new RntbdMetrics(provider.transportClient, this);
        this.maxConcurrentRequests = config.maxConcurrentRequestsPerEndpoint();

        this.connectionStateListener = this.provider.addressResolver != null && config.isConnectionEndpointRediscoveryEnabled()
            ? new RntbdConnectionStateListener(this.provider.addressResolver, this)
            : null;

        this.channelAcquisitionContextEnabled = config.isChannelAcquisitionContextEnabled();
    }

    private Bootstrap getBootStrap(EventLoopGroup eventLoopGroup, Config config) {
        checkNotNull(eventLoopGroup, "expected non-null eventLoopGroup");
        checkNotNull(config, "expected non-null config");

        RntbdLoop rntbdLoop = RntbdLoopNativeDetector.getRntbdLoop(config.preferTcpNative());
        Bootstrap bootstrap = new Bootstrap()
            .group(eventLoopGroup)
            .channel(rntbdLoop.getChannelClass())
            .option(ChannelOption.ALLOCATOR, config.allocator())
            .option(ChannelOption.AUTO_READ, true)
            .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, config.connectTimeoutInMillis())
            .option(ChannelOption.RCVBUF_ALLOCATOR, receiveBufferAllocator)
            .option(ChannelOption.SO_KEEPALIVE, true)
            .remoteAddress(this.serverKey.getHost(), this.serverKey.getPort());

        if (rntbdLoop instanceof RntbdLoopEpoll) {
            // Override the default Linux os config for tcp keep-alive, so a broken connection can be detected faster.
            // see man 7 tcp for more details.
            bootstrap
                .option(EpollChannelOption.TCP_KEEPINTVL, config.tcpKeepIntvl()) // default value 75s
                .option(EpollChannelOption.TCP_KEEPIDLE, config.tcpKeepIdle()); // default value 2hrs
        }

        return bootstrap;
    }

    // endregion

    // region Accessors

    /**
     * @return approximate number of acquired channels.
     */
    @Override
    public int channelsAcquiredMetric() {
        return this.channelPool.channelsAcquiredMetrics();
    }

    /**
     * @return approximate number of available channels.
     */
    @Override
    public int channelsAvailableMetric() {
        return this.channelPool.channelsAvailableMetrics();
    }

    @Override
    public int concurrentRequests() {
        return this.concurrentRequests.get();
    }

    @Override
    public int gettingEstablishedConnectionsMetrics() {
        return this.channelPool.attemptingToConnectMetrics();
    }

    @Override
    public long id() {
        return this.id;
    }

    @Override
    public boolean isClosed() {
        return this.closed.get();
    }

    @Override
    public int maxChannels() {
        return this.channelPool.channels(true);
    }

    public long lastRequestNanoTime() {
        return this.lastRequestNanoTime.get();
    }

    @Override
    public long lastSuccessfulRequestNanoTime() {
        return this.lastSuccessfulRequestNanoTime.get();
    }

    @Override
    public int channelsMetrics() {
        return this.channelPool.channels(true);
    }

    @Override
    public int executorTaskQueueMetrics() {
        return this.channelPool.executorTaskQueueMetrics();
    }

    public Instant getCreatedTime() {
        return this.createdTime;
    }

    @Override
    public SocketAddress remoteAddress() {
        return this.remoteAddress;
    }

    @Override
    public URI serverKey() { return this.serverKey; }

    @Override
    public int requestQueueLength() {
        return this.channelPool.requestQueueLength();
    }

    @Override
    public Tag tag() {
        return this.tag;
    }

    @Override
    public long usedDirectMemory() {
        return this.channelPool.usedDirectMemory();
    }

    @Override
    public long usedHeapMemory() {
        return this.channelPool.usedHeapMemory();
    }

    // endregion

    // region Methods

    @Override
    public void close() {
        if (this.closed.compareAndSet(false, true)) {
            this.provider.evict(this);
            this.channelPool.close();
        }
    }

    public RntbdRequestRecord request(final RntbdRequestArgs args) {

        this.throwIfClosed();

        int concurrentRequestSnapshot = this.concurrentRequests.incrementAndGet();

        RntbdEndpointStatistics stat = endpointMetricsSnapshot(concurrentRequestSnapshot);

        if (concurrentRequestSnapshot > this.maxConcurrentRequests) {
            try {
                FailFastRntbdRequestRecord requestRecord = FailFastRntbdRequestRecord.createAndFailFast(
                    args,
                    concurrentRequestSnapshot,
                    metrics,
                    remoteAddress);
                requestRecord.serviceEndpointStatistics(stat);
                return requestRecord;
            }
            finally {
                concurrentRequests.decrementAndGet();
            }
        }

        this.lastRequestNanoTime.set(args.nanoTimeCreated());

        final RntbdRequestRecord record = this.write(args);
        record.serviceEndpointStatistics(stat);

        record.whenComplete((response, error) -> {
            this.concurrentRequests.decrementAndGet();
            this.metrics.markComplete(record);
            onResponse(error, record);
        });

        return record;
    }

    private void onResponse(Throwable exception, RntbdRequestRecord record) {
        if (exception == null) {
            this.lastSuccessfulRequestNanoTime.set(System.nanoTime());
            return;
        }

        if (this.connectionStateListener != null) {
            this.connectionStateListener.onException(record.args().serviceRequest(), exception);
        }

        // exception != null
        if (exception instanceof CosmosException) {
            CosmosException cosmosException = (CosmosException) exception;
            switch (cosmosException.getStatusCode()) {
                // non 200 status codes representing business logic success
                case HttpConstants.StatusCodes.CONFLICT:
                case HttpConstants.StatusCodes.NOTFOUND:
                    this.lastSuccessfulRequestNanoTime.set(System.nanoTime());
                    return;
                default:
                    return;
            }
        }
    }

    private RntbdEndpointStatistics endpointMetricsSnapshot(int concurrentRequestSnapshot) {
        RntbdEndpointStatistics stats = new RntbdEndpointStatistics()
            .availableChannels(this.channelsAvailableMetric())
            .acquiredChannels(this.channelsAcquiredMetric())
            .executorTaskQueueSize(this.executorTaskQueueMetrics())
            .lastSuccessfulRequestNanoTime(this.lastSuccessfulRequestNanoTime())
            .createdTime(this.createdTime)
            .lastRequestNanoTime(this.lastRequestNanoTime())
            .closed(this.closed.get())
            .inflightRequests(concurrentRequestSnapshot);
        return stats;
    }

    @Override
    public String toString() {
        return RntbdObjectMapper.toString(this);
    }

    // endregion

    // region Privates

    private void ensureSuccessWhenReleasedToPool(Channel channel, Future<Void> released) {
        if (released.isSuccess()) {
            logger.debug("\n  [{}]\n  {}\n  release succeeded", this, channel);
        } else {
            logger.debug("\n  [{}]\n  {}\n  release failed due to {}", this, channel, released.cause());
        }
    }

    private void releaseToPool(final Channel channel) {

        logger.debug("\n  [{}]\n  {}\n  RELEASE", this, channel);
        final Future<Void> released = this.channelPool.release(channel);

        if (logger.isDebugEnabled()) {
            if (released.isDone()) {
                ensureSuccessWhenReleasedToPool(channel, released);
            } else {
                released.addListener(ignored -> ensureSuccessWhenReleasedToPool(channel, released));
            }
        }
    }

    private void throwIfClosed() {
        if (this.closed.get()) {
            throw new TransportException(lenientFormat("%s is closed", this), new IllegalStateException());
        }
    }

    private RntbdRequestRecord write(final RntbdRequestArgs requestArgs) {

        final RntbdRequestRecord requestRecord = new AsyncRntbdRequestRecord(requestArgs, this.requestTimer);
        requestRecord.channelAcquisitionContextEnabled(this.channelAcquisitionContextEnabled);
        requestRecord.stage(RntbdRequestRecord.Stage.CHANNEL_ACQUISITION_STARTED);
        final Future<Channel> connectedChannel = this.channelPool.acquire(requestRecord.getChannelAcquisitionTimeline());

        logger.debug("\n  [{}]\n  {}\n  WRITE WHEN CONNECTED {}", this, requestArgs, connectedChannel);

        if (connectedChannel.isDone()) {
            return writeWhenConnected(requestRecord, connectedChannel);
        } else {
            connectedChannel.addListener(ignored -> writeWhenConnected(requestRecord, connectedChannel));
        }

        return requestRecord;
    }

    private RntbdRequestRecord writeWhenConnected(
        final RntbdRequestRecord requestRecord, final Future<? super Channel> connected) {

        if (connected.isSuccess()) {
            final Channel channel = (Channel) connected.getNow();
            assert channel != null : "impossible";
            this.releaseToPool(channel);
            requestRecord.channelTaskQueueLength(RntbdUtils.tryGetExecutorTaskQueueSize(channel.eventLoop()));
            channel.write(requestRecord.stage(RntbdRequestRecord.Stage.PIPELINED));
            return requestRecord;
        }

        final RntbdRequestArgs requestArgs = requestRecord.args();
        final UUID activityId = requestArgs.activityId();
        final Throwable cause = connected.cause();

        if (connected.isCancelled()) {

            logger.debug("\n  [{}]\n  {}\n  write cancelled: {}", this, requestArgs, cause);
            requestRecord.cancel(true);

        } else {

            logger.debug("\n  [{}]\n  {}\n  write failed due to {} ", this, requestArgs, cause);
            final String reason = cause.toString();

            final GoneException goneException = new GoneException(
                lenientFormat("failed to establish connection to %s due to %s", this.remoteAddress, reason),
                cause instanceof Exception ? (Exception) cause : new IOException(reason, cause),
                ImmutableMap.of(HttpHeaders.ACTIVITY_ID, activityId.toString()),
                requestArgs.replicaPath()
            );

            BridgeInternal.setRequestHeaders(goneException, requestArgs.serviceRequest().getHeaders());
            requestRecord.completeExceptionally(goneException);
        }

        return requestRecord;
    }

    // endregion

    // region Types

    static final class JsonSerializer extends StdSerializer<RntbdServiceEndpoint> {

        private static final long serialVersionUID = -5764954918168771152L;

        public JsonSerializer() {
            super(RntbdServiceEndpoint.class);
        }

        @Override
        public void serialize(RntbdServiceEndpoint value, JsonGenerator generator, SerializerProvider provider)
            throws IOException {

            final RntbdTransportClient transportClient = value.provider.transportClient;

            generator.writeStartObject();
            generator.writeNumberField("id", value.id);
            generator.writeBooleanField("closed", value.isClosed());
            generator.writeNumberField("concurrentRequests", value.concurrentRequests());
            generator.writeStringField("remoteAddress", value.remoteAddress.toString());
            generator.writeObjectField("channelPool", value.channelPool);
            generator.writeObjectFieldStart("transportClient");
            generator.writeNumberField("id", transportClient.id());
            generator.writeBooleanField("closed", transportClient.isClosed());
            generator.writeNumberField("endpointCount", transportClient.endpointCount());
            generator.writeNumberField("endpointEvictionCount", transportClient.endpointEvictionCount());
            generator.writeEndObject();
            generator.writeEndObject();
        }
    }

    public static final class Provider implements RntbdEndpoint.Provider {

        private static final Logger logger = LoggerFactory.getLogger(Provider.class);

        private final AtomicBoolean closed;
        private final Config config;
        private final ConcurrentHashMap<String, RntbdEndpoint> endpoints;
        private final EventLoopGroup eventLoopGroup;
        private final AtomicInteger evictions;
        private final RntbdEndpointMonitoringProvider monitoring;
        private final RntbdRequestTimer requestTimer;
        private final RntbdTransportClient transportClient;
        private final IAddressResolver addressResolver;

        public Provider(
            final RntbdTransportClient transportClient,
            final Options options,
            final SslContext sslContext,
            final IAddressResolver addressResolver) {

            checkNotNull(transportClient, "expected non-null provider");
            checkNotNull(options, "expected non-null options");
            checkNotNull(sslContext, "expected non-null sslContext");

            final LogLevel wireLogLevel;

            if (logger.isDebugEnabled()) {
                wireLogLevel = LogLevel.TRACE;
            } else {
                wireLogLevel = null;
            }

            this.addressResolver = addressResolver;
            this.transportClient = transportClient;
            this.config = new Config(options, sslContext, wireLogLevel);

            this.requestTimer = new RntbdRequestTimer(
                config.tcpNetworkRequestTimeoutInNanos(),
                config.requestTimerResolutionInNanos());

            this.eventLoopGroup = this.getEventLoopGroup(options);
            this.endpoints = new ConcurrentHashMap<>();
            this.evictions = new AtomicInteger();
            this.closed = new AtomicBoolean();
            this.monitoring = new RntbdEndpointMonitoringProvider(this);
            this.monitoring.init();
        }

        private EventLoopGroup getEventLoopGroup(Options options) {
            checkNotNull(options, "expected non-null options");

            RntbdLoop rntbdEventLoop = RntbdLoopNativeDetector.getRntbdLoop(options.preferTcpNative());
            DefaultThreadFactory threadFactory =
                new DefaultThreadFactory(
                    "cosmos-rntbd-" + rntbdEventLoop.getName(),
                    true,
                    options.ioThreadPriority());

            return rntbdEventLoop.newEventLoopGroup(options.threadCount(), threadFactory);
        }

        @Override
        public void close() {

            if (this.closed.compareAndSet(false, true)) {
                this.monitoring.close();

                for (final RntbdEndpoint endpoint : this.endpoints.values()) {
                    endpoint.close();
                }

                this.eventLoopGroup.shutdownGracefully(QUIET_PERIOD, this.config.shutdownTimeoutInNanos(), NANOSECONDS)
                    .addListener(future -> {

                        this.requestTimer.close();

                        if (future.isSuccess()) {
                            logger.debug("\n  [{}]\n  closed endpoints", this);
                            return;
                        }

                        logger.error("\n  [{}]\n  failed to close endpoints due to ", this, future.cause());
                    });

                return;
            }

            logger.debug("\n  [{}]\n  already closed", this);
        }

        @Override
        public Config config() {
            return this.config;
        }

        @Override
        public int count() {
            return this.endpoints.size();
        }

        @Override
        public int evictions() {
            return this.evictions.get();
        }

        @Override
        public RntbdEndpoint get(final URI physicalAddress) {
            return endpoints.computeIfAbsent(physicalAddress.getAuthority(), authority -> new RntbdServiceEndpoint(
                this,
                this.config,
                this.eventLoopGroup,
                this.requestTimer,
                physicalAddress));
        }

        @Override
        public IAddressResolver getAddressResolver() {
            return this.addressResolver;
        }

        @Override
        public Stream<RntbdEndpoint> list() {
            return this.endpoints.values().stream();
        }

        private void evict(final RntbdEndpoint endpoint) {
            if (this.endpoints.remove(endpoint.serverKey().getAuthority()) != null) {
                this.evictions.incrementAndGet();
            }
        }
    }

    public static class RntbdEndpointMonitoringProvider implements AutoCloseable {
        private final Logger logger = LoggerFactory.getLogger(RntbdEndpointMonitoringProvider.class);
        // this is only for debugging monitoring of the health of RNTBD
        // TODO: once we are certain no task gets stuck in the rntbd queue remove this
        private static final EventExecutor monitoringRntbdChannelPool = new DefaultEventExecutor(new RntbdThreadFactory(
            "monitoring-rntbd-endpoints",
            true,
            Thread.MIN_PRIORITY));
        private static final Duration MONITORING_PERIOD = Duration.ofSeconds(60);
        private final Provider provider;
        private final static int MAX_TASK_LIMIT = 5_000;

        private ScheduledFuture<?> future;

        RntbdEndpointMonitoringProvider(Provider provider) {
            this.provider = provider;
        }

        synchronized void init() {
            logger.info("Starting RntbdClientChannelPoolMonitoringProvider ...");
            this.future = RntbdEndpointMonitoringProvider.monitoringRntbdChannelPool.scheduleAtFixedRate(() -> {
                logAllPools();
            }, 0, MONITORING_PERIOD.toMillis(), TimeUnit.MILLISECONDS);
        }

        @Override
        public synchronized void close() {
            logger.info("Shutting down RntbdClientChannelPoolMonitoringProvider ...");
            this.future.cancel(false);
            this.future = null;
        }

        synchronized void logAllPools() {
            try {
                logger.debug("Total number of RntbdClientChannelPool [{}].", provider.endpoints.size());
                for (RntbdEndpoint endpoint : provider.endpoints.values()) {
                    logEndpoint(endpoint);
                }
            } catch (Exception e) {
                logger.error("monitoring unexpected failure", e);
            }
        }

        private void logEndpoint(RntbdEndpoint endpoint) {
            if (this.logger.isWarnEnabled() &&
                (endpoint.executorTaskQueueMetrics() > MAX_TASK_LIMIT ||
                endpoint.requestQueueLength() > MAX_TASK_LIMIT ||
                endpoint.gettingEstablishedConnectionsMetrics() > 0 ||
                endpoint.channelsMetrics() > endpoint.maxChannels())) {
                logger.warn("RntbdEndpoint Identifier {}, Stat {}", getPoolId(endpoint), getPoolStat(endpoint));
            } else if (this.logger.isDebugEnabled()) {
                logger.debug("RntbdEndpoint Identifier {}, Stat {}", getPoolId(endpoint), getPoolStat(endpoint));
            }
        }

        private String getPoolStat(RntbdEndpoint endpoint) {
            return "[ "
                + "poolTaskExecutorSize " + endpoint.executorTaskQueueMetrics()
                + ", lastRequestNanoTime " + Instant.now().minusNanos(
                System.nanoTime() - endpoint.lastRequestNanoTime())
                + ", connecting " + endpoint.gettingEstablishedConnectionsMetrics()
                + ", acquiredChannel " + endpoint.channelsAcquiredMetric()
                + ", availableChannel " + endpoint.channelsAvailableMetric()
                + ", pendingAcquisitionSize " + endpoint.requestQueueLength()
                + ", closed " + endpoint.isClosed()
                + " ]";
        }

        private String getPoolId(RntbdEndpoint endpoint) {
            if (endpoint == null) {
                return "null";
            }

            return "[RntbdEndpoint" +
                ", id " + endpoint.id() +
                ", remoteAddress " + endpoint.remoteAddress() +
                ", creationTime " + endpoint.getCreatedTime() +
                ", hashCode " + endpoint.hashCode() +
                "]";
        }
    }

    // endregion
}