| | 1 | | // Copyright (c) Microsoft Corporation. All rights reserved. |
| | 2 | | // Licensed under the MIT License. |
| | 3 | |
|
| | 4 | | using System; |
| | 5 | | using System.Collections.Generic; |
| | 6 | | using System.Linq; |
| | 7 | | using System.Threading.Tasks; |
| | 8 | | using Azure.Core.Pipeline; |
| | 9 | |
|
| | 10 | | namespace Azure.Core.TestFramework |
| | 11 | | { |
| | 12 | | public class MockTransport : HttpPipelineTransport |
| | 13 | | { |
| 1700 | 14 | | private readonly object _syncObj = new object(); |
| | 15 | | private readonly Func<MockRequest, MockResponse> _responseFactory; |
| | 16 | |
|
| 2836 | 17 | | public AsyncGate<MockRequest, MockResponse> RequestGate { get; } |
| | 18 | |
|
| 3636 | 19 | | public List<MockRequest> Requests { get; } = new List<MockRequest>(); |
| | 20 | |
|
| 2070 | 21 | | public bool? ExpectSyncPipeline { get; set; } |
| | 22 | |
|
| 15129 | 23 | | public MockTransport() |
| | 24 | | { |
| 15129 | 25 | | RequestGate = new AsyncGate<MockRequest, MockResponse>(); |
| 15129 | 26 | | } |
| | 27 | |
|
| 906 | 28 | | public MockTransport(params MockResponse[] responses) |
| | 29 | | { |
| 906 | 30 | | var requestIndex = 0; |
| 906 | 31 | | _responseFactory = req => |
| 906 | 32 | | { |
| 1952 | 33 | | lock (_syncObj) |
| 906 | 34 | | { |
| 1952 | 35 | | return responses[requestIndex++]; |
| 906 | 36 | | } |
| 1944 | 37 | | }; |
| 906 | 38 | | } |
| | 39 | |
|
| 84 | 40 | | public MockTransport(Func<MockRequest, MockResponse> responseFactory) |
| | 41 | | { |
| 84 | 42 | | _responseFactory = responseFactory; |
| 84 | 43 | | } |
| | 44 | |
|
| | 45 | | public override Request CreateRequest() |
| 334138 | 46 | | => new MockRequest(); |
| | 47 | |
|
| | 48 | | public override void Process(HttpMessage message) |
| | 49 | | { |
| 907 | 50 | | if (ExpectSyncPipeline == false) |
| | 51 | | { |
| 0 | 52 | | throw new InvalidOperationException("Sync pipeline invocation not expected"); |
| | 53 | | } |
| | 54 | |
|
| 907 | 55 | | ProcessCore(message).GetAwaiter().GetResult(); |
| 869 | 56 | | } |
| | 57 | |
|
| | 58 | | public override async ValueTask ProcessAsync(HttpMessage message) |
| | 59 | | { |
| 1173 | 60 | | if (ExpectSyncPipeline == true) |
| | 61 | | { |
| 0 | 62 | | throw new InvalidOperationException("Async pipeline invocation not expected"); |
| | 63 | | } |
| | 64 | |
|
| 1173 | 65 | | await ProcessCore(message); |
| 1131 | 66 | | } |
| | 67 | |
|
| | 68 | | private async Task ProcessCore(HttpMessage message) |
| | 69 | | { |
| 2080 | 70 | | if (!(message.Request is MockRequest request)) |
| 0 | 71 | | throw new InvalidOperationException("the request is not compatible with the transport"); |
| | 72 | |
|
| 2080 | 73 | | lock (_syncObj) |
| | 74 | | { |
| 2080 | 75 | | Requests.Add(request); |
| 2080 | 76 | | } |
| | 77 | |
|
| 2080 | 78 | | if (RequestGate != null) |
| | 79 | | { |
| 378 | 80 | | message.Response = await RequestGate.WaitForRelease(request); |
| | 81 | | } |
| | 82 | | else |
| | 83 | | { |
| 1702 | 84 | | message.Response = _responseFactory(request); |
| | 85 | | } |
| | 86 | |
|
| 2000 | 87 | | message.Response.ClientRequestId = request.ClientRequestId; |
| | 88 | |
|
| 2000 | 89 | | if (message.Response.ContentStream != null && ExpectSyncPipeline != null) |
| | 90 | | { |
| 140 | 91 | | message.Response.ContentStream = new AsyncValidatingStream(!ExpectSyncPipeline.Value, message.Response.C |
| | 92 | | } |
| 2000 | 93 | | } |
| | 94 | |
|
| | 95 | | public MockRequest SingleRequest |
| | 96 | | { |
| | 97 | | get |
| | 98 | | { |
| 140 | 99 | | lock (_syncObj) |
| | 100 | | { |
| 140 | 101 | | return Requests.Single(); |
| | 102 | | } |
| 140 | 103 | | } |
| | 104 | | } |
| | 105 | | } |
| | 106 | | } |