| | 1 | | // Copyright (c) Microsoft Corporation. All rights reserved. |
| | 2 | | // Licensed under the MIT License. |
| | 3 | |
|
| | 4 | | using Azure.Core; |
| | 5 | | using Azure.Core.Pipeline; |
| | 6 | | using System; |
| | 7 | | using System.Collections.Generic; |
| | 8 | | using System.Globalization; |
| | 9 | | using System.Threading.Tasks; |
| | 10 | |
|
| | 11 | | namespace Azure.Security.KeyVault |
| | 12 | | { |
| | 13 | | internal class ChallengeBasedAuthenticationPolicy : HttpPipelinePolicy |
| | 14 | | { |
| | 15 | | private const string BearerChallengePrefix = "Bearer "; |
| | 16 | |
|
| | 17 | | private readonly TokenCredential _credential; |
| | 18 | |
|
| | 19 | | private AuthenticationChallenge _challenge = null; |
| | 20 | | private string _headerValue; |
| | 21 | | private DateTimeOffset _refreshOn; |
| | 22 | |
|
| 172 | 23 | | public ChallengeBasedAuthenticationPolicy(TokenCredential credential) |
| | 24 | | { |
| 172 | 25 | | _credential = credential; |
| 172 | 26 | | } |
| | 27 | |
|
| | 28 | | public override void Process(HttpMessage message, ReadOnlyMemory<HttpPipelinePolicy> pipeline) |
| | 29 | | { |
| 414 | 30 | | ProcessCoreAsync(message, pipeline, false).EnsureCompleted(); |
| 412 | 31 | | } |
| | 32 | |
|
| | 33 | | public override ValueTask ProcessAsync(HttpMessage message, ReadOnlyMemory<HttpPipelinePolicy> pipeline) |
| | 34 | | { |
| 476 | 35 | | return ProcessCoreAsync(message, pipeline, true); |
| | 36 | | } |
| | 37 | |
|
| | 38 | | private async ValueTask ProcessCoreAsync(HttpMessage message, ReadOnlyMemory<HttpPipelinePolicy> pipeline, bool |
| | 39 | | { |
| 890 | 40 | | if (message.Request.Uri.Scheme != Uri.UriSchemeHttps) |
| | 41 | | { |
| 4 | 42 | | throw new InvalidOperationException("Bearer token authentication is not permitted for non TLS protected |
| | 43 | | } |
| | 44 | |
|
| 886 | 45 | | RequestContent originalContent = message.Request.Content; |
| | 46 | |
|
| | 47 | | // if this policy doesn't have _challenge cached try to get it from the static challenge cache |
| 886 | 48 | | AuthenticationChallenge challenge = _challenge ?? AuthenticationChallenge.GetChallenge(message); |
| | 49 | |
|
| | 50 | | // if we still don't have the challenge for the endpoint |
| | 51 | | // remove the content from the request and send without authentication to get the challenge |
| 886 | 52 | | if (challenge == null) |
| | 53 | | { |
| 84 | 54 | | message.Request.Content = null; |
| | 55 | | } |
| | 56 | | // otherwise if we already know the challenge authenticate the request |
| | 57 | | else |
| | 58 | | { |
| 802 | 59 | | await AuthenticateRequestAsync(message, async, challenge).ConfigureAwait(false); |
| | 60 | | } |
| | 61 | |
|
| 886 | 62 | | if (async) |
| | 63 | | { |
| 474 | 64 | | await ProcessNextAsync(message, pipeline).ConfigureAwait(false); |
| | 65 | | } |
| | 66 | | else |
| | 67 | | { |
| 412 | 68 | | ProcessNext(message, pipeline); |
| | 69 | | } |
| | 70 | |
|
| | 71 | | // if we get a 401 |
| 886 | 72 | | if (message.Response.Status == 401) |
| | 73 | | { |
| | 74 | | // set the content to the original content in case it was cleared |
| 84 | 75 | | message.Request.Content = originalContent; |
| | 76 | |
|
| | 77 | | // update the cached challenge |
| 84 | 78 | | challenge = AuthenticationChallenge.GetChallenge(message); |
| | 79 | |
|
| 84 | 80 | | if (challenge != null) |
| | 81 | | { |
| | 82 | | // update the cached challenge if not yet set or different from the current challenge (e.g. moved te |
| 84 | 83 | | if (_challenge == null || !challenge.Equals(_challenge)) |
| | 84 | | { |
| 84 | 85 | | _challenge = challenge; |
| | 86 | | } |
| | 87 | |
|
| | 88 | | // authenticate the request and resend |
| 84 | 89 | | await AuthenticateRequestAsync(message, async, challenge).ConfigureAwait(false); |
| | 90 | |
|
| 84 | 91 | | if (async) |
| | 92 | | { |
| 42 | 93 | | await ProcessNextAsync(message, pipeline).ConfigureAwait(false); |
| | 94 | | } |
| | 95 | | else |
| | 96 | | { |
| 42 | 97 | | ProcessNext(message, pipeline); |
| | 98 | | } |
| | 99 | | } |
| | 100 | | } |
| 886 | 101 | | } |
| | 102 | |
|
| | 103 | | private async Task AuthenticateRequestAsync(HttpMessage message, bool async, AuthenticationChallenge challenge) |
| | 104 | | { |
| 886 | 105 | | if (_headerValue is null || DateTimeOffset.UtcNow >= _refreshOn) |
| | 106 | | { |
| 84 | 107 | | AccessToken token = async ? |
| 84 | 108 | | await _credential.GetTokenAsync(new TokenRequestContext(challenge.Scopes, message.Request.Client |
| 84 | 109 | | _credential.GetToken(new TokenRequestContext(challenge.Scopes, message.Request.ClientRequestId), |
| | 110 | |
|
| 84 | 111 | | _headerValue = BearerChallengePrefix + token.Token; |
| 84 | 112 | | _refreshOn = token.ExpiresOn - TimeSpan.FromMinutes(2); |
| | 113 | | } |
| | 114 | |
|
| 886 | 115 | | message.Request.Headers.SetValue(HttpHeader.Names.Authorization, _headerValue); |
| 886 | 116 | | } |
| | 117 | |
|
| | 118 | | internal class AuthenticationChallenge |
| | 119 | | { |
| 2 | 120 | | private static readonly Dictionary<string, AuthenticationChallenge> s_cache = new Dictionary<string, Authent |
| 2 | 121 | | private static readonly object s_cacheLock = new object(); |
| 2 | 122 | | private static readonly string[] s_challengeDelimiters = new string[] { "," }; |
| | 123 | |
|
| 84 | 124 | | private AuthenticationChallenge(string authority, string scope) |
| | 125 | | { |
| 84 | 126 | | Authority = authority; |
| 84 | 127 | | Scopes = new string[] { scope }; |
| 84 | 128 | | } |
| | 129 | |
|
| 0 | 130 | | public string Authority { get; } |
| | 131 | |
|
| 84 | 132 | | public string[] Scopes { get; } |
| | 133 | |
|
| | 134 | | public override bool Equals(object obj) |
| | 135 | | { |
| 0 | 136 | | if (ReferenceEquals(this, obj)) |
| | 137 | | { |
| 0 | 138 | | return true; |
| | 139 | | } |
| | 140 | |
|
| | 141 | | // This assumes that Authority Scopes are always non-null and Scopes has a length of one. |
| | 142 | | // This is guaranteed by the way the AuthenticationChallenge cache is constructed. |
| 0 | 143 | | if (obj is AuthenticationChallenge other) |
| | 144 | | { |
| 0 | 145 | | return string.Equals(Authority, other.Authority, StringComparison.OrdinalIgnoreCase) |
| 0 | 146 | | && string.Equals(Scopes[0], other.Scopes[0], StringComparison.OrdinalIgnoreCase); |
| | 147 | | } |
| | 148 | |
|
| 0 | 149 | | return false; |
| | 150 | | } |
| | 151 | |
|
| | 152 | | public override int GetHashCode() |
| | 153 | | { |
| | 154 | | // Currently the hash code is simply the hash of the authority and first scope as this is what is used t |
| | 155 | | // This assumes that Authority Scopes are always non-null and Scopes has a length of one. |
| | 156 | | // This is guaranteed by the way the AuthenticationChallenge cache is constructed. |
| 0 | 157 | | return HashCodeBuilder.Combine(Authority, Scopes[0]); |
| | 158 | | } |
| | 159 | |
|
| | 160 | | public static AuthenticationChallenge GetChallenge(HttpMessage message) |
| | 161 | | { |
| 168 | 162 | | AuthenticationChallenge challenge = null; |
| | 163 | |
|
| 168 | 164 | | if (message.HasResponse) |
| | 165 | | { |
| 84 | 166 | | challenge = GetChallengeFromResponse(message.Response); |
| | 167 | |
|
| | 168 | | // if the challenge is non-null cache it |
| 84 | 169 | | if (challenge != null) |
| | 170 | | { |
| 84 | 171 | | string authority = GetRequestAuthority(message.Request); |
| 84 | 172 | | lock (s_cacheLock) |
| | 173 | | { |
| 84 | 174 | | s_cache[authority] = challenge; |
| 84 | 175 | | } |
| | 176 | | } |
| | 177 | | } |
| | 178 | | else |
| | 179 | | { |
| | 180 | | // try to get the challenge from the cache |
| 84 | 181 | | string authority = GetRequestAuthority(message.Request); |
| 84 | 182 | | lock (s_cacheLock) |
| | 183 | | { |
| 84 | 184 | | s_cache.TryGetValue(authority, out challenge); |
| 84 | 185 | | } |
| | 186 | | } |
| | 187 | |
|
| 168 | 188 | | return challenge; |
| | 189 | | } |
| | 190 | |
|
| | 191 | | internal static void ClearCache() |
| | 192 | | { |
| | 193 | | // try to get the challenge from the cache |
| 84 | 194 | | lock (s_cacheLock) |
| | 195 | | { |
| 84 | 196 | | s_cache.Clear(); |
| 84 | 197 | | } |
| 84 | 198 | | } |
| | 199 | |
|
| | 200 | | private static AuthenticationChallenge GetChallengeFromResponse(Response response) |
| | 201 | | { |
| 84 | 202 | | AuthenticationChallenge challenge = null; |
| | 203 | |
|
| 84 | 204 | | if (response.Headers.TryGetValue("WWW-Authenticate", out string challengeValue) && challengeValue.Starts |
| | 205 | | { |
| 84 | 206 | | challenge = ParseBearerChallengeHeaderValue(challengeValue); |
| | 207 | | } |
| | 208 | |
|
| 84 | 209 | | return challenge; |
| | 210 | | } |
| | 211 | |
|
| | 212 | | private static AuthenticationChallenge ParseBearerChallengeHeaderValue(string challengeValue) |
| | 213 | | { |
| 84 | 214 | | string authority = null; |
| 84 | 215 | | string scope = null; |
| | 216 | |
|
| | 217 | | // remove the bearer challenge prefix |
| 84 | 218 | | var trimmedChallenge = challengeValue.Substring(BearerChallengePrefix.Length); |
| | 219 | |
|
| | 220 | | // Split the trimmed challenge into a set of name=value strings that |
| | 221 | | // are comma separated. The value fields are expected to be within |
| | 222 | | // quotation characters that are stripped here. |
| 84 | 223 | | string[] pairs = trimmedChallenge.Split(s_challengeDelimiters, StringSplitOptions.RemoveEmptyEntries); |
| | 224 | |
|
| 84 | 225 | | if (pairs.Length > 0) |
| | 226 | | { |
| | 227 | | // Process the name=value string |
| 504 | 228 | | for (int i = 0; i < pairs.Length; i++) |
| | 229 | | { |
| 168 | 230 | | string[] pair = pairs[i].Split('='); |
| | 231 | |
|
| 168 | 232 | | if (pair.Length == 2) |
| | 233 | | { |
| | 234 | | // We have a key and a value, now need to trim and decode |
| 168 | 235 | | string key = pair[0].AsSpan().Trim().Trim('\"').ToString(); |
| 168 | 236 | | string value = pair[1].AsSpan().Trim().Trim('\"').ToString(); |
| | 237 | |
|
| 168 | 238 | | if (!string.IsNullOrEmpty(key)) |
| | 239 | | { |
| | 240 | | // Ordered by current likelihood. |
| 168 | 241 | | if (string.Equals(key, "authorization", StringComparison.OrdinalIgnoreCase)) |
| | 242 | | { |
| 84 | 243 | | authority = value; |
| | 244 | | } |
| 84 | 245 | | else if (string.Equals(key, "resource", StringComparison.OrdinalIgnoreCase)) |
| | 246 | | { |
| 84 | 247 | | scope = value + "/.default"; |
| | 248 | | } |
| 0 | 249 | | else if (string.Equals(key, "scope", StringComparison.OrdinalIgnoreCase)) |
| | 250 | | { |
| 0 | 251 | | scope = value; |
| | 252 | | } |
| 0 | 253 | | else if (string.Equals(key, "authorization_uri", StringComparison.OrdinalIgnoreCase)) |
| | 254 | | { |
| 0 | 255 | | authority = value; |
| | 256 | | } |
| | 257 | | } |
| | 258 | | } |
| | 259 | | } |
| | 260 | | } |
| | 261 | |
|
| 84 | 262 | | if (authority != null && scope != null) |
| | 263 | | { |
| 84 | 264 | | return new AuthenticationChallenge(authority, scope); |
| | 265 | | } |
| | 266 | |
|
| 0 | 267 | | return null; |
| | 268 | | } |
| | 269 | |
|
| | 270 | | private static string GetRequestAuthority(Request request) |
| | 271 | | { |
| 168 | 272 | | Uri uri = request.Uri.ToUri(); |
| | 273 | |
|
| 168 | 274 | | string authority = uri.Authority; |
| | 275 | |
|
| 168 | 276 | | if (!authority.Contains(":") && uri.Port > 0) |
| | 277 | | { |
| | 278 | | // Append port for complete authority |
| 168 | 279 | | authority = uri.Authority + ":" + uri.Port.ToString(CultureInfo.InvariantCulture); |
| | 280 | | } |
| | 281 | |
|
| 168 | 282 | | return authority; |
| | 283 | | } |
| | 284 | | } |
| | 285 | | } |
| | 286 | | } |