lukas kiss
lukas kiss

Reputation: 392

Haskell State monad vs state as parameter performance test

I start to learn a State Monad and one idea bother me. Instead of passing accumulator as parameter, we can wrap everything to the state monad.

So I wanted to compare performance between using State monad vs passing it as parameter.

So I created two functions:

sum1 :: Int -> [Int] -> Int
sum1 x [] = x
sum1 x (y:xs) =  sum1 (x + y) xs

and

sumState:: [Int] -> Int
sumState xs = execState (traverse f xs) 0
    where f n = modify (n+)

I compared them on the input array [1..1000000000].

We can see clear winner, but the I realised that sumState can be optimised as:

  1. We can use strict version of modify
  2. We do not need necessary the map list output, so we can use traverse_ instead

So the new optimised state function is:

sumState:: [Int] -> Int
sumState xs = execState (traverse_ f xs) 0
    where f n = modify' (n+)

which has running time around 350ms. This is a huge improvement. It was shocking.

Why the modified sumState has better performance then sum1? Can sum1 be optimised to match or even be better then sumState?

I also tried other different implementation of sum as

Does it actually mean that it is better to use foldl function or State monad to pass accumulator instead of passing it as argument to the function?

Thank you for help.

EDIT:

Each function was in separate file with own main function and compiled with "-O2" flag.

main = do
    x <- (read . head ) <$> getArgs
    print $ <particular sum function> [1..x]

Runtime was measured via time command on linux.

Upvotes: 0

Views: 501

Answers (2)

Mark Saving
Mark Saving

Reputation: 1787

Ironically, the same problem that plagued your implementation of sumState is also the problem with sum1. You don't have strict accumulation, so you build up thunks like so:

sum 0 [1, 2, 3]
sum (0 + 1) [2, 3]
sum ((0 + 1) + 2) [3]
sum (((0 + 1) + 2) + 3) []
(((0 + 1) + 2) + 3)
((1 + 2) + 3)
(3 + 3)
6

If you add strictness to sum1, you should see a dramatic improvement in efficiency because you eliminate the non-tail-recursive evaluation of the thunk (((0 + 1) + 2) + 3), which is the costly part of sum1. Using strict accumulation makes this much more efficient:

sum1 x [] = []
sum1 x (y : xs) = x `seq` sum1 (x + y) xs

should give you comparable performance to sum (although as noted in another answer, GHC may not be able to use fusion properly to give you the truly magical performance of sum on the list [1..x]).

Upvotes: 0

Noughtmare
Noughtmare

Reputation: 10645

To give a bit more explanation as to why traverse is slower: traverse f xs has has type State [()] and that [()] (list of unit tuples) is built up during the summation. This prevents further optimizations and would cause a memory leak if you were not using lazy state.

Update: I think GHC should have been able to notice that that list of unit tuples is never used, so I opened a GHC issue.

In both cases, To get the best performance we want to combine (or fuse) the summation with the enumeration [1..x] into a tight recursive loop which simply increments and adds until it reaches x. The resulting code would look something like this:

sumFromTo :: Int -> Int -> Int -> Int
sumFromTo s x y
  | x == y = s + x
  | otherwise = sumFromTo (s + x) (x + 1) y

This avoids allocations for the list [1..x].

The base library achieves this optimization using foldr/build fusion, also known as short cut fusion. The sum, foldl' and traverse (for lists) functions are implemented using the foldr function and [1..x] is implemented using the build function. The foldr and build function have special optimization rules so that they can be fused. Your custom sum1 function doesn't use foldr and so it can never be fused with [1..x] in this way.

Upvotes: 2

Related Questions