| | 1 | | // Copyright (c) Microsoft Corporation. All rights reserved. |
| | 2 | | // Licensed under the MIT License. |
| | 3 | |
|
| | 4 | | using Azure.Core; |
| | 5 | | using System; |
| | 6 | | using System.Collections.Generic; |
| | 7 | | using System.Linq; |
| | 8 | | using System.Threading; |
| | 9 | | using System.Threading.Tasks; |
| | 10 | |
|
| | 11 | | namespace Azure.Identity |
| | 12 | | { |
| | 13 | | /// <summary> |
| | 14 | | /// Provides a <see cref="TokenCredential"/> implementation which chains multiple <see cref="TokenCredential"/> impl |
| | 15 | | /// until one of the getToken methods returns a non-default <see cref="AccessToken"/>. |
| | 16 | | /// </summary> |
| | 17 | | public class ChainedTokenCredential : TokenCredential |
| | 18 | | { |
| | 19 | | private const string AggregateAllUnavailableErrorMessage = "The ChainedTokenCredential failed to retrieve a toke |
| | 20 | |
|
| | 21 | | private const string AggregateCredentialFailedErrorMessage = "The ChainedTokenCredential failed due to an unhand |
| | 22 | |
|
| | 23 | | private readonly TokenCredential[] _sources; |
| | 24 | |
|
| | 25 | | /// <summary> |
| | 26 | | /// Creates an instance with the specified <see cref="TokenCredential"/> sources. |
| | 27 | | /// </summary> |
| | 28 | | /// <param name="sources">The ordered chain of <see cref="TokenCredential"/> implementations to tried when calli |
| 14 | 29 | | public ChainedTokenCredential(params TokenCredential[] sources) |
| | 30 | | { |
| 16 | 31 | | if (sources is null) throw new ArgumentNullException(nameof(sources)); |
| | 32 | |
|
| 12 | 33 | | if (sources.Length == 0) |
| | 34 | | { |
| 4 | 35 | | throw new ArgumentException("sources must not be empty", nameof(sources)); |
| | 36 | | } |
| | 37 | |
|
| 52 | 38 | | for (int i = 0; i < sources.Length; i++) |
| | 39 | | { |
| 20 | 40 | | if (sources[i] == null) |
| | 41 | | { |
| 2 | 42 | | throw new ArgumentException("sources must not contain null", nameof(sources)); |
| | 43 | | } |
| | 44 | |
|
| | 45 | | } |
| 6 | 46 | | _sources = sources; |
| 6 | 47 | | } |
| | 48 | |
|
| | 49 | | /// <summary> |
| | 50 | | /// Sequentially calls <see cref="TokenCredential.GetToken"/> on all the specified sources, returning the first |
| | 51 | | /// </summary> |
| | 52 | | /// <param name="requestContext">The details of the authentication request.</param> |
| | 53 | | /// <param name="cancellationToken">A <see cref="CancellationToken"/> controlling the request lifetime.</param> |
| | 54 | | /// <returns>The first <see cref="AccessToken"/> returned by the specified sources. Any credential which raises |
| | 55 | | public override AccessToken GetToken(TokenRequestContext requestContext, CancellationToken cancellationToken = d |
| | 56 | | { |
| 0 | 57 | | List<Exception> exceptions = new List<Exception>(); |
| | 58 | |
|
| 0 | 59 | | for (int i = 0; i < _sources.Length; i++) |
| | 60 | | { |
| | 61 | | try |
| | 62 | | { |
| 0 | 63 | | return _sources[i].GetToken(requestContext, cancellationToken); |
| | 64 | | } |
| 0 | 65 | | catch (CredentialUnavailableException e) |
| | 66 | | { |
| 0 | 67 | | exceptions.Add(e); |
| 0 | 68 | | } |
| 0 | 69 | | catch (Exception e) when (!(e is OperationCanceledException)) |
| | 70 | | { |
| 0 | 71 | | exceptions.Add(e); |
| | 72 | |
|
| 0 | 73 | | throw AuthenticationFailedException.CreateAggregateException(AggregateCredentialFailedErrorMessage + |
| | 74 | | } |
| | 75 | | } |
| | 76 | |
|
| 0 | 77 | | throw AuthenticationFailedException.CreateAggregateException(AggregateAllUnavailableErrorMessage, exceptions |
| 0 | 78 | | } |
| | 79 | |
|
| | 80 | | /// <summary> |
| | 81 | | /// Sequentially calls <see cref="TokenCredential.GetToken"/> on all the specified sources, returning the first |
| | 82 | | /// </summary> |
| | 83 | | /// <param name="requestContext">The details of the authentication request.</param> |
| | 84 | | /// <param name="cancellationToken">A <see cref="CancellationToken"/> controlling the request lifetime.</param> |
| | 85 | | /// <returns>The first <see cref="AccessToken"/> returned by the specified sources. Any credential which raises |
| | 86 | | public override async ValueTask<AccessToken> GetTokenAsync(TokenRequestContext requestContext, CancellationToken |
| | 87 | | { |
| 16 | 88 | | List<Exception> exceptions = new List<Exception>(); |
| | 89 | |
|
| 80 | 90 | | for (int i = 0; i < _sources.Length; i++) |
| | 91 | | { |
| | 92 | | try |
| | 93 | | { |
| 36 | 94 | | return await _sources[i].GetTokenAsync(requestContext, cancellationToken).ConfigureAwait(false); |
| | 95 | | } |
| 24 | 96 | | catch (CredentialUnavailableException e) |
| | 97 | | { |
| 24 | 98 | | exceptions.Add(e); |
| 24 | 99 | | } |
| 4 | 100 | | catch (Exception e) when (!(e is OperationCanceledException)) |
| | 101 | | { |
| 4 | 102 | | exceptions.Add(e); |
| | 103 | |
|
| 4 | 104 | | throw AuthenticationFailedException.CreateAggregateException(AggregateCredentialFailedErrorMessage + |
| | 105 | | } |
| | 106 | | } |
| | 107 | |
|
| 4 | 108 | | throw AuthenticationFailedException.CreateAggregateException(AggregateAllUnavailableErrorMessage, exceptions |
| 8 | 109 | | } |
| | 110 | | } |
| | 111 | | } |