CharlesNRice
CharlesNRice

Reputation: 3259

Call base method from unity interception

I'm trying to integrate the Transient Fault Handling Application Block in my application via Unity Interception using the VirtualMethodInterceptor.

I created a call handler to create the action or func or task of the intercepted method and pass that into the Transient Fault Handler but now I just get a stackoverflow exception. Which makes sense as the fault handler will call the method which unity will intercept and pass into the fault handler.

What I need is to be able to skip the interception when called back by the Fault Handler or call directly into the base method. Some how unity is able to do this as it will eventually call my classes method.

My interception code

public class RetryHandler : ICallHandler
{
    // store a cache of the delegates
    private static readonly ConcurrentDictionary<MethodInfo, Func<object, object[], object>> CacheCall =
        new ConcurrentDictionary<MethodInfo, Func<object, object[], object>>();

    // List of methods we need to call
    private static readonly Action<Action> RetryActionMethod = RetryPolicy.NoRetry.ExecuteAction;
    private static readonly Func<Func<int>, int> RetryFuncMethod = RetryPolicy.NoRetry.ExecuteAction;
    private static readonly Func<Func<Task<int>>, Task<int>> RetryAsyncFuncMethod = RetryPolicy.NoRetry.ExecuteAsync;
    private static readonly Func<Func<Task>, Task> RetryAsyncActionMethod = RetryPolicy.NoRetry.ExecuteAsync;
    private static readonly ConstructorInfo RetryPolicyConstructor;

    static RetryHandler()
    {
        RetryPolicyConstructor =
            typeof (RetryPolicy).GetConstructor(new[]
            {typeof (ITransientErrorDetectionStrategy), typeof (RetryStrategy)});
    }

    public int Order { get; set; }

    /// <summary>
    /// Uses Expression Trees to wrap method call into retryhandler
    /// </summary>
    /// <param name="input"></param>
    /// <param name="getNext"></param>
    /// <returns></returns>
    public IMethodReturn Invoke(IMethodInvocation input, GetNextHandlerDelegate getNext)
    {
        if (input.Arguments.Count == input.Inputs.Count)
        {
            var methodInfo = input.MethodBase as MethodInfo;
            if (methodInfo != null)
            {
                var func = CacheCall.GetOrAdd(methodInfo, BuildLambda);
                var results = func(input.Target, input.Inputs.OfType<object>().ToArray());
                var ex = results as Exception;
                if (ex != null)
                {
                    return input.CreateExceptionMethodReturn(ex);
                }
                return input.CreateMethodReturn(results);
            }
        }
        return getNext()(input, getNext);
    }

    private static Func<object, object[], object> BuildLambda(MethodInfo methodInfo)
    {
        var retryAttribute = methodInfo.GetCustomAttributes<RetryAttribute>(true).First();

        // Convert parameters and source object to be able to call method
        var target = Expression.Parameter(typeof (object), "target");
        var parameters = Expression.Parameter(typeof (object[]), "parameters");
        Expression source;
        if (methodInfo.DeclaringType == null)
        {
            source = target;
        }
        else
        {
            source = Expression.Convert(target, methodInfo.DeclaringType);
        }

        var convertedParams =
            methodInfo.GetParameters()
                .Select(
                    (p, i) =>
                        Expression.Convert(Expression.ArrayIndex(parameters, Expression.Constant(i)),
                            p.ParameterType)).Cast<Expression>()
                .ToArray();

        //!!!!! **This line of code causes the stackoverflow as this will go back through interception**
        var innerExpression = Expression.Call(source, methodInfo, convertedParams);

        // get what type of lambda we need to build and what method we need to call on the retry handler
        Type returnType;
        MethodInfo retryMethod;
        if (methodInfo.ReturnType == typeof (void))
        {
            returnType = typeof (Action);
            retryMethod = RetryActionMethod.Method;
        }
        else if (methodInfo.ReturnType == typeof (Task))
        {
            returnType = typeof (Func<Task>);
            retryMethod = RetryAsyncActionMethod.Method;
        }
        else if (methodInfo.ReturnType.IsGenericType &&
                 methodInfo.ReturnType.GetGenericTypeDefinition() == typeof (Task<>))
        {
            var genericType = methodInfo.ReturnType.GetGenericArguments()[0];
            returnType =
                typeof (Func<>).MakeGenericType(
                    typeof (Task<>).MakeGenericType(genericType));
            retryMethod = RetryAsyncFuncMethod.Method.GetGenericMethodDefinition().MakeGenericMethod(genericType);
        }
        else
        {
            returnType = typeof (Func<>).MakeGenericType(methodInfo.ReturnType);
            retryMethod =
                RetryFuncMethod.Method.GetGenericMethodDefinition().MakeGenericMethod(methodInfo.ReturnType);
        }

        var innerLambda = Expression.Lambda(returnType, innerExpression);

        var outerLambda = Expression.Lambda(
            typeof (Func<,,>).MakeGenericType(typeof (object), typeof (object[]), returnType),
            innerLambda,
            target, parameters);

        // create the retry handler
        var retryPolicy = Expression.New(RetryPolicyConstructor,
            Expression.Invoke(retryAttribute.TransientErrorDetectionStrategy),
            Expression.Invoke(retryAttribute.RetryStrategy));

        var passedInTarget = Expression.Parameter(typeof (object), "wrapperTarget");
        var passedInParameters = Expression.Parameter(typeof (object[]), "wrapperParamters");

        var retryCall = Expression.Call(retryPolicy, retryMethod,
            Expression.Invoke(outerLambda, passedInTarget, passedInParameters));

        Expression resultExpression;
        if (methodInfo.ReturnType != typeof (void))
        {
            // convert to object so we can have a standard func<object, object[], object> 
            resultExpression = Expression.Convert(retryCall, typeof (object));
        }
        else
        {
            // if void we will set the return results as null - it's what unity wants and plus we keep our signature
            var returnTarget = Expression.Label(typeof (object));
            resultExpression = Expression.Block(retryCall, Expression.Label(returnTarget, Expression.Constant(null)));
        }


        var func =
            Expression.Lambda<Func<object, object[], object>>(resultExpression, passedInTarget, passedInParameters)
                .Compile();


        return func;
    }
}

My Attributes

public abstract class RetryAttribute : HandlerAttribute
{
    public RetryAttribute()
    {
        RetryStrategy = () =>
            Microsoft.Practices.EnterpriseLibrary.TransientFaultHandling.RetryStrategy.NoRetry;
        TransientErrorDetectionStrategy = () => RetryPolicy.NoRetry.ErrorDetectionStrategy;
    }
    public Expression<Func<RetryStrategy>> RetryStrategy { get; protected set; }

    public Expression<Func<ITransientErrorDetectionStrategy>> TransientErrorDetectionStrategy { get; protected set; }

    public override ICallHandler CreateHandler(IUnityContainer container)
    {
        return new RetryHandler();
    }
}

public class SqlDatabaseeRetryAttribute : RetryAttribute
{
    public SqlDatabaseeRetryAttribute()
    {
        RetryStrategy =
            () => Microsoft.Practices.EnterpriseLibrary.TransientFaultHandling.RetryStrategy.DefaultExponential;
        TransientErrorDetectionStrategy = () => new SqlDatabaseTransientErrorDetectionStrategy();
    }

}

Upvotes: 0

Views: 629

Answers (1)

CharlesNRice
CharlesNRice

Reputation: 3259

Ended up using Reflection Emit to call the base method and not the virtual method.

/// <summary>
/// Creates a base.  call method
///    Needs to do this so we can skip the unity interception and call
/// </summary>
/// <param name="methodInfo"></param>
/// <returns></returns>
private static Delegate BaseMethodCall(MethodInfo methodInfo)
{

    // get parameter types
    // include the calling type to make it an open delegate
    var paramTypes = new[] { methodInfo.DeclaringType.BaseType}.Concat(
        methodInfo.GetParameters().Select(pi => pi.ParameterType)).ToArray();

    var baseCall = new DynamicMethod(string.Empty, methodInfo.ReturnType, paramTypes, methodInfo.Module);

    var il = baseCall.GetILGenerator();
    // add all the parameters into the stack
    for (var i = 0; i < paramTypes.Length; i++)
    {
        il.Emit(OpCodes.Ldarg, i);
    }
    // call the method but not the virtual method 
    //   this is the key to not have the virtual run
    il.EmitCall(OpCodes.Call, methodInfo, null);
    il.Emit(OpCodes.Ret);

    // get the deletage type call of this method
    var delegateType = Expression.GetDelegateType(paramTypes.Concat(new[] { methodInfo.ReturnType}).ToArray());

    return baseCall.CreateDelegate(delegateType);
}

And change the line BuildLambda that was causing the issue to

// need to create call to bypass unity interception
var baseDelegate = BaseMethodCall(methodInfo);
var innerExpression = Expression.Invoke(Expression.Constant(baseDelegate), new[] { source }.Concat(convertedParams));

Upvotes: 0

Related Questions