DefaultCredentialsProvider.java

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package com.azure.spring.cloud.context.core.impl;

import com.azure.spring.cloud.context.core.api.CredentialsProvider;
import com.microsoft.azure.credentials.AppServiceMSICredentials;
import com.microsoft.azure.credentials.ApplicationTokenCredentials;
import com.microsoft.azure.credentials.AzureTokenCredentials;
import com.microsoft.azure.credentials.MSICredentials;
import com.azure.spring.cloud.context.core.config.AzureManagedIdentityProperties;
import com.azure.spring.cloud.context.core.config.AzureProperties;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import org.apache.commons.io.FileUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.io.DefaultResourceLoader;
import org.springframework.lang.NonNull;
import org.springframework.util.StringUtils;

/**
 * A {@link CredentialsProvider} implementation that provides credentials based on user-provided properties and
 * defaults.
 *
 * @author Warren Zhu
 */
public class DefaultCredentialsProvider implements CredentialsProvider {

    private static final Logger LOGGER = LoggerFactory.getLogger(DefaultCredentialsProvider.class);

    private static final String TEMP_CREDENTIAL_FILE_PREFIX = "azure";
    private static final String TEMP_CREDENTIAL_FILE_SUFFIX = "credential";
    private static final String ENV_MSI_ENDPOINT = "MSI_ENDPOINT";
    private static final String ENV_MSI_SECRET = "MSI_SECRET";

    private final AzureTokenCredentials credentials;

    public DefaultCredentialsProvider(AzureProperties azureProperties) {
        this.credentials = initCredentials(azureProperties);
    }

    private File createTempCredentialFile(@NonNull InputStream inputStream) throws IOException {
        File tempCredentialFile = File.createTempFile(TEMP_CREDENTIAL_FILE_PREFIX, TEMP_CREDENTIAL_FILE_SUFFIX);

        tempCredentialFile.deleteOnExit();
        FileUtils.copyInputStreamToFile(inputStream, tempCredentialFile);

        return tempCredentialFile;
    }

    private AzureTokenCredentials initCredentials(AzureProperties azureProperties) {
        if (azureProperties.isMsiEnabled()) {
            AzureTokenCredentials credentials = getMSIToken(azureProperties);
            credentials.withDefaultSubscriptionId(azureProperties.getSubscriptionId());
            return credentials;
        }

        try {
            DefaultResourceLoader resourceLoader = new DefaultResourceLoader();
            InputStream inputStream =
                resourceLoader.getResource(azureProperties.getCredentialFilePath()).getInputStream();
            File credentialFile = this.createTempCredentialFile(inputStream);

            return ApplicationTokenCredentials.fromFile(credentialFile);
        } catch (IOException e) {
            LOGGER.error("Credential file path not found.", e);
            throw new IllegalArgumentException("Credential file path not found", e);
        }
    }

    private boolean isAppService() {
        return StringUtils.hasText(System.getenv(ENV_MSI_ENDPOINT))
            && StringUtils.hasText(System.getenv(ENV_MSI_SECRET));
    }

    private AzureTokenCredentials getMSIToken(AzureProperties azureProperties) {
        AzureManagedIdentityProperties msiProps = azureProperties.getManagedIdentity();

        if (isAppService()) {
            AppServiceMSICredentials credentials = new AppServiceMSICredentials(azureProperties.getEnvironment());

            if (msiProps != null && StringUtils.hasText(msiProps.getClientId())) {
                credentials.withClientId(msiProps.getClientId());
            }

            return credentials;
        }

        MSICredentials msiCredentials = new MSICredentials();

        if (msiProps != null) {
            if (StringUtils.hasText(msiProps.getClientId())) {
                msiCredentials.withClientId(msiProps.getClientId());
            }

            if (StringUtils.hasText(msiProps.getObjectId())) {
                msiCredentials.withObjectId(msiProps.getObjectId());
            }
        }

        return msiCredentials;
    }

    @Override
    public AzureTokenCredentials getCredentials() {
        return this.credentials;
    }
}