KeyVaultOperation.java
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
package com.azure.spring.keyvault;
import com.azure.core.exception.HttpRequestException;
import com.azure.core.exception.ResourceNotFoundException;
import com.azure.core.http.rest.PagedResponse;
import com.azure.core.http.rest.Response;
import com.azure.core.util.Context;
import com.azure.core.util.paging.ContinuablePagedIterable;
import com.azure.security.keyvault.secrets.SecretClient;
import com.azure.security.keyvault.secrets.models.KeyVaultSecret;
import com.azure.security.keyvault.secrets.models.SecretProperties;
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Timer;
import java.util.TimerTask;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.lang.NonNull;
/**
* KeyVaultOperation wraps the operations to access Key Vault.
*/
@SuppressFBWarnings("ST_WRITE_TO_STATIC_FROM_INSTANCE_METHOD")
public class KeyVaultOperation {
private static final Logger LOG = LoggerFactory.getLogger(KeyVaultOperation.class);
/**
* Stores the case sensitive flag.
*/
private final boolean caseSensitive;
/**
* Stores the properties.
*/
private Map<String, String> properties = new HashMap<>();
/**
* Stores the secret client.
*/
private final SecretClient secretClient;
/**
* Stores the secret keys.
*/
private final List<String> secretKeys;
/**
* Stores the timer object to schedule refresh task.
*/
private static Timer timer;
/**
* Constructor.
* @param secretClient the Key Vault secret client.
* @param refreshInMillis the refresh in milliseconds (0 or less disables refresh).
* @param secretKeys the secret keys to look for.
* @param caseSensitive the case sensitive flag.
*/
public KeyVaultOperation(
final SecretClient secretClient,
final long refreshInMillis,
List<String> secretKeys,
boolean caseSensitive
) {
this.caseSensitive = caseSensitive;
this.secretClient = secretClient;
this.secretKeys = secretKeys;
refreshProperties();
if (refreshInMillis > 0) {
synchronized (KeyVaultOperation.class) {
if (timer != null) {
try {
timer.cancel();
timer.purge();
} catch (RuntimeException runtimeException) {
LOG.error("Error of terminating Timer", runtimeException);
}
}
timer = new Timer();
final TimerTask task = new TimerTask() {
@Override
public void run() {
refreshProperties();
}
};
timer.scheduleAtFixedRate(task, refreshInMillis, refreshInMillis);
}
}
}
/**
* Get the property.
*
* @param property the property to get.
* @return the property value.
*/
public String getProperty(String property) {
return properties.get(toKeyVaultSecretName(property));
}
/**
* Get the property names.
*
* @return the property names.
*/
public String[] getPropertyNames() {
if (!caseSensitive) {
return properties
.keySet()
.stream()
.flatMap(p -> Stream.of(p, p.replaceAll("-", ".")))
.distinct()
.toArray(String[]::new);
} else {
return properties
.keySet()
.toArray(new String[0]);
}
}
/**
* Refresh the properties by accessing key vault.
*/
private void refreshProperties() {
if (secretKeys == null || secretKeys.isEmpty()) {
properties = Optional.of(secretClient)
.map(SecretClient::listPropertiesOfSecrets)
.map(ContinuablePagedIterable::iterableByPage)
.map(i -> StreamSupport.stream(i.spliterator(), false))
.orElseGet(Stream::empty)
.map(PagedResponse::getElements)
.flatMap(i -> StreamSupport.stream(i.spliterator(), false))
.filter(SecretProperties::isEnabled)
.map(p -> secretClient.getSecret(p.getName(), p.getVersion()))
.filter(Objects::nonNull)
.collect(Collectors.toMap(
s -> toKeyVaultSecretName(s.getName()),
KeyVaultSecret::getValue
));
} else {
properties = secretKeys.stream()
.map(this::toKeyVaultSecretName)
.map(secretClient::getSecret)
.filter(Objects::nonNull)
.collect(Collectors.toMap(
s -> toKeyVaultSecretName(s.getName()),
KeyVaultSecret::getValue
));
}
}
/**
* For convention we need to support all relaxed binding format from spring, these may include:
* <table>
* <tr><td>Spring relaxed binding names</td></tr>
* <tr><td>acme.my-project.person.first-name</td></tr>
* <tr><td>acme.myProject.person.firstName</td></tr>
* <tr><td>acme.my_project.person.first_name</td></tr>
* <tr><td>ACME_MYPROJECT_PERSON_FIRSTNAME</td></tr>
* </table>
* But azure keyvault only allows ^[0-9a-zA-Z-]+$ and case insensitive, so
* there must be some conversion between spring names and azure keyvault
* names. For example, the 4 properties stated above should be convert to
* acme-myproject-person-firstname in keyvault.
*
* @param property of secret instance.
* @return the value of secret with given name or null.
*/
private String toKeyVaultSecretName(@NonNull String property) {
if (!caseSensitive) {
if (property.matches("[a-z0-9A-Z-]+")) {
return property.toLowerCase(Locale.US);
} else if (property.matches("[A-Z0-9_]+")) {
return property.toLowerCase(Locale.US).replaceAll("_", "-");
} else {
return property.toLowerCase(Locale.US)
.replaceAll("-", "") // my-project -> myproject
.replaceAll("_", "") // my_project -> myproject
.replaceAll("\\.", "-"); // acme.myproject -> acme-myproject
}
} else {
return property;
}
}
/**
* Set the properties.
*
* @param properties the properties.
*/
void setProperties(HashMap<String, String> properties) {
this.properties = properties;
}
boolean isUp() {
boolean result;
try {
final Response<KeyVaultSecret> response = secretClient
.getSecretWithResponse("it-is-ok-to-be-empty", null, Context.NONE);
result = response.getStatusCode() < 500;
} catch (ResourceNotFoundException resourceNotFoundException) {
result = true;
} catch (HttpRequestException httpRequestException) {
LOG.error("An HTTP error occurred while checking key vault connectivity", httpRequestException);
result = true;
} catch (RuntimeException runtimeException) {
LOG.error("A runtime error occurred while checking key vault connectivity", runtimeException);
result = false;
}
return result;
}
}