Zdeněk
Zdeněk

Reputation: 969

Adding `MethodCallExpression` on top of provided `Expression`

I'm having expression tree originating in Linq, e.g. leCollection.Where(...).OrderBy(...).Skip(n).Take(m). Expression looks like:

Take(Skip(OrderBy(Where(...), ...), n), m) // you got the idea

Now, this is my ideal state that I have Take and Skip there, but it is not the rule. I would like to add Take/Skip programmatically if needed.

I came up with way how to change Take/Skip argument, and I'm even able to add Skip under Take if I detect it's not present, but I'm struggling to figure out how to add Take at the top of expression - I don't know how to recognize I'm actually visiting top expression. Methods I wrote are executed on every method call in tree, so I had to check method name before I do anything with expression.

Here are methods I'm using for altering Take/Skip and adding Skip under Take. Those work, I'm now also interested in placing Take on top of tree if it's not yet present. Could anyone direct me to any place of wisdom, where I can learn more?

public class LeVisitor<TEntity> : ExpressionVisitor
    where TEntity : class
{
    private readonly int? _take;
    private readonly int? _skip;
    private readonly MethodInfo _queryableSkip;

    public LeVisitor(int? take, int? skip)
    {
        // ...
    }

    protected override Expression VisitMethodCall(MethodCallExpression node)
    {
        return base.VisitMethodCall(AlterTake(AlterSkip(node)));
    }

    private MethodCallExpression AlterTake(MethodCallExpression node)
    {
        if (!_take.HasValue || !node.Method.Name.Equals("Take", StringComparison.Ordinal))
        {
            return node;
        }

        Expression innerCall = node.Arguments[0];
        if (_skip != null)
        {
            var innerMethod = innerCall as MethodCallExpression;
            if (innerMethod != null && !innerMethod.Method.Name.Equals("Skip", StringComparison.Ordinal))
            {
                ConstantExpression skipConstant = Expression.Constant(_skip, typeof(int));
                innerCall = Expression.Call(_queryableSkip, new[] { innerCall, skipConstant });
            }
        }

        return node.Update(
            node.Object,
            new[]
            {
                innerCall,
                Expression.Constant(_take, typeof(int))
            });
    }

    private MethodCallExpression AlterSkip(MethodCallExpression node)
    {
        if (!_skip.HasValue || !node.Method.Name.Equals("Skip", StringComparison.Ordinal))
        {
            return node;
        }

        return node.Update(
            node.Object,
            new[]
            {
                node.Arguments[0],
                Expression.Constant(_skip, typeof(int))
            });
    }
}

Upvotes: 1

Views: 511

Answers (1)

Aleks Andreev
Aleks Andreev

Reputation: 7054

You can override Visit method and use flag variable to check if this is a very first call to it.
Next code will check a top method and if it's not a Take add call to Queryable.Take

public class AddTakeVisitor : ExpressionVisitor
{
    private readonly int takeAmount;
    private bool firstEntry = true;

    public AddTakeVisitor(int takeAmount)
    {
        this.takeAmount = takeAmount;
    }

    public override Expression Visit(Expression node)
    {
        if (!firstEntry)
            return base.Visit(node);

        firstEntry = false;
        var methodCallExpression = node as MethodCallExpression;
        if (methodCallExpression == null)
            return base.Visit(node);

        if (methodCallExpression.Method.Name == "Take")
            return base.Visit(node);

        var elementType = node.Type.GetGenericArguments();
        var methodInfo = typeof(Queryable)
            .GetMethod("Take", BindingFlags.Public | BindingFlags.Static)
            .MakeGenericMethod(elementType.First());
        return Expression.Call(methodInfo, node, Expression.Constant(takeAmount));
    }
}

I've tested it with this code:

var exp = (new[] {1, 2, 3}).AsQueryable().Skip(1);
var visitor = new AddTakeVisitor(1);
var modified = visitor.Visit(exp.Expression);

modified.DebugView looks like this:

.Call System.Linq.Queryable.Take(
    .Call System.Linq.Queryable.Skip(
        .Constant<System.Linq.EnumerableQuery`1[System.Int32]>(System.Int32[]),
        1),
    1)

Upvotes: 1

Related Questions