IdentitySslUtil.java
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
package com.azure.identity.implementation.util;
import com.azure.core.util.logging.ClientLogger;
import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.HttpsURLConnection;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSession;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509TrustManager;
import java.security.KeyManagementException;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.cert.Certificate;
import java.security.cert.CertificateEncodingException;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
public final class IdentitySslUtil {
public static final HostnameVerifier ALL_HOSTS_ACCEPT_HOSTNAME_VERIFIER;
static {
ALL_HOSTS_ACCEPT_HOSTNAME_VERIFIER = new HostnameVerifier() {
@SuppressWarnings("BadHostnameVerifier")
@Override
public boolean verify(String hostname, SSLSession session) {
return true;
}
};
}
private IdentitySslUtil() { }
/**
*
* Pins the specified HTTPS URL Connection to work against a specific server-side certificate with
* the specified thumbprint only.
*
* @param className The class calling the method.
* @param httpsUrlConnection The https url connection to configure
* @param certificateThumbprint The thumbprint of the certificate
*/
public static void addTrustedCertificateThumbprint(String className, HttpsURLConnection httpsUrlConnection,
String certificateThumbprint) {
ClientLogger logger = new ClientLogger(className);
//We expect the connection to work against a specific server side certificate only, so its safe to disable the
// host name verification.
if (httpsUrlConnection.getHostnameVerifier() != ALL_HOSTS_ACCEPT_HOSTNAME_VERIFIER) {
httpsUrlConnection.setHostnameVerifier(ALL_HOSTS_ACCEPT_HOSTNAME_VERIFIER);
}
// Create a Trust manager that trusts only certificate with specified thumbprint.
TrustManager[] certificateTrust = new TrustManager[]{new X509TrustManager() {
public X509Certificate[] getAcceptedIssuers() {
return new X509Certificate[]{};
}
public void checkClientTrusted(X509Certificate[] certificates, String authenticationType)
throws CertificateException {
throw logger.logExceptionAsError(new RuntimeException("No client side certificate configured."));
}
public void checkServerTrusted(X509Certificate[] certificates, String authenticationType)
throws CertificateException {
if (certificates == null || certificates.length == 0) {
throw logger.logExceptionAsError(
new RuntimeException("Did not receive any certificate from the server."));
}
for (X509Certificate x509Certificate : certificates) {
String sslCertificateThumbprint = extractCertificateThumbprint(x509Certificate, logger);
if (certificateThumbprint.equalsIgnoreCase(sslCertificateThumbprint)) {
return;
}
}
throw logger.logExceptionAsError(new RuntimeException(
"Thumbprint of certificates receieved did not match the "
+ "expected thumbprint."));
}
}
};
SSLSocketFactory sslSocketFactory;
try {
SSLContext sslContext = SSLContext.getInstance("TLS");
sslContext.init(null, certificateTrust, null);
sslSocketFactory = sslContext.getSocketFactory();
} catch (NoSuchAlgorithmException | KeyManagementException e) {
throw logger.logExceptionAsError(new RuntimeException("Error Creating SSL Context", e));
}
// Pin the connection to a specific certificate with specified thumbprint.
if (httpsUrlConnection.getSSLSocketFactory() != sslSocketFactory) {
httpsUrlConnection.setSSLSocketFactory(sslSocketFactory);
}
}
private static String extractCertificateThumbprint(Certificate certificate, ClientLogger logger) {
try {
StringBuffer thumbprint = new StringBuffer();
MessageDigest messageDigest;
messageDigest = MessageDigest.getInstance("SHA-1");
byte[] encodedCertificate;
try {
encodedCertificate = certificate.getEncoded();
} catch (CertificateEncodingException e) {
throw new RuntimeException(e);
}
byte[] updatedDigest = messageDigest.digest(encodedCertificate);
for (int i = 0; i < updatedDigest.length; i++) {
int unsignedByte = updatedDigest[i] & 0xff;
if (unsignedByte < 16) {
thumbprint.append("0");
}
thumbprint.append(Integer.toHexString(unsignedByte));
}
return thumbprint.toString();
} catch (NoSuchAlgorithmException e) {
throw logger.logExceptionAsError(new RuntimeException(e));
}
}
}