Reputation: 67
I am working on project Euler #14, and have a solution to get the answer, but am getting a stack space overflow error when I try to run the code. The algorithm works OK in the interactive GHCI (on low numbers), but wont work when I throw a really big number at it and try to compile it.
Here is a rough idea of what it does in the interactive GHCI. It takes about 10 seconds to calculate "answer 50000" on my computer.
After letting GHCI run the problem for a few minutes, it spits out the correct answer.
*Euler System.IO> answer 1000000
(525,837799)
But that doesn't solve the stack overflow error when compiling the program to run natively.
*Euler System.IO> answer 10
(20,9)
*Euler System.IO> answer 100
(119,97)
*Euler System.IO> answer 1000
(179,871)
*Euler System.IO> answer 10000
(262,6171)
*Euler System.IO> answer 50000
(324,35655)
What should I do to get the answer to for "answer 1000000"? I imagine my algorithm needs to be fine tuned a bit, but I have no idea how to go about doing that.
Code:
module Main
where
import System.IO
import Control.Monad
main = print (answer 1000000)
-- Count the length of the sequences
-- count' creates a tuple with the second value
-- being the starting number of the game
-- and the first value being the total
-- length of the chain
count' n = (cSeq n, n)
cSeq n = length $ game n
-- Find the maximum chain value of the game
answer n = maximum $ map count' [1..n]
-- Working game.
-- game 13 = [13,40,20,10,5,16,8,4,2,1]
game n = n : play n
play x
| x <= 0 = [] -- is negative or 0
| x == 1 = [] -- is 1
| even x = doEven x : play ((doEven x)) -- even
| otherwise = doOdd x : play ((doOdd x)) -- odd
where doOdd x = (3 * x) + 1
doEven x = (x `div` 2)
Upvotes: 2
Views: 190
Reputation: 183858
@hammar already pointed out the problem that maximum
is too lazy, and how to resolve that (using foldl1'
, the strict version of foldl1
).
But there are further inefficiencies in the code.
cSeq n = length $ game n
cSeq
lets game
construct a list, only to calculate its length. Unfortunately, length
is not a "good consumer", so the construction of the intermediate list is not fused away. That's quite a bit of unnecessary allocation and costs time. Eliminating these lists
cSeq n = coll (1 :: Int) n
where
coll acc 1 = acc
coll acc m
| even m = coll (acc + 1) (m `div` 2)
| otherwise = coll (acc + 1) (3*m+1)
cuts down the allocation by something like 65% and the running time by about 20% (still slow). Next point, you're using div
, which performs a sign check in addition to the normal division. Since all numbers involved are positive, using quot
instead does speed it up a bit more (not much here, but it will become important later).
The next big point is that, since you haven't given type signatures, the type of the numbers (except where it was determined by the use of length
or by the expression type signature (1 :: Int)
in my rewrite) is Integer
. The operations on Integer
are considerably slower than the corresponding operations on Int
, so if possible, you should use Int
(or Word
) rather than Integer
when speed matters. If you have a 64-bit GHC, Int
is sufficient for these computations, that reduces the running time by about half when using div
, by about 70% when using quot
, when using the native code generator, and when using the LLVM backend, the running time is reduced by about 70% when using div
and by about 95% when using quot
.
The difference between the native code generator and the LLVM backend is mostly due to some elementary low-level optimisations.
even
and odd
are defined
even, odd :: (Integral a) => a -> Bool
even n = n `rem` 2 == 0
odd = not . even
in GHC.Real
. When the type is Int
, LLVM knows to replace the division by 2 used to determine the modulus with a bitwise and (n .&. 1 == 0
). The native code generator does not (yet) do many of these low-level optimisations. If you do that by hand, the code produced by the NCG and the LLVM backend performs nearly identically.
When using div
, both, the NCG and LLVM, are not able to replace the division with a short shift-and-add sequence, so you get the relatively slow machine division instruction with the sign-test. With quot
, both are able to do that for Int
, so you get much faster code.
The knowledge that all occurring numbers are positive allows us to replace the division by 2 with a simple right shift, without any code to correct for negative arguments, that speeds up the code produced by the LLVM backend by another ~33%, oddly it doesn't make a difference for the NCG.
So from the original that took eight second plus/minus a bit (a little less with the NCG, a little more with the LLVM backend), we've gone to
module Main (main)
where
import Data.List
import Data.Bits
main = print (answer (1000000 :: Int))
-- Count the length of the sequences
-- count' creates a tuple with the second value
-- being the starting number of the game
-- and the first value being the total
-- length of the chain
count' n = (cSeq n, n)
cSeq n = go (1 :: Int) n
where
go !acc 1 = acc
go acc m
| even' m = go (acc+1) (m `shiftR` 1)
| otherwise = go (acc+1) (3*m+1)
even' :: Int -> Bool
even' m = m .&. 1 == 0
-- Find the maximum chain value of the game
answer n = foldl1' max $ map count' [1..n]
which takes 0.37 seconds with the NCG, and 0.27 seconds with the LLVM backend on my setup.
A minute improvement in running time, but a huge reduction of allocation can be obtained by replacing the foldl1' max
with a manual recursion,
answer n = go 1 1 2
where
go ml mi i
| n < i = (ml,mi)
| l > ml = go l i (i+1)
| otherwise = go ml mi (i+1)
where
l = cSeq i
that makes it 0.35 resp. 0.25 seconds (and produces a tiny 52,936 bytes allocated in the heap
).
Now if that is still too slow, you can worry about a good memoisation strategy. The best I know(1) is to use an unboxed array to store the chain lengths for the numbers not exceeding the limit,
{-# LANGUAGE BangPatterns #-}
module Main (main) where
import System.Environment (getArgs)
import Data.Array.ST
import Data.Array.Base
import Control.Monad.ST
import Data.Bits
main :: IO ()
main = do
args <- getArgs
let bd = case args of
a:_ -> read a
_ -> 100000
print $ mxColl bd
mxColl :: Int -> (Int,Int)
mxColl bd = runST $ do
arr <- newArray (0,bd) 0
unsafeWrite arr 1 1
goColl arr bd 1 1 2
goColl :: STUArray s Int Int -> Int -> Int -> Int -> Int -> ST s (Int,Int)
goColl arr bd ms ml i
| bd < i = return (ms,ml)
| otherwise = do
nln <- collatzLength arr bd i
if ml < nln
then goColl arr bd i nln (i+1)
else goColl arr bd ms ml (i+1)
collatzLength :: STUArray s Int Int -> Int -> Int -> ST s Int
collatzLength arr bd n = go 1 n
where
go !l 1 = return l
go l m
| bd < m = go (l+1) $ case m .&. 1 of
0 -> m `shiftR` 1
_ -> 3*m+1
| otherwise = do
l' <- unsafeRead arr m
case l' of
0 -> do
l'' <- go 1 $ case m .&. 1 of
0 -> m `shiftR` 1
_ -> 3*m+1
unsafeWrite arr m (l''+1)
return (l + l'')
_ -> return (l+l'-1)
which does the job for a limit of 1000000 in 0.04 seconds when compiled with the NCG, 0.05 with the LLVM backend (apparently, that is not as good at optimising STUArray
code as the NCG is).
If you don't have a 64-bit GHC, you can't simply use Int
, since that would overflow then for some inputs.
But the overwhelming part of the computation is still performed in Int
range, so you should use that where possible and only move to Integer
where required.
switch :: Int
switch = (maxBound - 1) `quot` 3
back :: Integer
back = 2 * fromIntegral (maxBound :: Int)
cSeq :: Int -> Int
cSeq n = goInt 1 n
where
goInt acc 1 = acc
goInt acc m
| m .&. 1 == 0 = goInt (acc+1) (m `shiftR` 1)
| m > switch = goInteger (acc+1) (3*toInteger m + 1)
| otherwise = goInt (acc+1) (3*m+1)
goInteger acc m
| fromInteger m .&. (1 :: Int) == 1 = goInteger (acc+1) (3*m+1)
| m > back = goInteger (acc+1) (m `quot` 2) -- yup, quot is faster than shift for Integer here
| otherwise = goInt (acc + 1) (fromInteger $ m `quot` 2)
makes it harder to optimise the loop(s), so it is slower than the single loop using Int
, but still decent. Here (where the Integer
loop is never run), it takes 0.42 seconds with the NCG and 0.37 with the LLVM backend (which is pretty much the same as using quot
in the pure Int
version).
Using a similar trick for the memoised version has similar consequences, it's considerably slower than the pure Int
version, but still blazingly fast compared to unmemoised versions.
(1) For this special (type of) problem, where you need to memoise the results for a contiguous range of arguments. For other problems, a Map
or some other data structure will be the better choice.
Upvotes: 4
Reputation: 9891
It seems that the maximum
function is the culprit as already pointed out, but you shouldn't have to worry about it if you compile your program with the -O2
flag.
The program is still quite slow, this is because the problem is supposed to teach you about memoization. One good way of doing this is haskell is by using Data.Memocombinators
:
import Data.MemoCombinators
import Control.Arrow
import Data.List
import Data.Ord
import System.Environment
play m = maximumBy (comparing snd) . map (second threeNPuzzle) $ zip [1..] [1..m]
where
threeNPuzzle = arrayRange (1,m) memoized
memoized n
| n == 1 = 1
| odd n = 1 + threeNPuzzle (3*n + 1)
| even n = 1 + threeNPuzzle (n `div` 2)
main = getArgs >>= print . play . read . head
The above program runs in under a second when compiled with -O2
on my machine.
Note that in this case it is not a good idea to memoize all values found by threeNPuzzle, the program above memoizes the ones up until the limit (1000000 in the problem).
Upvotes: 0
Reputation: 139840
The problem here is that maximum
is too lazy. Instead of keeping track of the largest element as it goes along, it builds up a huge tree of max
thunks. This is because maximum
is defined in terms of foldl
, so the evaluation goes as follows:
maximum [1, 2, 3, 4, 5]
foldl max 1 [2, 3, 4, 5]
foldl max (max 1 2) [3, 4, 5]
foldl max (max (max 1 2) 3) [4, 5]
foldl max (max (max (max 1 2) 3) 4) [5]
foldl max (max (max (max (max 1 2) 3) 4) 5) []
max (max (max (max 1 2) 3) 4) 5 -- this expression will be huge for large lists
Trying to evaluate too many of these nested max
calls causes a stack overflow.
The solution is to force it to evaluate these as it goes along by using the strict version foldl'
, (or, in this case, its cousin foldl1'
). This prevents the max
's from building up by reducing them at each step:
foldl1' max [1, 2, 3, 4, 5]
foldl' max 1 [2, 3, 4, 5]
foldl' max 2 [3, 4, 5]
foldl' max 3 [4, 5]
foldl' max 4 [5]
foldl' max 5 []
5
GHC can often solve these kinds of problems on its own if you compile with -O2
which (among other things) runs a strictness analysis of your program. However, I think it's good practice to write programs that don't need to rely on optimizations to work.
Note: After fixing this, the resulting program is still very slow. You might want to look into using memoization for this problem.
Upvotes: 4