|   |  | 1 |  | // Copyright (c) Microsoft Corporation. All rights reserved. | 
|   |  | 2 |  | // Licensed under the MIT License. | 
|   |  | 3 |  |  | 
|   |  | 4 |  | using System; | 
|   |  | 5 |  | using System.Collections.Generic; | 
|   |  | 6 |  | using System.Linq; | 
|   |  | 7 |  | using System.Threading; | 
|   |  | 8 |  | using System.Threading.Tasks; | 
|   |  | 9 |  | using Azure.Core.Diagnostics; | 
|   |  | 10 |  |  | 
|   |  | 11 |  | namespace Azure.Core.Pipeline | 
|   |  | 12 |  | { | 
|   |  | 13 |  |     /// <summary> | 
|   |  | 14 |  |     /// A policy that sends an <see cref="AccessToken"/> provided by a <see cref="TokenCredential"/> as an Authenticatio | 
|   |  | 15 |  |     /// </summary> | 
|   |  | 16 |  |     public class BearerTokenAuthenticationPolicy : HttpPipelinePolicy | 
|   |  | 17 |  |     { | 
|   |  | 18 |  |         private readonly AccessTokenCache _accessTokenCache; | 
|   |  | 19 |  |  | 
|   |  | 20 |  |         /// <summary> | 
|   |  | 21 |  |         /// Creates a new instance of <see cref="BearerTokenAuthenticationPolicy"/> using provided token credential and  | 
|   |  | 22 |  |         /// </summary> | 
|   |  | 23 |  |         /// <param name="credential">The token credential to use for authentication.</param> | 
|   |  | 24 |  |         /// <param name="scope">The scope to authenticate for.</param> | 
|   | 88 | 25 |  |         public BearerTokenAuthenticationPolicy(TokenCredential credential, string scope) : this(credential, new[] { scop | 
|   |  | 26 |  |  | 
|   |  | 27 |  |         /// <summary> | 
|   |  | 28 |  |         /// Creates a new instance of <see cref="BearerTokenAuthenticationPolicy"/> using provided token credential and  | 
|   |  | 29 |  |         /// </summary> | 
|   |  | 30 |  |         /// <param name="credential">The token credential to use for authentication.</param> | 
|   |  | 31 |  |         /// <param name="scopes">Scopes to authenticate for.</param> | 
|   |  | 32 |  |         public BearerTokenAuthenticationPolicy(TokenCredential credential, IEnumerable<string> scopes) | 
|   | 104 | 33 |  |             : this(credential, scopes, TimeSpan.FromMinutes(5), TimeSpan.FromSeconds(30)) { } | 
|   |  | 34 |  |  | 
|   | 68 | 35 |  |         internal BearerTokenAuthenticationPolicy(TokenCredential credential, IEnumerable<string> scopes, TimeSpan tokenR | 
|   | 68 | 36 |  |             Argument.AssertNotNull(credential, nameof(credential)); | 
|   | 68 | 37 |  |             Argument.AssertNotNull(scopes, nameof(scopes)); | 
|   |  | 38 |  |  | 
|   | 68 | 39 |  |             _accessTokenCache = new AccessTokenCache(credential, scopes.ToArray(), tokenRefreshOffset, tokenRefreshRetry | 
|   | 68 | 40 |  |         } | 
|   |  | 41 |  |  | 
|   |  | 42 |  |         /// <inheritdoc /> | 
|   |  | 43 |  |         public override ValueTask ProcessAsync(HttpMessage message, ReadOnlyMemory<HttpPipelinePolicy> pipeline) | 
|   |  | 44 |  |         { | 
|   | 478 | 45 |  |             return ProcessAsync(message, pipeline, true); | 
|   |  | 46 |  |         } | 
|   |  | 47 |  |  | 
|   |  | 48 |  |         /// <inheritdoc /> | 
|   |  | 49 |  |         public override void Process(HttpMessage message, ReadOnlyMemory<HttpPipelinePolicy> pipeline) | 
|   |  | 50 |  |         { | 
|   | 478 | 51 |  |             ProcessAsync(message, pipeline, false).EnsureCompleted(); | 
|   | 260 | 52 |  |         } | 
|   |  | 53 |  |  | 
|   |  | 54 |  |         private async ValueTask ProcessAsync(HttpMessage message, ReadOnlyMemory<HttpPipelinePolicy> pipeline, bool asyn | 
|   |  | 55 |  |         { | 
|   | 956 | 56 |  |             if (message.Request.Uri.Scheme != Uri.UriSchemeHttps) | 
|   |  | 57 |  |             { | 
|   | 8 | 58 |  |                 throw new InvalidOperationException("Bearer token authentication is not permitted for non TLS protected  | 
|   |  | 59 |  |             } | 
|   |  | 60 |  |  | 
|   | 948 | 61 |  |             string headerValue = await _accessTokenCache.GetHeaderValueAsync(message, async); | 
|   | 520 | 62 |  |             message.Request.SetHeader(HttpHeader.Names.Authorization, headerValue); | 
|   |  | 63 |  |  | 
|   | 520 | 64 |  |             if (async) | 
|   |  | 65 |  |             { | 
|   | 260 | 66 |  |                 await ProcessNextAsync(message, pipeline).ConfigureAwait(false); | 
|   |  | 67 |  |             } | 
|   |  | 68 |  |             else | 
|   |  | 69 |  |             { | 
|   | 260 | 70 |  |                 ProcessNext(message, pipeline); | 
|   |  | 71 |  |             } | 
|   | 520 | 72 |  |         } | 
|   |  | 73 |  |  | 
|   |  | 74 |  |         private class AccessTokenCache | 
|   |  | 75 |  |         { | 
|   | 68 | 76 |  |             private readonly object _syncObj = new object(); | 
|   |  | 77 |  |             private readonly TokenCredential _credential; | 
|   |  | 78 |  |             private readonly string[] _scopes; | 
|   |  | 79 |  |             private readonly TimeSpan _tokenRefreshOffset; | 
|   |  | 80 |  |             private readonly TimeSpan _tokenRefreshRetryDelay; | 
|   |  | 81 |  |  | 
|   |  | 82 |  |             private TaskCompletionSource<HeaderValueInfo>? _infoTcs; | 
|   |  | 83 |  |             private TaskCompletionSource<HeaderValueInfo>? _backgroundUpdateTcs; | 
|   | 68 | 84 |  |             public AccessTokenCache(TokenCredential credential, string[] scopes, TimeSpan tokenRefreshOffset, TimeSpan t | 
|   |  | 85 |  |             { | 
|   | 68 | 86 |  |                 _credential = credential; | 
|   | 68 | 87 |  |                 _scopes = scopes; | 
|   | 68 | 88 |  |                 _tokenRefreshOffset = tokenRefreshOffset; | 
|   | 68 | 89 |  |                 _tokenRefreshRetryDelay = tokenRefreshRetryDelay; | 
|   | 68 | 90 |  |             } | 
|   |  | 91 |  |  | 
|   |  | 92 |  |             public async ValueTask<string> GetHeaderValueAsync(HttpMessage message, bool async) | 
|   |  | 93 |  |             { | 
|   |  | 94 |  |                 bool getTokenFromCredential; | 
|   |  | 95 |  |                 TaskCompletionSource<HeaderValueInfo> headerValueTcs; | 
|   |  | 96 |  |                 TaskCompletionSource<HeaderValueInfo>? backgroundUpdateTcs; | 
|   | 948 | 97 |  |                 (headerValueTcs, backgroundUpdateTcs, getTokenFromCredential) = GetTaskCompletionSources(); | 
|   |  | 98 |  |  | 
|   | 948 | 99 |  |                 if (getTokenFromCredential) | 
|   |  | 100 |  |                 { | 
|   | 124 | 101 |  |                     if (backgroundUpdateTcs != null) | 
|   |  | 102 |  |                     { | 
|   | 20 | 103 |  |                         HeaderValueInfo info = headerValueTcs.Task.EnsureCompleted(); | 
|   | 40 | 104 |  |                         _ = Task.Run(() => GetHeaderValueFromCredentialInBackgroundAsync(backgroundUpdateTcs, info, mess | 
|   | 20 | 105 |  |                         return info.HeaderValue; | 
|   |  | 106 |  |                     } | 
|   |  | 107 |  |  | 
|   |  | 108 |  |                     try | 
|   |  | 109 |  |                     { | 
|   | 104 | 110 |  |                         HeaderValueInfo info = await GetHeaderValueFromCredentialAsync(message, async, message.Cancellat | 
|   | 60 | 111 |  |                         headerValueTcs.SetResult(info); | 
|   | 60 | 112 |  |                     } | 
|   | 0 | 113 |  |                     catch (OperationCanceledException) | 
|   |  | 114 |  |                     { | 
|   | 0 | 115 |  |                         headerValueTcs.SetCanceled(); | 
|   | 0 | 116 |  |                         throw; | 
|   |  | 117 |  |                     } | 
|   | 44 | 118 |  |                     catch (Exception exception) | 
|   |  | 119 |  |                     { | 
|   | 44 | 120 |  |                         headerValueTcs.SetException(exception); | 
|   | 44 | 121 |  |                         throw; | 
|   |  | 122 |  |                     } | 
|   |  | 123 |  |                 } | 
|   |  | 124 |  |  | 
|   | 884 | 125 |  |                 var headerValueTask = headerValueTcs.Task; | 
|   | 884 | 126 |  |                 if (!headerValueTask.IsCompleted) | 
|   |  | 127 |  |                 { | 
|   | 602 | 128 |  |                     if (async) | 
|   |  | 129 |  |                     { | 
|   | 404 | 130 |  |                         await headerValueTask.AwaitWithCancellation(message.CancellationToken); | 
|   |  | 131 |  |                     } | 
|   |  | 132 |  |                     else | 
|   |  | 133 |  |                     { | 
|   |  | 134 |  |                         try | 
|   |  | 135 |  |                         { | 
|   | 198 | 136 |  |                             headerValueTask.Wait(message.CancellationToken); | 
|   | 18 | 137 |  |                         } | 
|   | 356 | 138 |  |                         catch (AggregateException) { } // ignore exception here to rethrow it with EnsureCompleted | 
|   |  | 139 |  |                     } | 
|   |  | 140 |  |                 } | 
|   |  | 141 |  |  | 
|   | 678 | 142 |  |                 return headerValueTcs.Task.EnsureCompleted().HeaderValue; | 
|   | 520 | 143 |  |             } | 
|   |  | 144 |  |  | 
|   |  | 145 |  |             private (TaskCompletionSource<HeaderValueInfo> tcs, TaskCompletionSource<HeaderValueInfo>? backgroundUpdateT | 
|   |  | 146 |  |             { | 
|   | 948 | 147 |  |                 lock (_syncObj) | 
|   |  | 148 |  |                 { | 
|   |  | 149 |  |                     // Initial state. GetTaskCompletionSources has been called for the first time | 
|   | 948 | 150 |  |                     if (_infoTcs == null) | 
|   |  | 151 |  |                     { | 
|   | 60 | 152 |  |                         _infoTcs = new TaskCompletionSource<HeaderValueInfo>(TaskCreationOptions.RunContinuationsAsynchr | 
|   | 60 | 153 |  |                         return (_infoTcs, default, true); | 
|   |  | 154 |  |                     } | 
|   |  | 155 |  |  | 
|   |  | 156 |  |                     // Getting new access token is in progress, wait for it | 
|   | 888 | 157 |  |                     if (!_infoTcs.Task.IsCompleted) | 
|   |  | 158 |  |                     { | 
|   | 602 | 159 |  |                         _backgroundUpdateTcs = default; | 
|   | 602 | 160 |  |                         return (_infoTcs, _backgroundUpdateTcs, false); | 
|   |  | 161 |  |                     } | 
|   |  | 162 |  |  | 
|   | 286 | 163 |  |                     DateTimeOffset now = DateTimeOffset.UtcNow; | 
|   |  | 164 |  |                     // Access token has been successfully acquired in background and it is not expired yet, use it inste | 
|   | 286 | 165 |  |                     if (_backgroundUpdateTcs != null && _backgroundUpdateTcs.Task.Status == TaskStatus.RanToCompletion & | 
|   |  | 166 |  |                     { | 
|   | 8 | 167 |  |                         _infoTcs = _backgroundUpdateTcs; | 
|   | 8 | 168 |  |                         _backgroundUpdateTcs = default; | 
|   |  | 169 |  |                     } | 
|   |  | 170 |  |  | 
|   |  | 171 |  |                     // Attempt to get access token has failed or it has already expired. Need to get a new one | 
|   | 286 | 172 |  |                     if (_infoTcs.Task.Status != TaskStatus.RanToCompletion || now >= _infoTcs.Task.Result.ExpiresOn) | 
|   |  | 173 |  |                     { | 
|   | 44 | 174 |  |                         _infoTcs = new TaskCompletionSource<HeaderValueInfo>(TaskCreationOptions.RunContinuationsAsynchr | 
|   | 44 | 175 |  |                         return (_infoTcs, default, true); | 
|   |  | 176 |  |                     } | 
|   |  | 177 |  |  | 
|   |  | 178 |  |                     // Access token is still valid but is about to expire, try to get it in background | 
|   | 242 | 179 |  |                     if (now >= _infoTcs.Task.Result.RefreshOn && _backgroundUpdateTcs == null) | 
|   |  | 180 |  |                     { | 
|   | 20 | 181 |  |                         _backgroundUpdateTcs = new TaskCompletionSource<HeaderValueInfo>(TaskCreationOptions.RunContinua | 
|   | 20 | 182 |  |                         return (_infoTcs, _backgroundUpdateTcs, true); | 
|   |  | 183 |  |                     } | 
|   |  | 184 |  |  | 
|   |  | 185 |  |                     // Access token is valid, use it | 
|   | 222 | 186 |  |                     return (_infoTcs, default, false); | 
|   |  | 187 |  |                 } | 
|   | 948 | 188 |  |             } | 
|   |  | 189 |  |  | 
|   |  | 190 |  |             private async ValueTask GetHeaderValueFromCredentialInBackgroundAsync(TaskCompletionSource<HeaderValueInfo>  | 
|   |  | 191 |  |             { | 
|   | 20 | 192 |  |                 var cts = new CancellationTokenSource(_tokenRefreshRetryDelay); | 
|   |  | 193 |  |                 try | 
|   |  | 194 |  |                 { | 
|   | 20 | 195 |  |                     HeaderValueInfo newInfo = await GetHeaderValueFromCredentialAsync(httpMessage, async, cts.Token); | 
|   | 8 | 196 |  |                     backgroundUpdateTcs.SetResult(newInfo); | 
|   | 8 | 197 |  |                 } | 
|   | 0 | 198 |  |                 catch (OperationCanceledException oce) when (cts.IsCancellationRequested) | 
|   |  | 199 |  |                 { | 
|   | 0 | 200 |  |                     backgroundUpdateTcs.SetResult(new HeaderValueInfo(info.HeaderValue, info.ExpiresOn, DateTimeOffset.U | 
|   | 0 | 201 |  |                     AzureCoreEventSource.Singleton.BackgroundRefreshFailed(httpMessage.Request.ClientRequestId, oce.ToSt | 
|   | 0 | 202 |  |                 } | 
|   | 12 | 203 |  |                 catch (Exception e) | 
|   |  | 204 |  |                 { | 
|   | 12 | 205 |  |                     backgroundUpdateTcs.SetResult(new HeaderValueInfo(info.HeaderValue, info.ExpiresOn, DateTimeOffset.U | 
|   | 12 | 206 |  |                     AzureCoreEventSource.Singleton.BackgroundRefreshFailed(httpMessage.Request.ClientRequestId, e.ToStri | 
|   | 12 | 207 |  |                 } | 
|   |  | 208 |  |                 finally | 
|   |  | 209 |  |                 { | 
|   | 20 | 210 |  |                     cts.Dispose(); | 
|   |  | 211 |  |                 } | 
|   | 20 | 212 |  |             } | 
|   |  | 213 |  |  | 
|   |  | 214 |  |             private async ValueTask<HeaderValueInfo> GetHeaderValueFromCredentialAsync(HttpMessage message, bool async,  | 
|   |  | 215 |  |             { | 
|   | 124 | 216 |  |                 var requestContext = new TokenRequestContext(_scopes, message.Request.ClientRequestId); | 
|   | 124 | 217 |  |                 AccessToken token = async | 
|   | 124 | 218 |  |                     ? await _credential.GetTokenAsync(requestContext, cancellationToken).ConfigureAwait(false) | 
|   | 124 | 219 |  |                     : _credential.GetToken(requestContext, cancellationToken); | 
|   |  | 220 |  |  | 
|   | 68 | 221 |  |                 return new HeaderValueInfo("Bearer " + token.Token, token.ExpiresOn, token.ExpiresOn - _tokenRefreshOffs | 
|   | 68 | 222 |  |             } | 
|   |  | 223 |  |  | 
|   |  | 224 |  |             private readonly struct HeaderValueInfo | 
|   |  | 225 |  |             { | 
|   | 532 | 226 |  |                 public string HeaderValue { get; } | 
|   | 278 | 227 |  |                 public DateTimeOffset ExpiresOn { get; } | 
|   | 242 | 228 |  |                 public DateTimeOffset RefreshOn { get; } | 
|   |  | 229 |  |  | 
|   |  | 230 |  |                 public HeaderValueInfo(string headerValue, DateTimeOffset expiresOn, DateTimeOffset refreshOn) | 
|   |  | 231 |  |                 { | 
|   | 80 | 232 |  |                     HeaderValue = headerValue; | 
|   | 80 | 233 |  |                     ExpiresOn = expiresOn; | 
|   | 80 | 234 |  |                     RefreshOn = refreshOn; | 
|   | 80 | 235 |  |                 } | 
|   |  | 236 |  |             } | 
|   |  | 237 |  |         } | 
|   |  | 238 |  |     } | 
|   |  | 239 |  | } |