KeyVaultEnvironmentPostProcessorHelper.java

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package com.azure.spring.keyvault;

import com.azure.core.credential.TokenCredential;
import com.azure.core.http.policy.HttpLogOptions;
import com.azure.identity.ClientCertificateCredentialBuilder;
import com.azure.identity.ClientSecretCredentialBuilder;
import com.azure.identity.ManagedIdentityCredentialBuilder;
import com.azure.identity.implementation.IdentityClientOptions;
import com.azure.security.keyvault.secrets.SecretClient;
import com.azure.security.keyvault.secrets.SecretClientBuilder;
import com.azure.security.keyvault.secrets.SecretServiceVersion;
import com.azure.spring.keyvault.KeyVaultProperties.Property;
import com.azure.spring.telemetry.TelemetrySender;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.boot.context.properties.bind.Bindable;
import org.springframework.boot.context.properties.bind.Binder;
import org.springframework.core.env.ConfigurableEnvironment;
import org.springframework.core.env.MutablePropertySources;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.StringUtils;

import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;

import static com.azure.spring.telemetry.TelemetryData.SERVICE_NAME;
import static com.azure.spring.telemetry.TelemetryData.getClassPackageSimpleName;
import static com.azure.spring.utils.ApplicationId.AZURE_SPRING_KEY_VAULT;
import static com.azure.spring.utils.Constants.AZURE_KEYVAULT_PROPERTYSOURCE_NAME;
import static com.azure.spring.utils.Constants.DEFAULT_REFRESH_INTERVAL_MS;
import static org.springframework.core.env.StandardEnvironment.SYSTEM_ENVIRONMENT_PROPERTY_SOURCE_NAME;

/**
 * A helper class to initialize the key vault secret client depending on which authentication method users choose. Then
 * add key vault as a property source to the environment.
 */
class KeyVaultEnvironmentPostProcessorHelper {

    private static final Logger LOGGER = LoggerFactory.getLogger(KeyVaultEnvironmentPostProcessorHelper.class);
    private static final String DEFAULT_AUTHORITY_HOST = new IdentityClientOptions().getAuthorityHost();
    private final ConfigurableEnvironment environment;

    KeyVaultEnvironmentPostProcessorHelper(final ConfigurableEnvironment environment) {
        this.environment = environment;
        Assert.notNull(environment, "environment must not be null!");
        // As @PostConstructor not available when post processor, call it explicitly.
        sendTelemetry();
    }

    /**
     * Add a key vault property source.
     *
     * <p>
     * The normalizedName is used to target a specific key vault (note if the name is the empty string it works as
     * before with only one key vault present). The normalized name is the name of the specific key vault plus a
     * trailing "." at the end.
     * </p>
     *
     * @param normalizedName The normalized name.
     * @throws IllegalStateException If KeyVaultOperations fails to initialize.
     */
    public void addKeyVaultPropertySource(String normalizedName) {
        final String vaultUri = getPropertyValue(normalizedName, Property.URI);
        final String version = getPropertyValue(normalizedName, Property.SECRET_SERVICE_VERSION);
        SecretServiceVersion secretServiceVersion = Arrays.stream(SecretServiceVersion.values())
                                                          .filter(val -> val.getVersion().equals(version))
                                                          .findFirst()
                                                          .orElse(null);
        Assert.notNull(vaultUri, "vaultUri must not be null!");
        final Long refreshInterval = Optional.ofNullable(getPropertyValue(normalizedName, Property.REFRESH_INTERVAL))
                .map(Long::valueOf)
                .orElse(DEFAULT_REFRESH_INTERVAL_MS);
        final List<String> secretKeys = Binder.get(this.environment)
                .bind(
                        KeyVaultProperties.getPropertyName(normalizedName, Property.SECRET_KEYS),
                        Bindable.listOf(String.class)
                )
                .orElse(Collections.emptyList());

        final TokenCredential tokenCredential = getCredentials(normalizedName);
        final SecretClient secretClient = new SecretClientBuilder()
                .vaultUrl(vaultUri)
                .credential(tokenCredential)
                .serviceVersion(secretServiceVersion)
                .httpLogOptions(new HttpLogOptions().setApplicationId(AZURE_SPRING_KEY_VAULT))
                .buildClient();
        try {
            final MutablePropertySources sources = this.environment.getPropertySources();
            final boolean caseSensitive = Boolean
                    .parseBoolean(getPropertyValue(normalizedName, Property.CASE_SENSITIVE_KEYS));
            final KeyVaultOperation keyVaultOperation = new KeyVaultOperation(
                    secretClient,
                    refreshInterval,
                    secretKeys,
                    caseSensitive);

            String propertySourceName = Optional.of(normalizedName)
                    .map(String::trim)
                    .filter(s -> !s.isEmpty())
                    .orElse(AZURE_KEYVAULT_PROPERTYSOURCE_NAME);
            KeyVaultPropertySource keyVaultPropertySource =
                    new KeyVaultPropertySource(propertySourceName, keyVaultOperation);
            if (sources.contains(SYSTEM_ENVIRONMENT_PROPERTY_SOURCE_NAME)) {
                sources.addAfter(
                        SYSTEM_ENVIRONMENT_PROPERTY_SOURCE_NAME,
                        keyVaultPropertySource
                );
            } else {
                sources.addFirst(keyVaultPropertySource);
            }

        } catch (final Exception ex) {
            throw new IllegalStateException("Failed to configure KeyVault property source", ex);
        }
    }

    /**
     * Get the token credentials.
     *
     * @return the token credentials.
     */
    public TokenCredential getCredentials() {
        return getCredentials("");
    }

    /**
     * Get the token credentials.
     *
     * @param normalizedName the normalized name of the key vault.
     * @return the token credentials.
     */
    public TokenCredential getCredentials(String normalizedName) {
        //use service principle to authenticate
        final String clientId = getPropertyValue(normalizedName, Property.CLIENT_ID);
        final String clientKey = getPropertyValue(normalizedName, Property.CLIENT_KEY);
        final String tenantId = getPropertyValue(normalizedName, Property.TENANT_ID);
        final String certificatePath = getPropertyValue(normalizedName, Property.CERTIFICATE_PATH);
        final String certificatePassword = getPropertyValue(normalizedName, Property.CERTIFICATE_PASSWORD);
        final String authorityHost = getPropertyValue(normalizedName, Property.AUTHORITY_HOST, DEFAULT_AUTHORITY_HOST);
        if (clientId != null && tenantId != null && clientKey != null) {
            LOGGER.debug("Will use custom credentials");
            return new ClientSecretCredentialBuilder()
                    .clientId(clientId)
                    .clientSecret(clientKey)
                    .tenantId(tenantId)
                    .authorityHost(authorityHost)
                    .build();
        }
        // Use certificate to authenticate
        // Password can be empty
        if (clientId != null && tenantId != null && certificatePath != null) {
            if (StringUtils.isEmpty(certificatePassword)) {
                return new ClientCertificateCredentialBuilder()
                        .tenantId(tenantId)
                        .clientId(clientId)
                        .pemCertificate(certificatePath)
                        .authorityHost(authorityHost)
                        .build();
            } else {
                return new ClientCertificateCredentialBuilder()
                        .tenantId(tenantId)
                        .clientId(clientId)
                        .authorityHost(authorityHost)
                        .pfxCertificate(certificatePath, certificatePassword)
                        .build();
            }
        }
        //use MSI to authenticate
        if (clientId != null) {
            LOGGER.debug("Will use MSI credentials with specified clientId");
            return new ManagedIdentityCredentialBuilder().clientId(clientId).build();
        }
        LOGGER.debug("Will use MSI credentials");
        return new ManagedIdentityCredentialBuilder().build();
    }

    private String getPropertyValue(final Property property) {
        return Optional.of(property)
                .map(KeyVaultProperties::getPropertyName)
                .map(environment::getProperty)
                .orElse(null);
    }

    private String getPropertyValue(final String normalizedName, final Property property) {
        return getPropertyValue(normalizedName, property, null);
    }

    private String getPropertyValue(final String normalizedName, final Property property, String defaultValue) {
        return Optional.of(KeyVaultProperties.getPropertyName(normalizedName, property))
            .map(environment::getProperty)
            .orElse(defaultValue);
    }

    private boolean allowTelemetry() {
        return Boolean.parseBoolean(getPropertyValue(Property.ALLOW_TELEMETRY));
    }

    private void sendTelemetry() {
        if (allowTelemetry()) {
            final Map<String, String> events = new HashMap<>();
            final TelemetrySender sender = new TelemetrySender();

            events.put(SERVICE_NAME, getClassPackageSimpleName(KeyVaultEnvironmentPostProcessorHelper.class));

            sender.send(ClassUtils.getUserClass(getClass()).getSimpleName(), events);
        }
    }
}