| | | 1 | | // Copyright (c) Microsoft Corporation. All rights reserved. |
| | | 2 | | // Licensed under the MIT License. |
| | | 3 | | |
| | | 4 | | using Azure.Core; |
| | | 5 | | using Azure.Core.Pipeline; |
| | | 6 | | using System; |
| | | 7 | | using System.Collections.Generic; |
| | | 8 | | using System.Globalization; |
| | | 9 | | using System.Threading.Tasks; |
| | | 10 | | |
| | | 11 | | namespace Azure.Security.KeyVault |
| | | 12 | | { |
| | | 13 | | internal class ChallengeBasedAuthenticationPolicy : HttpPipelinePolicy |
| | | 14 | | { |
| | | 15 | | private const string BearerChallengePrefix = "Bearer "; |
| | | 16 | | |
| | | 17 | | private readonly TokenCredential _credential; |
| | | 18 | | |
| | | 19 | | private AuthenticationChallenge _challenge = null; |
| | | 20 | | private string _headerValue; |
| | | 21 | | private DateTimeOffset _refreshOn; |
| | | 22 | | |
| | 172 | 23 | | public ChallengeBasedAuthenticationPolicy(TokenCredential credential) |
| | | 24 | | { |
| | 172 | 25 | | _credential = credential; |
| | 172 | 26 | | } |
| | | 27 | | |
| | | 28 | | public override void Process(HttpMessage message, ReadOnlyMemory<HttpPipelinePolicy> pipeline) |
| | | 29 | | { |
| | 414 | 30 | | ProcessCoreAsync(message, pipeline, false).EnsureCompleted(); |
| | 412 | 31 | | } |
| | | 32 | | |
| | | 33 | | public override ValueTask ProcessAsync(HttpMessage message, ReadOnlyMemory<HttpPipelinePolicy> pipeline) |
| | | 34 | | { |
| | 476 | 35 | | return ProcessCoreAsync(message, pipeline, true); |
| | | 36 | | } |
| | | 37 | | |
| | | 38 | | private async ValueTask ProcessCoreAsync(HttpMessage message, ReadOnlyMemory<HttpPipelinePolicy> pipeline, bool |
| | | 39 | | { |
| | 890 | 40 | | if (message.Request.Uri.Scheme != Uri.UriSchemeHttps) |
| | | 41 | | { |
| | 4 | 42 | | throw new InvalidOperationException("Bearer token authentication is not permitted for non TLS protected |
| | | 43 | | } |
| | | 44 | | |
| | 886 | 45 | | RequestContent originalContent = message.Request.Content; |
| | | 46 | | |
| | | 47 | | // if this policy doesn't have _challenge cached try to get it from the static challenge cache |
| | 886 | 48 | | AuthenticationChallenge challenge = _challenge ?? AuthenticationChallenge.GetChallenge(message); |
| | | 49 | | |
| | | 50 | | // if we still don't have the challenge for the endpoint |
| | | 51 | | // remove the content from the request and send without authentication to get the challenge |
| | 886 | 52 | | if (challenge == null) |
| | | 53 | | { |
| | 84 | 54 | | message.Request.Content = null; |
| | | 55 | | } |
| | | 56 | | // otherwise if we already know the challenge authenticate the request |
| | | 57 | | else |
| | | 58 | | { |
| | 802 | 59 | | await AuthenticateRequestAsync(message, async, challenge).ConfigureAwait(false); |
| | | 60 | | } |
| | | 61 | | |
| | 886 | 62 | | if (async) |
| | | 63 | | { |
| | 474 | 64 | | await ProcessNextAsync(message, pipeline).ConfigureAwait(false); |
| | | 65 | | } |
| | | 66 | | else |
| | | 67 | | { |
| | 412 | 68 | | ProcessNext(message, pipeline); |
| | | 69 | | } |
| | | 70 | | |
| | | 71 | | // if we get a 401 |
| | 886 | 72 | | if (message.Response.Status == 401) |
| | | 73 | | { |
| | | 74 | | // set the content to the original content in case it was cleared |
| | 84 | 75 | | message.Request.Content = originalContent; |
| | | 76 | | |
| | | 77 | | // update the cached challenge |
| | 84 | 78 | | challenge = AuthenticationChallenge.GetChallenge(message); |
| | | 79 | | |
| | 84 | 80 | | if (challenge != null) |
| | | 81 | | { |
| | | 82 | | // update the cached challenge if not yet set or different from the current challenge (e.g. moved te |
| | 84 | 83 | | if (_challenge == null || !challenge.Equals(_challenge)) |
| | | 84 | | { |
| | 84 | 85 | | _challenge = challenge; |
| | | 86 | | } |
| | | 87 | | |
| | | 88 | | // authenticate the request and resend |
| | 84 | 89 | | await AuthenticateRequestAsync(message, async, challenge).ConfigureAwait(false); |
| | | 90 | | |
| | 84 | 91 | | if (async) |
| | | 92 | | { |
| | 42 | 93 | | await ProcessNextAsync(message, pipeline).ConfigureAwait(false); |
| | | 94 | | } |
| | | 95 | | else |
| | | 96 | | { |
| | 42 | 97 | | ProcessNext(message, pipeline); |
| | | 98 | | } |
| | | 99 | | } |
| | | 100 | | } |
| | 886 | 101 | | } |
| | | 102 | | |
| | | 103 | | private async Task AuthenticateRequestAsync(HttpMessage message, bool async, AuthenticationChallenge challenge) |
| | | 104 | | { |
| | 886 | 105 | | if (_headerValue is null || DateTimeOffset.UtcNow >= _refreshOn) |
| | | 106 | | { |
| | 84 | 107 | | AccessToken token = async ? |
| | 84 | 108 | | await _credential.GetTokenAsync(new TokenRequestContext(challenge.Scopes, message.Request.Client |
| | 84 | 109 | | _credential.GetToken(new TokenRequestContext(challenge.Scopes, message.Request.ClientRequestId), |
| | | 110 | | |
| | 84 | 111 | | _headerValue = BearerChallengePrefix + token.Token; |
| | 84 | 112 | | _refreshOn = token.ExpiresOn - TimeSpan.FromMinutes(2); |
| | | 113 | | } |
| | | 114 | | |
| | 886 | 115 | | message.Request.Headers.SetValue(HttpHeader.Names.Authorization, _headerValue); |
| | 886 | 116 | | } |
| | | 117 | | |
| | | 118 | | internal class AuthenticationChallenge |
| | | 119 | | { |
| | 2 | 120 | | private static readonly Dictionary<string, AuthenticationChallenge> s_cache = new Dictionary<string, Authent |
| | 2 | 121 | | private static readonly object s_cacheLock = new object(); |
| | 2 | 122 | | private static readonly string[] s_challengeDelimiters = new string[] { "," }; |
| | | 123 | | |
| | 84 | 124 | | private AuthenticationChallenge(string authority, string scope) |
| | | 125 | | { |
| | 84 | 126 | | Authority = authority; |
| | 84 | 127 | | Scopes = new string[] { scope }; |
| | 84 | 128 | | } |
| | | 129 | | |
| | 0 | 130 | | public string Authority { get; } |
| | | 131 | | |
| | 84 | 132 | | public string[] Scopes { get; } |
| | | 133 | | |
| | | 134 | | public override bool Equals(object obj) |
| | | 135 | | { |
| | 0 | 136 | | if (ReferenceEquals(this, obj)) |
| | | 137 | | { |
| | 0 | 138 | | return true; |
| | | 139 | | } |
| | | 140 | | |
| | | 141 | | // This assumes that Authority Scopes are always non-null and Scopes has a length of one. |
| | | 142 | | // This is guaranteed by the way the AuthenticationChallenge cache is constructed. |
| | 0 | 143 | | if (obj is AuthenticationChallenge other) |
| | | 144 | | { |
| | 0 | 145 | | return string.Equals(Authority, other.Authority, StringComparison.OrdinalIgnoreCase) |
| | 0 | 146 | | && string.Equals(Scopes[0], other.Scopes[0], StringComparison.OrdinalIgnoreCase); |
| | | 147 | | } |
| | | 148 | | |
| | 0 | 149 | | return false; |
| | | 150 | | } |
| | | 151 | | |
| | | 152 | | public override int GetHashCode() |
| | | 153 | | { |
| | | 154 | | // Currently the hash code is simply the hash of the authority and first scope as this is what is used t |
| | | 155 | | // This assumes that Authority Scopes are always non-null and Scopes has a length of one. |
| | | 156 | | // This is guaranteed by the way the AuthenticationChallenge cache is constructed. |
| | 0 | 157 | | return HashCodeBuilder.Combine(Authority, Scopes[0]); |
| | | 158 | | } |
| | | 159 | | |
| | | 160 | | public static AuthenticationChallenge GetChallenge(HttpMessage message) |
| | | 161 | | { |
| | 168 | 162 | | AuthenticationChallenge challenge = null; |
| | | 163 | | |
| | 168 | 164 | | if (message.HasResponse) |
| | | 165 | | { |
| | 84 | 166 | | challenge = GetChallengeFromResponse(message.Response); |
| | | 167 | | |
| | | 168 | | // if the challenge is non-null cache it |
| | 84 | 169 | | if (challenge != null) |
| | | 170 | | { |
| | 84 | 171 | | string authority = GetRequestAuthority(message.Request); |
| | 84 | 172 | | lock (s_cacheLock) |
| | | 173 | | { |
| | 84 | 174 | | s_cache[authority] = challenge; |
| | 84 | 175 | | } |
| | | 176 | | } |
| | | 177 | | } |
| | | 178 | | else |
| | | 179 | | { |
| | | 180 | | // try to get the challenge from the cache |
| | 84 | 181 | | string authority = GetRequestAuthority(message.Request); |
| | 84 | 182 | | lock (s_cacheLock) |
| | | 183 | | { |
| | 84 | 184 | | s_cache.TryGetValue(authority, out challenge); |
| | 84 | 185 | | } |
| | | 186 | | } |
| | | 187 | | |
| | 168 | 188 | | return challenge; |
| | | 189 | | } |
| | | 190 | | |
| | | 191 | | internal static void ClearCache() |
| | | 192 | | { |
| | | 193 | | // try to get the challenge from the cache |
| | 84 | 194 | | lock (s_cacheLock) |
| | | 195 | | { |
| | 84 | 196 | | s_cache.Clear(); |
| | 84 | 197 | | } |
| | 84 | 198 | | } |
| | | 199 | | |
| | | 200 | | private static AuthenticationChallenge GetChallengeFromResponse(Response response) |
| | | 201 | | { |
| | 84 | 202 | | AuthenticationChallenge challenge = null; |
| | | 203 | | |
| | 84 | 204 | | if (response.Headers.TryGetValue("WWW-Authenticate", out string challengeValue) && challengeValue.Starts |
| | | 205 | | { |
| | 84 | 206 | | challenge = ParseBearerChallengeHeaderValue(challengeValue); |
| | | 207 | | } |
| | | 208 | | |
| | 84 | 209 | | return challenge; |
| | | 210 | | } |
| | | 211 | | |
| | | 212 | | private static AuthenticationChallenge ParseBearerChallengeHeaderValue(string challengeValue) |
| | | 213 | | { |
| | 84 | 214 | | string authority = null; |
| | 84 | 215 | | string scope = null; |
| | | 216 | | |
| | | 217 | | // remove the bearer challenge prefix |
| | 84 | 218 | | var trimmedChallenge = challengeValue.Substring(BearerChallengePrefix.Length); |
| | | 219 | | |
| | | 220 | | // Split the trimmed challenge into a set of name=value strings that |
| | | 221 | | // are comma separated. The value fields are expected to be within |
| | | 222 | | // quotation characters that are stripped here. |
| | 84 | 223 | | string[] pairs = trimmedChallenge.Split(s_challengeDelimiters, StringSplitOptions.RemoveEmptyEntries); |
| | | 224 | | |
| | 84 | 225 | | if (pairs.Length > 0) |
| | | 226 | | { |
| | | 227 | | // Process the name=value string |
| | 504 | 228 | | for (int i = 0; i < pairs.Length; i++) |
| | | 229 | | { |
| | 168 | 230 | | string[] pair = pairs[i].Split('='); |
| | | 231 | | |
| | 168 | 232 | | if (pair.Length == 2) |
| | | 233 | | { |
| | | 234 | | // We have a key and a value, now need to trim and decode |
| | 168 | 235 | | string key = pair[0].AsSpan().Trim().Trim('\"').ToString(); |
| | 168 | 236 | | string value = pair[1].AsSpan().Trim().Trim('\"').ToString(); |
| | | 237 | | |
| | 168 | 238 | | if (!string.IsNullOrEmpty(key)) |
| | | 239 | | { |
| | | 240 | | // Ordered by current likelihood. |
| | 168 | 241 | | if (string.Equals(key, "authorization", StringComparison.OrdinalIgnoreCase)) |
| | | 242 | | { |
| | 84 | 243 | | authority = value; |
| | | 244 | | } |
| | 84 | 245 | | else if (string.Equals(key, "resource", StringComparison.OrdinalIgnoreCase)) |
| | | 246 | | { |
| | 84 | 247 | | scope = value + "/.default"; |
| | | 248 | | } |
| | 0 | 249 | | else if (string.Equals(key, "scope", StringComparison.OrdinalIgnoreCase)) |
| | | 250 | | { |
| | 0 | 251 | | scope = value; |
| | | 252 | | } |
| | 0 | 253 | | else if (string.Equals(key, "authorization_uri", StringComparison.OrdinalIgnoreCase)) |
| | | 254 | | { |
| | 0 | 255 | | authority = value; |
| | | 256 | | } |
| | | 257 | | } |
| | | 258 | | } |
| | | 259 | | } |
| | | 260 | | } |
| | | 261 | | |
| | 84 | 262 | | if (authority != null && scope != null) |
| | | 263 | | { |
| | 84 | 264 | | return new AuthenticationChallenge(authority, scope); |
| | | 265 | | } |
| | | 266 | | |
| | 0 | 267 | | return null; |
| | | 268 | | } |
| | | 269 | | |
| | | 270 | | private static string GetRequestAuthority(Request request) |
| | | 271 | | { |
| | 168 | 272 | | Uri uri = request.Uri.ToUri(); |
| | | 273 | | |
| | 168 | 274 | | string authority = uri.Authority; |
| | | 275 | | |
| | 168 | 276 | | if (!authority.Contains(":") && uri.Port > 0) |
| | | 277 | | { |
| | | 278 | | // Append port for complete authority |
| | 168 | 279 | | authority = uri.Authority + ":" + uri.Port.ToString(CultureInfo.InvariantCulture); |
| | | 280 | | } |
| | | 281 | | |
| | 168 | 282 | | return authority; |
| | | 283 | | } |
| | | 284 | | } |
| | | 285 | | } |
| | | 286 | | } |