| | 1 | | // Copyright (c) Microsoft Corporation. All rights reserved. |
| | 2 | | // Licensed under the MIT License. |
| | 3 | |
|
| | 4 | | using System; |
| | 5 | | using System.Collections.Generic; |
| | 6 | | using System.Diagnostics; |
| | 7 | | using System.Linq; |
| | 8 | | using System.Reflection; |
| | 9 | | using System.Runtime.ExceptionServices; |
| | 10 | | using System.Threading.Tasks; |
| | 11 | | using Castle.DynamicProxy; |
| | 12 | |
|
| | 13 | | namespace Azure.Core.TestFramework |
| | 14 | | { |
| | 15 | | /// <summary> |
| | 16 | | /// This interceptor forwards the async call to a sync method call with the same arguments |
| | 17 | | /// </summary> |
| | 18 | | public class UseSyncMethodsInterceptor : IInterceptor |
| | 19 | | { |
| | 20 | | private readonly bool _forceSync; |
| | 21 | |
|
| 132 | 22 | | public UseSyncMethodsInterceptor(bool forceSync) |
| | 23 | | { |
| 132 | 24 | | _forceSync = forceSync; |
| 132 | 25 | | } |
| | 26 | |
|
| | 27 | | private const string AsyncSuffix = "Async"; |
| | 28 | |
|
| 132 | 29 | | private readonly MethodInfo _taskFromResultMethod = typeof(Task) |
| 132 | 30 | | .GetMethod("FromResult", BindingFlags.Static | BindingFlags.Public); |
| | 31 | |
|
| 132 | 32 | | private readonly MethodInfo _taskFromExceptionMethod = typeof(Task) |
| 132 | 33 | | .GetMethods(BindingFlags.Static | BindingFlags.Public) |
| 5280 | 34 | | .Single(m => m.Name == "FromException" && m.IsGenericMethod); |
| | 35 | |
|
| | 36 | | [DebuggerStepThrough] |
| | 37 | | public void Intercept(IInvocation invocation) |
| | 38 | | { |
| 420105 | 39 | | Type[] parameterTypes = invocation.Method.GetParameters().Select(p => p.ParameterType).ToArray(); |
| | 40 | |
|
| 164275 | 41 | | var methodName = invocation.Method.Name; |
| 164275 | 42 | | if (!methodName.EndsWith(AsyncSuffix)) |
| | 43 | | { |
| 88426 | 44 | | MethodInfo asyncAlternative = GetMethod(invocation, methodName + AsyncSuffix, parameterTypes); |
| | 45 | |
|
| | 46 | | // Check if there is an async alternative to sync call |
| 88426 | 47 | | if (asyncAlternative != null) |
| | 48 | | { |
| 2 | 49 | | throw new InvalidOperationException($"Async method call expected for {methodName}"); |
| | 50 | | } |
| | 51 | | else |
| | 52 | | { |
| 88424 | 53 | | invocation.Proceed(); |
| 88420 | 54 | | return; |
| | 55 | | } |
| | 56 | | } |
| | 57 | |
|
| 75849 | 58 | | if (!_forceSync) |
| | 59 | | { |
| 38267 | 60 | | invocation.Proceed(); |
| 38255 | 61 | | return; |
| | 62 | | } |
| | 63 | |
|
| 37582 | 64 | | var nonAsyncMethodName = methodName.Substring(0, methodName.Length - AsyncSuffix.Length); |
| | 65 | |
|
| 37582 | 66 | | MethodInfo methodInfo = GetMethod(invocation, nonAsyncMethodName, parameterTypes); |
| 37582 | 67 | | if (methodInfo == null) |
| | 68 | | { |
| 2 | 69 | | throw new InvalidOperationException($"Unable to find a method with name {nonAsyncMethodName} and {string |
| 2 | 70 | | + "Make sure both methods have the same signature including the canc |
| | 71 | | } |
| | 72 | |
|
| 37580 | 73 | | Type returnType = methodInfo.ReturnType; |
| 37580 | 74 | | bool returnsSyncCollection = returnType.IsGenericType && returnType.GetGenericTypeDefinition() == typeof(Pag |
| | 75 | |
|
| | 76 | | try |
| | 77 | | { |
| | 78 | | // If we've got GetAsync<Model>() and just found Get<T>(), we |
| | 79 | | // need to change the method to Get<Model>(). |
| 37580 | 80 | | if (methodInfo.ContainsGenericParameters) |
| | 81 | | { |
| 462 | 82 | | methodInfo = methodInfo.MakeGenericMethod(invocation.Method.GetGenericArguments()); |
| 462 | 83 | | returnType = methodInfo.ReturnType; |
| | 84 | | } |
| 37580 | 85 | | object result = methodInfo.Invoke(invocation.InvocationTarget, invocation.Arguments); |
| | 86 | |
|
| | 87 | | // Map IEnumerable to IAsyncEnumerable |
| 35709 | 88 | | if (returnsSyncCollection) |
| | 89 | | { |
| 1623 | 90 | | Type[] modelType = returnType.GenericTypeArguments; |
| 1623 | 91 | | Type wrapperType = typeof(SyncPageableWrapper<>).MakeGenericType(modelType); |
| | 92 | |
|
| 1623 | 93 | | invocation.ReturnValue = Activator.CreateInstance(wrapperType, new[] { result }); |
| | 94 | | } |
| | 95 | | else |
| | 96 | | { |
| 34086 | 97 | | SetAsyncResult(invocation, returnType, result); |
| | 98 | | } |
| 35709 | 99 | | } |
| 1871 | 100 | | catch (TargetInvocationException exception) |
| | 101 | | { |
| 1871 | 102 | | if (returnsSyncCollection) |
| | 103 | | { |
| 8 | 104 | | ExceptionDispatchInfo.Capture(exception.InnerException).Throw(); |
| | 105 | | } |
| | 106 | | else |
| | 107 | | { |
| 1863 | 108 | | SetAsyncException(invocation, returnType, exception.InnerException); |
| | 109 | | } |
| 1863 | 110 | | } |
| 37572 | 111 | | } |
| | 112 | |
|
| | 113 | | private void SetAsyncResult(IInvocation invocation, Type returnType, object result) |
| | 114 | | { |
| 34086 | 115 | | Type methodReturnType = invocation.Method.ReturnType; |
| 34086 | 116 | | if (methodReturnType.IsGenericType) |
| | 117 | | { |
| 34086 | 118 | | returnType = CloseResponseType(returnType, methodReturnType); |
| 34086 | 119 | | if (methodReturnType.GetGenericTypeDefinition() == typeof(Task<>)) |
| | 120 | | { |
| 33978 | 121 | | invocation.ReturnValue = _taskFromResultMethod.MakeGenericMethod(returnType).Invoke(null, new[] { re |
| 33978 | 122 | | return; |
| | 123 | | } |
| 108 | 124 | | if (methodReturnType.GetGenericTypeDefinition() == typeof(ValueTask<>)) |
| | 125 | | { |
| 108 | 126 | | invocation.ReturnValue = Activator.CreateInstance(typeof(ValueTask<>).MakeGenericType(returnType), r |
| 108 | 127 | | return; |
| | 128 | | } |
| | 129 | | } |
| | 130 | |
|
| 0 | 131 | | throw new NotSupportedException(); |
| | 132 | | } |
| | 133 | |
|
| | 134 | | private void SetAsyncException(IInvocation invocation, Type returnType, Exception result) |
| | 135 | | { |
| 1863 | 136 | | Type methodReturnType = invocation.Method.ReturnType; |
| 1863 | 137 | | if (methodReturnType.IsGenericType) |
| | 138 | | { |
| 1863 | 139 | | returnType = CloseResponseType(returnType, methodReturnType); |
| 1863 | 140 | | if (methodReturnType.GetGenericTypeDefinition() == typeof(Task<>)) |
| | 141 | | { |
| 1733 | 142 | | invocation.ReturnValue = _taskFromExceptionMethod.MakeGenericMethod(returnType).Invoke(null, new[] { |
| 1733 | 143 | | return; |
| | 144 | | } |
| | 145 | |
|
| 130 | 146 | | if (methodReturnType.GetGenericTypeDefinition() == typeof(ValueTask<>)) |
| | 147 | | { |
| 130 | 148 | | var task = _taskFromExceptionMethod.MakeGenericMethod(returnType).Invoke(null, new[] { result }); |
| 130 | 149 | | invocation.ReturnValue = Activator.CreateInstance(typeof(ValueTask<>).MakeGenericType(returnType), t |
| 130 | 150 | | return; |
| | 151 | | } |
| | 152 | | } |
| | 153 | |
|
| 0 | 154 | | throw new NotSupportedException(); |
| | 155 | | } |
| | 156 | |
|
| | 157 | | /// <summary> |
| | 158 | | /// If the sync method returned Response{Model} and the async method's |
| | 159 | | /// return type is still an open Response{U}, we need to close it to |
| | 160 | | /// Response{Model} as well. We don't care about this if Response{} |
| | 161 | | /// has already been closed. |
| | 162 | | /// </summary> |
| | 163 | | /// <param name="returnType">The sync method's return type.</param> |
| | 164 | | /// <param name="methodReturnType">The async method's return type.</param> |
| | 165 | | /// <returns>A return type with Response{U} closed.</returns> |
| | 166 | | private static Type CloseResponseType(Type returnType, Type methodReturnType) |
| | 167 | | { |
| 35949 | 168 | | if (returnType.IsGenericType && |
| 35949 | 169 | | returnType.GetGenericTypeDefinition() == typeof(Response<>) && |
| 35949 | 170 | | returnType.ContainsGenericParameters) |
| | 171 | | { |
| 0 | 172 | | Type modelType = methodReturnType.GetGenericArguments()[0].GetGenericArguments()[0]; |
| 0 | 173 | | returnType = returnType.GetGenericTypeDefinition().MakeGenericType(modelType); |
| | 174 | | } |
| 35949 | 175 | | return returnType; |
| | 176 | | } |
| | 177 | |
|
| | 178 | | private static MethodInfo GetMethod(IInvocation invocation, string nonAsyncMethodName, Type[] types) |
| | 179 | | { |
| 126008 | 180 | | BindingFlags flags = IsInternal(invocation.Method) ? |
| 126008 | 181 | | BindingFlags.Instance | BindingFlags.NonPublic : |
| 126008 | 182 | | BindingFlags.Instance | BindingFlags.Public; |
| | 183 | |
|
| | 184 | | // Do our own slow "lightweight binding" in situations where we |
| | 185 | | // have generic arguments that aren't factored into the binder for |
| | 186 | | // the regular GetMethod call. We're taking lots of shortcuts like |
| | 187 | | // only comparing the generic type or count and it's only enough |
| | 188 | | // for the cases we have today. |
| 5970 | 189 | | static Type GenericDef(Type t) => t.IsGenericType ? t.GetGenericTypeDefinition() : t; |
| | 190 | | MethodInfo GetMethodSlow() |
| | 191 | | { |
| 89646 | 192 | | var methods = invocation.TargetType |
| 89646 | 193 | | // Start with all methods that have the right name |
| 6002284 | 194 | | .GetMethods(flags).Where(m => m.Name == nonAsyncMethodName); |
| | 195 | |
|
| | 196 | | // Check if their type parameters have the same generic |
| | 197 | | // type definitions (i.e., if I invoked |
| | 198 | | // GetAsync<Model>(Wrapper<Model>) we want that to match |
| | 199 | | // with Get<T>(Wrapper<T>) |
| 89646 | 200 | | var genericDefs = methods.Where(m => |
| | 201 | | m.GetParameters().Select(p => GenericDef(p.ParameterType)) |
| 3458 | 202 | | .SequenceEqual(types.Select(GenericDef))); |
| | 203 | |
|
| | 204 | | // If the previous check has any results, check if they have the same number of type arguments |
| | 205 | | // (all of our cases today either specialize on 0 or 1 type |
| | 206 | | // argument for the static or dynamic user schema approach) |
| | 207 | | // Else, close each GenericMethodDefinition and compare its paramter types. |
| 89646 | 208 | | var withSimilarGenericArguments = genericDefs.Any() ? |
| 89646 | 209 | | genericDefs.Where(m => |
| 2033 | 210 | | m.GetGenericArguments().Length == |
| 2033 | 211 | | invocation.Method.GetGenericArguments().Length) : |
| 89646 | 212 | | methods |
| | 213 | | .Where(m => m.IsGenericMethodDefinition) |
| 2632 | 214 | | .Select(m => m.MakeGenericMethod(invocation.GenericArguments)) |
| | 215 | | .Where(gm => gm.GetParameters().Select(p => p.ParameterType) |
| | 216 | | .SequenceEqual(invocation.Method.GetParameters().Select(p => p.ParameterType))); |
| | 217 | |
|
| | 218 | | // Hopefully we're down to 1. If you arrive here in the |
| | 219 | | // future because SingleOrDefault threw, we need to make |
| | 220 | | // the comparison logic more specific. If you arrive here |
| | 221 | | // because we're returning null, then we need to search |
| | 222 | | // a little more broadly. Either way, congratulations on |
| | 223 | | // blazing new API patterns and taking us boldly into the |
| | 224 | | // future. |
| 89646 | 225 | | return withSimilarGenericArguments.SingleOrDefault(); |
| | 226 | | } |
| | 227 | |
|
| | 228 | | try |
| | 229 | | { |
| 126008 | 230 | | return invocation.TargetType.GetMethod( |
| 126008 | 231 | | nonAsyncMethodName, |
| 126008 | 232 | | flags, |
| 126008 | 233 | | null, |
| 126008 | 234 | | types, |
| 126008 | 235 | | null) ?? |
| 126008 | 236 | | // Search a little more broadly if the regular binder |
| 126008 | 237 | | // couldn't find a match |
| 126008 | 238 | | GetMethodSlow(); |
| | 239 | | } |
| 8 | 240 | | catch (AmbiguousMatchException) |
| | 241 | | { |
| | 242 | | // Use our own binder to pick the best method if the regular |
| | 243 | | // binder couldn't decide between multiple choices |
| 8 | 244 | | return GetMethodSlow(); |
| | 245 | | } |
| 126008 | 246 | | } |
| | 247 | |
|
| 126008 | 248 | | private static bool IsInternal(MethodBase method) => method.IsAssembly || method.IsFamilyAndAssembly && !method. |
| | 249 | |
|
| | 250 | | private class SyncPageableWrapper<T> : AsyncPageable<T> |
| | 251 | | { |
| | 252 | | private readonly Pageable<T> _enumerable; |
| | 253 | |
|
| 1623 | 254 | | public SyncPageableWrapper(Pageable<T> enumerable) |
| | 255 | | { |
| 1623 | 256 | | _enumerable = enumerable; |
| 1623 | 257 | | } |
| | 258 | |
|
| | 259 | | #pragma warning disable 1998 |
| | 260 | | public override async IAsyncEnumerable<Page<T>> AsPages(string continuationToken = default, int? pageSizeHin |
| | 261 | | #pragma warning restore 1998 |
| | 262 | | { |
| 6564 | 263 | | foreach (Page<T> page in _enumerable.AsPages()) |
| | 264 | | { |
| 1682 | 265 | | yield return page; |
| | 266 | | } |
| 1598 | 267 | | } |
| | 268 | | } |
| | 269 | | } |
| | 270 | | } |