| | 1 | | // Copyright (c) Microsoft Corporation. All rights reserved. |
| | 2 | | // Licensed under the MIT License. See License.txt in the project root for |
| | 3 | | // license information. |
| | 4 | |
|
| | 5 | | using System; |
| | 6 | | using System.Threading; |
| | 7 | | using System.Threading.Tasks; |
| | 8 | | using Microsoft.Azure.KeyVault.Core; |
| | 9 | |
|
| | 10 | | namespace Microsoft.Azure.KeyVault |
| | 11 | | { |
| | 12 | | /// <summary> |
| | 13 | | /// Azure Key Vault KeyResolver. This class resolves Key Vault Key Identifiers and |
| | 14 | | /// Secret Identifiers to implementations of IKey. Secret Identifiers can only |
| | 15 | | /// be resolved if the Secret is a byte array with a length matching one of the AES |
| | 16 | | /// key lengths (128, 192, 256) and the content-type of the secret is application/octet-stream. |
| | 17 | | /// </summary> |
| | 18 | | public class KeyVaultKeyResolver : IKeyResolver |
| | 19 | | { |
| | 20 | | private readonly IKeyVaultClient _client; |
| | 21 | | private readonly string _name; |
| | 22 | |
|
| | 23 | | /// <summary> |
| | 24 | | /// Creates a new Key Vault KeyResolver that uses a KeyVaultClient constructed |
| | 25 | | /// with the provided authentication callback. |
| | 26 | | /// </summary> |
| | 27 | | /// <param name="authenticationCallback">Key Vault authentication callback</param> |
| | 28 | | public KeyVaultKeyResolver( KeyVaultClient.AuthenticationCallback authenticationCallback ) |
| 0 | 29 | | : this( new KeyVaultClient( authenticationCallback ) ) |
| | 30 | | { |
| 0 | 31 | | } |
| | 32 | |
|
| | 33 | | /// <summary> |
| | 34 | | /// Create a new Key Vault KeyResolver that uses the specified KeyVaultClient |
| | 35 | | /// </summary> |
| | 36 | | /// <param name="client">Key Vault client</param> |
| 8 | 37 | | public KeyVaultKeyResolver( IKeyVaultClient client ) |
| | 38 | | { |
| 8 | 39 | | _name = null; |
| 8 | 40 | | _client = client ?? throw new ArgumentNullException( "client" ); |
| 8 | 41 | | } |
| | 42 | |
|
| | 43 | | /// <summary> |
| | 44 | | /// Creates a new Key Vault KeyResolver that uses a KeyVaultClient constructed |
| | 45 | | /// with the provided authentication callback and only resolves keys for the |
| | 46 | | /// specified key vault |
| | 47 | | /// </summary> |
| | 48 | | /// <param name="vaultName">The URL for the Key Vault, e.g. https://myvault.vault.azure.net/ </param> |
| | 49 | | /// <param name="authenticationCallback">Key Vault authentication callback</param> |
| | 50 | | public KeyVaultKeyResolver( string vaultName, KeyVaultClient.AuthenticationCallback authenticationCallback ) |
| 0 | 51 | | : this( vaultName, new KeyVaultClient( authenticationCallback ) ) |
| | 52 | | { |
| 0 | 53 | | } |
| | 54 | |
|
| | 55 | | /// <summary> |
| | 56 | | /// Creates a new Key Vault KeyResolver that uses the specified KeyVaultClient |
| | 57 | | /// and only resolves keys for the specified key vault |
| | 58 | | /// </summary> |
| | 59 | | /// <param name="vaultName">The URL for the Key Vault, e.g. https://myvault.vault.azure.net/ </param> |
| | 60 | | /// <param name="client">Key Vault client</param> |
| 8 | 61 | | public KeyVaultKeyResolver( string vaultName, IKeyVaultClient client ) |
| | 62 | | { |
| 8 | 63 | | if ( string.IsNullOrWhiteSpace( vaultName ) ) |
| 0 | 64 | | throw new ArgumentNullException( "vaultName" ); |
| | 65 | |
|
| 8 | 66 | | if ( client == null ) |
| 0 | 67 | | throw new ArgumentNullException( "client" ); |
| | 68 | |
|
| 8 | 69 | | _name = NormalizeVaultName( vaultName ); |
| 8 | 70 | | _client = client; |
| 8 | 71 | | } |
| | 72 | |
|
| | 73 | | #region IKeyResolver |
| | 74 | |
|
| | 75 | | /// <summary> |
| | 76 | | /// Provides an IKey implementation for the specified key or secret identifier. |
| | 77 | | /// </summary> |
| | 78 | | /// <param name="kid">The key or secret identifier to resolve</param> |
| | 79 | | /// <param name="token">Cancellation token</param> |
| | 80 | | /// <returns>The resolved IKey implementation or null</returns> |
| | 81 | | public async Task<IKey> ResolveKeyAsync( string kid, CancellationToken token ) |
| | 82 | | { |
| 32 | 83 | | if ( string.IsNullOrWhiteSpace( kid ) ) |
| 0 | 84 | | throw new ArgumentNullException( "kid" ); |
| | 85 | |
|
| | 86 | | // If the resolver has a name prefix, only handle kid that have that prefix. |
| 32 | 87 | | if ( _name != null ) |
| | 88 | | { |
| 16 | 89 | | var vaultUrl = new Uri( _name ); |
| 16 | 90 | | var keyUrl = new Uri( kid ); |
| | 91 | |
|
| 16 | 92 | | if ( string.Compare( vaultUrl.Scheme, keyUrl.Scheme, true ) != 0 || string.Compare( vaultUrl.Authority, |
| 0 | 93 | | return null; |
| | 94 | | } |
| | 95 | |
|
| 32 | 96 | | if ( KeyIdentifier.IsKeyIdentifier( kid ) ) |
| 8 | 97 | | return await ResolveKeyFromKeyAsync( kid, token ).ConfigureAwait( false ); |
| | 98 | |
|
| 24 | 99 | | if ( SecretIdentifier.IsSecretIdentifier( kid ) ) |
| 24 | 100 | | return await ResolveKeyFromSecretAsync( kid, token ).ConfigureAwait( false ); |
| | 101 | |
|
| | 102 | | // Return null rather than throw an exception here |
| 0 | 103 | | return null; |
| 32 | 104 | | } |
| | 105 | |
|
| | 106 | | #endregion |
| | 107 | |
|
| | 108 | | private string NormalizeVaultName( string vaultName ) |
| | 109 | | { |
| 8 | 110 | | Uri vaultUri = new Uri( vaultName, UriKind.Absolute ); |
| | 111 | |
|
| 8 | 112 | | if ( string.Compare(vaultUri.Scheme, "https", true) != 0 ) |
| 0 | 113 | | throw new ArgumentException( "The vaultName must use the https scheme" ); |
| | 114 | |
|
| 8 | 115 | | if ( string.CompareOrdinal( vaultUri.PathAndQuery, "/" ) != 0 ) |
| 0 | 116 | | throw new ArgumentException( "The vaultName cannot contain a path or query string" ); |
| | 117 | |
|
| 8 | 118 | | return vaultUri.AbsoluteUri; |
| | 119 | | } |
| | 120 | |
|
| | 121 | |
|
| | 122 | | private Task<IKey> ResolveKeyFromKeyAsync( string kid, CancellationToken token ) |
| | 123 | | { |
| | 124 | | // KeyVaultClient is thread-safe |
| 8 | 125 | | return _client.GetKeyAsync( kid, token ) |
| 8 | 126 | | .ContinueWith<IKey>( task => |
| 8 | 127 | | { |
| 16 | 128 | | var keyBundle = task.Result; |
| 8 | 129 | |
|
| 16 | 130 | | if ( keyBundle != null ) |
| 8 | 131 | | { |
| 16 | 132 | | return new KeyVaultKey( _client, keyBundle ); |
| 8 | 133 | | } |
| 8 | 134 | |
|
| 0 | 135 | | return null; |
| 8 | 136 | | }, token ); |
| | 137 | | } |
| | 138 | |
|
| | 139 | | private Task<IKey> ResolveKeyFromSecretAsync( string sid, CancellationToken token ) |
| | 140 | | { |
| | 141 | | // KeyVaultClient is thread-safe |
| 24 | 142 | | return _client.GetSecretAsync( sid, token ) |
| 24 | 143 | | .ContinueWith<IKey>( task => |
| 24 | 144 | | { |
| 48 | 145 | | var secret = task.Result; |
| 24 | 146 | |
|
| 48 | 147 | | if ( secret != null && string.Equals( secret.ContentType, "application/octet-stream", StringComparis |
| 24 | 148 | | { |
| 48 | 149 | | var keyBytes = FromBase64UrlString( secret.Value ); |
| 24 | 150 | |
|
| 48 | 151 | | if ( keyBytes != null ) |
| 24 | 152 | | { |
| 48 | 153 | | return new SymmetricKey( secret.Id, keyBytes ); |
| 24 | 154 | | } |
| 24 | 155 | | } |
| 24 | 156 | |
|
| 0 | 157 | | return null; |
| 24 | 158 | | }, token ); |
| | 159 | | } |
| | 160 | |
|
| | 161 | | /// <summary> |
| | 162 | | /// Converts a Base64 or Base64Url encoded string to a byte array |
| | 163 | | /// </summary> |
| | 164 | | /// <param name="input">The Base64Url encoded string</param> |
| | 165 | | /// <returns>The byte array represented by the enconded string</returns> |
| | 166 | | private static byte[] FromBase64UrlString( string input ) |
| | 167 | | { |
| 24 | 168 | | if ( string.IsNullOrEmpty( input ) ) |
| 0 | 169 | | throw new ArgumentNullException( "input" ); |
| | 170 | |
|
| 24 | 171 | | return Convert.FromBase64String( Pad( input.Replace( '-', '+' ).Replace( '_', '/' ) ) ); |
| | 172 | | } |
| | 173 | |
|
| | 174 | | /// <summary> |
| | 175 | | /// Adds padding to the input |
| | 176 | | /// </summary> |
| | 177 | | /// <param name="input"> the input string </param> |
| | 178 | | /// <returns> the padded string </returns> |
| | 179 | | private static string Pad( string input ) |
| | 180 | | { |
| 24 | 181 | | var count = 3 - ( ( input.Length + 3 ) % 4 ); |
| | 182 | |
|
| 24 | 183 | | if ( count == 0 ) |
| | 184 | | { |
| 24 | 185 | | return input; |
| | 186 | | } |
| | 187 | |
|
| 0 | 188 | | return input + new string( '=', count ); |
| | 189 | | } |
| | 190 | | } |
| | 191 | | } |