CommunicationTokenCredential.java

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
package com.azure.communication.common;

import com.azure.core.util.FluxUtil;
import com.azure.core.util.logging.ClientLogger;

import reactor.core.publisher.Mono;

import com.azure.core.credential.AccessToken;

import java.io.IOException;
import java.time.OffsetDateTime;
import java.util.Date;
import java.util.Objects;
import java.util.Timer;
import java.util.TimerTask;
import java.util.function.Supplier;

import com.azure.communication.common.implementation.TokenParser;

/**
 * Provide user credential for Communication service user
 */
public final class CommunicationTokenCredential implements AutoCloseable {
    private static final int DEFAULT_EXPIRING_OFFSET_MINUTES = 10;

    private final ClientLogger logger = new ClientLogger(CommunicationTokenCredential.class);

    private AccessToken accessToken;
    private final TokenParser tokenParser = new TokenParser();
    private Supplier<Mono<String>> refresher;
    private FetchingTask fetchingTask;
    private boolean isClosed = false;

    /**
     * Create with serialized JWT token
     *
     * @param token serialized JWT token
     */
    public CommunicationTokenCredential(String token) {
        Objects.requireNonNull(token, "'token' cannot be null.");
        setToken(token);
    }

    /**
     * Create with tokenRefreshOptions, which includes a token supplier and optional serialized JWT token.
     * If refresh proactively is true, callback function tokenRefresher will be called
     * ahead of the token expiry by the number of minutes specified by
     * CallbackOffsetMinutes defaulted to ten minutes. To modify this default, call
     * setCallbackOffsetMinutes after construction
     *
     * @param tokenRefreshOptions implementation to supply fresh token when reqested
     */
    public CommunicationTokenCredential(CommunicationTokenRefreshOptions tokenRefreshOptions) {
        Supplier<Mono<String>> tokenRefresher = tokenRefreshOptions.getTokenRefresher();
        Objects.requireNonNull(tokenRefresher, "'tokenRefresher' cannot be null.");
        refresher = tokenRefresher;
        if (tokenRefreshOptions.getInitialToken() != null) {
            setToken(tokenRefreshOptions.getInitialToken());
            if (tokenRefreshOptions.isRefreshProactively()) {
                OffsetDateTime nextFetchTime = accessToken.getExpiresAt().minusMinutes(DEFAULT_EXPIRING_OFFSET_MINUTES);
                fetchingTask = new FetchingTask(this, nextFetchTime);
            }
        }
    }

    /**
     * Get Azure core access token from credential
     *
     * @return Asynchronous call to fetch actual token
     */
    public Mono<AccessToken> getToken() {
        if (isClosed) {
            return FluxUtil.monoError(logger,
                new RuntimeException("getToken called on closed CommunicationTokenCredential object"));
        }
        if ((accessToken == null || accessToken.isExpired()) && refresher != null) {
            synchronized (this) {
                // no valid token to return and can refresh
                if ((accessToken == null || accessToken.isExpired()) && refresher != null) {
                    return fetchFreshToken()
                        .map(token -> {
                            accessToken = tokenParser.parseJWTToken(token);
                            return accessToken;
                        });
                }
            }
        }
        return Mono.just(accessToken);
    }

    @Override
    public void close() throws IOException {
        isClosed = true;
        if (fetchingTask != null) {
            fetchingTask.stopTimer();
            fetchingTask = null;
        }
        refresher = null;
    }

    // For test verification usage only
    boolean hasProactiveFetcher() {
        return fetchingTask != null;
    }

    private void setToken(String freshToken) {
        accessToken = tokenParser.parseJWTToken(freshToken);

        if (fetchingTask != null) {
            OffsetDateTime nextFetchTime = accessToken.getExpiresAt().minusMinutes(DEFAULT_EXPIRING_OFFSET_MINUTES);
            fetchingTask.setNextFetchTime(nextFetchTime);
        }
    }

    private Mono<String> fetchFreshToken() {
        Mono<String> tokenAsync = refresher.get();
        if (tokenAsync == null) {
            return FluxUtil.monoError(logger,
                new RuntimeException("get() function of the token refresher should not return null."));
        }
        return tokenAsync;
    }

    private static class FetchingTask {
        private final CommunicationTokenCredential host;
        private Timer expiringTimer;
        private OffsetDateTime nextFetchTime;

        FetchingTask(CommunicationTokenCredential tokenHost,
            OffsetDateTime nextFetchAt) {
            host = tokenHost;
            nextFetchTime = nextFetchAt;
            startTimer();
        }

        private synchronized void setNextFetchTime(OffsetDateTime newFetchTime) {
            nextFetchTime = newFetchTime;
            stopTimer();
            startTimer();
        }

        private synchronized void startTimer() {
            expiringTimer = new Timer();
            Date expiring = Date.from(nextFetchTime.toInstant());
            expiringTimer.schedule(new TokenExpiringTask(this), expiring);
        }

        private synchronized void stopTimer() {
            if (expiringTimer == null) {
                return;
            }

            expiringTimer.cancel();
            expiringTimer.purge();
            expiringTimer = null;
        }

        private Mono<String> fetchFreshToken() {
            return host.fetchFreshToken();
        }

        private void setToken(String freshTokenString) {
            host.setToken(freshTokenString);
        }

        private class TokenExpiringTask extends TimerTask {
            private final ClientLogger logger = new ClientLogger(TokenExpiringTask.class);
            private final FetchingTask tokenCache;

            TokenExpiringTask(FetchingTask host) {
                tokenCache = host;
            }

            @Override
            public void run() {
                try {
                    Mono<String> tokenAsync = tokenCache.fetchFreshToken();
                    tokenCache.setToken(tokenAsync.block());
                } catch (Exception exception) {
                    logger.logExceptionAsError(new RuntimeException(exception));
                }

            }
        }
    }
}