lrente
lrente

Reputation: 1130

Linq Recursive Sum

I have the following data strucure:

        List<Item> Items = new List<Item>
        {
            new Item{ Id = 1, Name = "Machine" },
            new Item{ Id = 3, Id_Parent = 1,  Name = "Machine1"},
            new Item{ Id = 5, Id_Parent = 3,  Name = "Machine1-A", Number = 2, Price = 10 },
            new Item{ Id = 9, Id_Parent = 3,  Name = "Machine1-B", Number = 4, Price = 11 },
            new Item{ Id = 100,  Name = "Item" } ,
            new Item{ Id = 112,  Id_Parent = 100, Name = "Item1", Number = 5, Price = 55 }
        };

I want to build a query that gets the sum of all children price in its parent (items are related by Id_Parent). For example, for Item Id = 100, I have 55, because thats the value of the its child .

For Item Id = 3 I have 21, becaue Item Id = 5 and Id = 9 all sum to that. So far soo good.

What I am strugling to get is for Item Id = 1 I should also have the sum = 21, because Id = 3 is a child of Id = 1 and it has a sum of 21.

Here is my code:

        var result = from i in items
                                   join item in item on i.Id_Parent equals item.Id
                                   select new
                                   {
                                       Name = prod.Nome,
                                       Sum =
                                         (from it in items
                                          where it.Id_Parent == item.Id
                                          group it by new
                                          {
                                              it.Id_Parent
                                          }
                                          into g
                                          select new
                                          {
                                              Sum = g.Sum(x => x.Price)
                                          }
                                         ).First()
                                   };

Help appreciated.

Upvotes: 3

Views: 1892

Answers (4)

Erik Philips
Erik Philips

Reputation: 54618

For future readers who may experience a StackOverflowException, the alternative I use is in the following example: (dotnetfiddle example)

using System;
using System.Collections.Generic;
using System.Linq;
                    
public class Program
{
    public static void Main()
    {
        var items = new List<Item>
        {
            new Item{ Id = 1, Name = "Machine" },
            new Item{ Id = 3, Id_Parent = 1,  Name = "Machine1"},
            new Item{ Id = 5, Id_Parent = 3,  Name = "Machine1-A", Number = 2, Price = 10 },
            new Item{ Id = 9, Id_Parent = 3,  Name = "Machine1-B", Number = 4, Price = 11 },
            new Item{ Id = 100,  Name = "Item" } ,
            new Item{ Id = 112,  Id_Parent = 100, Name = "Item1", Number = 5, Price = 55 }
        };
        
        foreach(var item in items)
        {
            Console.WriteLine("{0} {1} $" + GetSum(items, item.Id).ToString(), item.Name, item.Id);
        }
        
    }
    
    public static int GetSum(IEnumerable<Item> items, int id)
    {
        // add all matching items
        var itemsToSum = items.Where(i => i.Id == id).ToList();
        var oldCount = 0;
        var currentCount = itemsToSum.Count();
        // it nothing was added we skip the while
        while (currentCount != oldCount)
        {
            oldCount = currentCount;
            // find all matching items except the ones already in the list
            var matchedItems = items
                .Join(itemsToSum, item => item.Id_Parent, sum => sum.Id, (item, sum) => item)
                .Except(itemsToSum)
                .ToList();
            itemsToSum.AddRange(matchedItems);
            currentCount = itemsToSum.Count;
        }
        
        return itemsToSum.Sum(i => i.Price);
    }
    
    public class Item
    {
        public int Id { get; set; }
        public int Id_Parent { get; set; }
        public int Number { get; set; }
        public int Price { get; set; }
        public string Name { get; set; }
    
    }
}

Result:

Machine 1 $21

Machine1 3 $21

Machine1-A 5 $10

Machine1-B 9 $11

Item 100 $55

Item1 112 $55

Basically we create a list with the initial items matching the id passed. If the id doesn't match we have no items and we skip the while loop. If we do have items, then we join to find all items that have a parent id of the items we currently have. From that list we then exclude the ones already in the list. Then append what we've found. Eventually there are no more items in the list that have matching parent id's.

Upvotes: 0

Anton&#237;n Lejsek
Anton&#237;n Lejsek

Reputation: 6103

There are so many solutions that it is worth to make a benchmark. I added my solution to the mix too, it is the last function. Some functions include the root node and some not, but apart from this they return the same result. I tested wide tree with 2 children per parent and narrow tree with just one children per parent (depth is equal to number of items). And the results are:

---------- Wide 100000 3 ----------
ItemDescendents: 9592ms
ItemDescendentsFlat: 9544ms
ItemDescendentsFlat2: 45826ms
ItemDescendentsFlat3: 30ms
ItemDescendentsFlat4: 11ms
CalculatePrice: 23849ms
Y: 24265ms
GetSum: 62ms
GetDescendants: 19ms

---------- Narrow 3000 3 ----------
ItemDescendents: 100ms
ItemDescendentsFlat: 24ms
ItemDescendentsFlat2: 75948ms
ItemDescendentsFlat3: 1004ms
ItemDescendentsFlat4: 1ms
CalculatePrice: 69ms
Y: 69ms
GetSum: 915ms
GetDescendants: 0ms

While premature optimalization is bad, it is important to know what the asymptotic behaviour is. Asymptotic behaviour determines if the algorithm would scale or would die.

And the code follows

using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading.Tasks;

namespace ConsoleApp3
{
    class Program
    {
        public class Test
        {
            public static IEnumerable<Item> ItemDescendents(IEnumerable<Item> src, int parent_id)
            {
                foreach (var item in src.Where(i => i.Id_Parent == parent_id))
                {
                    yield return item;
                    foreach (var itemd in ItemDescendents(src, item.Id))
                        yield return itemd;
                }
            }

            public static IEnumerable<Item> ItemDescendentsFlat(IEnumerable<Item> src, int parent_id)
            {
                void PushRange<T>(Stack<T> s, IEnumerable<T> Ts)
                {
                    foreach (var aT in Ts)
                        s.Push(aT);
                }

                var itemStack = new Stack<Item>(src.Where(i => i.Id_Parent == parent_id));

                while (itemStack.Count > 0)
                {
                    var item = itemStack.Pop();
                    PushRange(itemStack, src.Where(i => i.Id_Parent == item.Id));
                    yield return item;
                };
            }

            public IEnumerable<Item> ItemDescendantsFlat2(IEnumerable<Item> src, int parent_id)
            {
                var children = src.Where(s => s.Id_Parent == parent_id);
                do
                {
                    foreach (var c in children)
                        yield return c;
                    children = children.SelectMany(c => src.Where(i => i.Id_Parent == c.Id));
                } while (children.Count() > 0);
            }

            public IEnumerable<Item> ItemDescendantsFlat3(IEnumerable<Item> src, int parent_id)
            {
                var childItems = src.ToLookup(i => i.Id_Parent);

                var children = childItems[parent_id];
                do
                {
                    foreach (var c in children)
                        yield return c;
                    children = children.SelectMany(c => childItems[c.Id]);
                } while (children.Count() > 0);
            }

            public IEnumerable<Item> ItemDescendantsFlat4(IEnumerable<Item> src, int parent_id)
            {
                var childItems = src.ToLookup(i => i.Id_Parent);

                var stackOfChildren = new Stack<IEnumerable<Item>>();
                stackOfChildren.Push(childItems[parent_id]);
                do
                    foreach (var c in stackOfChildren.Pop())
                    {
                        yield return c;
                        stackOfChildren.Push(childItems[c.Id]);
                    }
                while (stackOfChildren.Count > 0);
            }

            public static int GetSum(IEnumerable<Item> items, int id)
            {
                // add all matching items
                var itemsToSum = items.Where(i => i.Id == id).ToList();
                var oldCount = 0;
                var currentCount = itemsToSum.Count();
                // it nothing was added we skip the while
                while (currentCount != oldCount)
                {
                    oldCount = currentCount;
                    // find all matching items except the ones already in the list
                    var matchedItems = items
                        .Join(itemsToSum, item => item.Id_Parent, sum => sum.Id, (item, sum) => item)
                        .Except(itemsToSum)
                        .ToList();
                    itemsToSum.AddRange(matchedItems);
                    currentCount = itemsToSum.Count;
                }

                return itemsToSum.Sum(i => i.Price);
            }

            /// <summary>
            /// Implements a recursive function that takes a single parameter
            /// </summary>
            /// <typeparam name="T">The Type of the Func parameter</typeparam>
            /// <typeparam name="TResult">The Type of the value returned by the recursive function</typeparam>
            /// <param name="f">The function that returns the recursive Func to execute</param>
            /// <returns>The recursive Func with the given code</returns>
            [MethodImpl(MethodImplOptions.AggressiveInlining)]
            public static Func<T, TResult> Y<T, TResult>(Func<Func<T, TResult>, Func<T, TResult>> f)
            {
                Func<T, TResult> g = null;
                g = f(a => g(a));
                return g;
            }


            public IEnumerable<Item> GetDescendants(IEnumerable<Item> items, int key)
            {
                var lookup = items.ToLookup(i => i.Id_Parent);
                Stack<Item> st = new Stack<Item>(lookup[key]);

                while (st.Count > 0)
                {
                    var item = st.Pop();
                    yield return item;
                    foreach (var i in lookup[item.Id])
                    {
                        st.Push(i);
                    }
                }
            }

            public class Item
            {
                public int Id;
                public int Price;
                public int Id_Parent;
            }

            protected Item[] getItems(int count, bool wide)
            {
                Item[] Items = new Item[count];
                for (int i = 0; i < count; ++i)
                {
                    Item ix = new Item()
                    {
                        Id = i,
                        Id_Parent = wide ? i / 2 : i - 1,
                        Price = i % 17
                    };
                    Items[i] = ix;
                }
                return Items;
            }

            public void test()
            {
                Item[] items = null;

                int CalculatePrice(int id)
                {
                    int price = items.Where(item => item.Id_Parent == id).Sum(child => CalculatePrice(child.Id));
                    return price + items.First(item => item.Id == id).Price;
                }

                var functions = new List<(Func<Item[], int, int>, string)>() {
                ((it, key) => ItemDescendents(it, key).Sum(i => i.Price), "ItemDescendents"),
                ((it, key) => ItemDescendentsFlat(it, key).Sum(i => i.Price), "ItemDescendentsFlat"),
                ((it, key) => ItemDescendantsFlat2(it, key).Sum(i => i.Price), "ItemDescendentsFlat2"),
                ((it, key) => ItemDescendantsFlat3(it, key).Sum(i => i.Price), "ItemDescendentsFlat3"),
                ((it, key) => ItemDescendantsFlat4(it, key).Sum(i => i.Price), "ItemDescendentsFlat4"),
                ((it, key) => CalculatePrice(key), "CalculatePrice"),
                ((it, key) => Y<int, int>(x => y =>
                {
                    int price = it.Where(item => item.Id_Parent == y).Sum(child => x(child.Id));
                    return price + it.First(item => item.Id == y).Price;
                })(key), "Y"),
                ((it, key) => GetSum(it, key), "GetSum"),
                ((it, key) => GetDescendants(it, key).Sum(i => i.Price), "GetDescendants" )                 
                };

                System.Diagnostics.Stopwatch st = new System.Diagnostics.Stopwatch();

                var testSetup = new[]
                {
                    new { Count = 10, Wide = true, Key=3}, //warmup
                    new { Count = 100000, Wide = true, Key=3},
                    new { Count = 3000, Wide = false, Key=3}
                };

                List<int> sums = new List<int>();
                foreach (var setup in testSetup)
                {
                    items = getItems(setup.Count, setup.Wide);
                    Console.WriteLine("---------- " + (setup.Wide ? "Wide" : "Narrow")
                        + " " + setup.Count + " " + setup.Key + " ----------");
                    foreach (var func in functions)
                    {
                        st.Restart();
                        sums.Add(func.Item1(items, setup.Key));
                        st.Stop();
                        Console.WriteLine(func.Item2 + ": " + st.ElapsedMilliseconds + "ms");
                    }
                    Console.WriteLine();
                    Console.WriteLine("checks: " + string.Join(", ", sums));
                    sums.Clear();
                }

                Console.WriteLine("---------- END ----------");

            }
        }

        static void Main(string[] args)
        {
            Test t = new Test();
            t.test();
        }
    }
}

Upvotes: 0

NetMage
NetMage

Reputation: 26907

Create a recursive function to find all the children of a parent:

public static IEnumerable<Item> ItemDescendents(IEnumerable<Item> src, int parent_id) {
    foreach (var item in src.Where(i => i.Id_Parent == parent_id)) {
        yield return item;
        foreach (var itemd in ItemDescendents(src, item.Id))
            yield return itemd;
    }
}

Now you can get the price for any parent:

var price1 = ItemDescendants(Items, 1).Sum(i => i.Price);

Note if you know that the children of an item are always greater in id value than their parent, you don't need recursion:

var descendents = Items.OrderBy(i => i.Id).Aggregate(new List<Item>(), (ans, i) => {
    if (i.Id_Parent == 1 || ans.Select(a => a.Id).Contains(i.Id_Parent))
        ans.Add(i);
    return ans;
});

For those that prefer to avoid recursion, you can use an explicit stack instead:

public static IEnumerable<Item> ItemDescendentsFlat(IEnumerable<Item> src, int parent_id) {
    void PushRange<T>(Stack<T> s, IEnumerable<T> Ts) {
        foreach (var aT in Ts)
            s.Push(aT);
    }

    var itemStack = new Stack<Item>(src.Where(i => i.Id_Parent == parent_id));

    while (itemStack.Count > 0) {
        var item = itemStack.Pop();
        PushRange(itemStack, src.Where(i => i.Id_Parent == item.Id));
        yield return item;
    }
}

I included PushRange helper function since Stack doesn't have one.

Finally, here is a variation that doesn't use any stack, implicit or explicit.

public IEnumerable<Item> ItemDescendantsFlat2(IEnumerable<Item> src, int parent_id) {
    var children = src.Where(s => s.Id_Parent == parent_id);
    do {
        foreach (var c in children)
            yield return c;
        children = children.SelectMany(c => src.Where(i => i.Id_Parent == c.Id)).ToList();
    } while (children.Count() > 0);
}

You can replace the multiple traversals of the source with a Lookup as well:

public IEnumerable<Item> ItemDescendantsFlat3(IEnumerable<Item> src, int parent_id) {
    var childItems = src.ToLookup(i => i.Id_Parent);

    var children = childItems[parent_id];
    do {
        foreach (var c in children)
            yield return c;
        children = children.SelectMany(c => childItems[c.Id]).ToList();
    } while (children.Count() > 0);
}

I optimized the above based on the comments about too much nested enumeration, which improved performance vastly, but I was also inspired to attempt to remove SelectMany which can be slow, and collect IEnumerables as I've seen suggested elsewhere to optimize Concat:

public IEnumerable<Item> ItemDescendantsFlat4(IEnumerable<Item> src, int parent_id) {
    var childItems = src.ToLookup(i => i.Id_Parent);

    var stackOfChildren = new Stack<IEnumerable<Item>>();
    stackOfChildren.Push(childItems[parent_id]);
    do
        foreach (var c in stackOfChildren.Pop()) {
            yield return c;
            stackOfChildren.Push(childItems[c.Id]);
        }
    while (stackOfChildren.Count > 0);
}

@AntonínLejsek's GetDescendants is still fastest, though it is very close now, but sometimes simpler wins out for performance.

Upvotes: 6

Sergio0694
Sergio0694

Reputation: 4567

The easy way would be to use a local function, like this:

int CalculatePrice(int id)
{
    int price = Items.Where(item => item.Id_Parent == id).Sum(child => CalculatePrice(child.Id));
    return price + Items.First(item => item.Id == id).Price;
}
int total = CalculatePrice(3); // 3 is just an example id

Another, cleaner solution instead would be to use the Y combinator to create a closure that can be called inline. Assuming you have this

/// <summary>
/// Implements a recursive function that takes a single parameter
/// </summary>
/// <typeparam name="T">The Type of the Func parameter</typeparam>
/// <typeparam name="TResult">The Type of the value returned by the recursive function</typeparam>
/// <param name="f">The function that returns the recursive Func to execute</param>
/// <returns>The recursive Func with the given code</returns>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Func<T, TResult> Y<T, TResult>(Func<Func<T, TResult>, Func<T, TResult>> f)
{
    Func<T, TResult> g = null;
    g = f(a => g(a));
    return g;
}

Then you can just get your result like so:

int total = Y<int, int>(x => y =>
{
    int price = Items.Where(item => item.Id_Parent == y).Sum(child => x(child.Id));
    return price + Items.First(item => item.Id == y).Price;
})(3);

What's nice about this is that it allows you to quickly declare and call a recursive function in a functional-fashion, which is especially handy in situations like this one, where you only need "throwaway" functions that you'll use just once. Also, since this function is quite small, using the Y combinator further reduces the boilerplate of having to declare a local function and call it on another line.

Upvotes: 1

Related Questions