user28080356
user28080356

Reputation: 61

Why does this haskell program have incorrect time complexity?

newtype Prob a = Prob { getProb :: [(a,Rational)] } deriving (Show,Eq,Functor)
flatten :: Prob (Prob a) -> Prob a
flatten (Prob xs) = Prob $ concat $ map multAll xs
    where multAll (Prob innerxs,p) = map (\(x,r) -> (x,p*r)) innerxs
instance Applicative Prob where 
    liftA2 fn (Prob x) (Prob y) = Prob [(fn a b,prob1 *prob2) |(a,prob1) <- x, (b,prob2) <- y]
    pure x = Prob [(x,1%1)]
instance Monad Prob where
    m >>= f = flatten (fmap f m)
makedie x = Prob (zip [1..x] (repeat (1%x)))
proboperate2 op a b= sumprobs (liftA2 op a b)
rerolldie op = proboperate2 op <*> id
sumprobs (Prob a) = Prob [(v,sum (map snd (filter ((==v) . fst) a))) |v <- indivalues]
    where indivalues = nub (map fst a)
survivedeath :: Integer -> Prob Integer -> Prob Bool 
survivedeath dc die = sumprobs (survivegiven (0,0) =<< die) 
    where 
        survivegiven :: (Integer,Integer) -> Integer -> Prob Bool
        survivegiven (a,_) _ | a >= 3 = return False 
        survivegiven (_,a) _ | a >= 3 = return True 
        survivegiven (a,b) 1 = sumprobs ((survivegiven (a+2,b)) =<< die)
        survivegiven (a,b) 20 = return True 
        survivegiven (a,b) n | n >= dc = sumprobs ((survivegiven (a,1+b)) =<< die)
        survivegiven (a,b) n  = sumprobs ((survivegiven (1+a,b)) =<< die)

survivedeath starts to take a long time quickly. As far as I can tell, sumprobs = O(N^2), =<< = O(N), if survivedeath goes six times on a d20, it should run 20^2 * 6 or 400 * 6 or 2400 operations, which should happen quickly, why is that not the case?

Upvotes: 0

Views: 92

Answers (3)

Daniel Wagner
Daniel Wagner

Reputation: 153172

One thing that decreases the runtime significantly is doing a little bit of work up front to collapse your 20 possible d20 outcomes to just four:

survivedeath dc die = sumprobs (go (0,0)) where
    go (a, b) | a >= 3 = return False
              | b >= 3 = return True
              | otherwise = do
      (da, db) <- outcome
      go (a+da, b+db)
    outcome = sumprobs $ do
      n <- die
      pure $ case n of
        1 -> (2, 0)
        20 -> (0, 3)
        _ | n >= dc -> (0, 1)
        _ -> (1, 0)

For quick ghci-based testing, this appears to give a ~10,000x speedup when computing results for all DCs from 1 to 20 (250s for your version, 0.02s for mine).

A back-of-the-envelope calculation suggests this speedup jives with the explanations given by other answers about why your version is slow. If each probabilistic event has 5x (=20/4) as many possible outcomes, then a depth-6 search through outcomes would have 5^6 = 15,625x as many outcomes to look through.

Upvotes: 3

Li-yao Xia
Li-yao Xia

Reputation: 33519

Although sumprobs lets you keep your lists short, you are calling survivegiven for every possible die roll value.

In other words, if die is a D20,

survivegiven (a,b) n  = sumprobs ((survivegiven (1+a,b)) =<< die)

makes 20 recursive calls. So you are still enumerating all dice roll sequences, of which there are only somewhat less than 20^6.

You can easily measure the exact number using Debug.Trace making every call to survivegiven print a line, and then counting the lines in the output. Here is your survivedeath with only one line changed:

import Debug.Trace

...

survivedeath :: Integer -> Prob Integer -> Prob Bool 
survivedeath dc die = sumprobs (survivegiven (0,0) =<< die) 
    where 
        survivegiven :: (Integer,Integer) -> Integer -> Prob Bool
        survivegiven (a,_) _ | trace "X" $ a >= 3 = return False 
        survivegiven (_,a) _ | a >= 3 = return True 
        survivegiven (a,b) 1 = sumprobs ((survivegiven (a+2,b)) =<< die)
        survivegiven (a,b) 20 = return True 
        survivegiven (a,b) n | n >= dc = sumprobs ((survivegiven (a,1+b)) =<< die)
        survivegiven (a,b) n  = sumprobs ((survivegiven (1+a,b)) =<< die)

You can make your code compilable by adding a main function:

d20 = makedie 20

main :: IO ()
main = print (survivedeath 6 d20)
$ ghc -O A.hs
$ ./A.hs 2> log   # store the stderr output in log
$ wc -l log       # count lines in log
8664020 log

There are 8.6 million recursive calls, each working with a rather inefficient representation of probability distributions, so it's expected that this takes at least a few seconds.


A faster solution is to change the purpose of the inner function to compute the state space after n rolls. You only make one recursive call at a time to get the state space after n-1 rolls. Mix in the next die roll, and wrap the result at every step in sumprobs.

survivedeath :: Integer -> Prob Integer -> Prob Bool
survivedeath dc die = (\(Left b) -> b) <$> surviveafter 6
    where 
        -- Survival states after (up to) n rolls:
        -- - Left True (survived)
        -- - Left False (died)
        -- - Right (a, b) ('a' failed throws, 'b' successfull throws)
        surviveafter :: Integer -> Prob (Either Bool (Integer, Integer))
        surviveafter 0 = pure (Right (0, 0))
        surviveafter n = sumprobs (do
          state <- surviveafter (n-1)
          case state of
            Left _ -> pure state
            Right (a, b) -> do
              roll <- die
              case roll of
                1 -> pure (step (a+2, b))
                20 -> pure (Left True)
                n | n >= dc -> pure (step (a, 1+b))
                  | otherwise -> pure (step (1+a, b)))
        step (a, b) | a >= 3 = Left False
                    | b >= 3 = Left True
                    | otherwise = Right (a, b)

Compilable file:

import Data.Ratio ((%))
import Data.List (nub)

newtype Prob a = Prob { getProb :: [(a,Rational)] } deriving (Show,Eq,Functor)
flatten :: Prob (Prob a) -> Prob a
flatten (Prob xs) = Prob $ concat $ map multAll xs
    where multAll (Prob innerxs,p) = map (\(x,r) -> (x,p*r)) innerxs
instance Applicative Prob where 
    liftA2 fn (Prob x) (Prob y) = Prob [(fn a b,prob1 *prob2) |(a,prob1) <- x, (b,prob2) <- y]
    pure x = Prob [(x,1%1)]
instance Monad Prob where
    m >>= f = flatten (fmap f m)
makedie x = Prob (zip [1..x] (repeat (1%x)))
proboperate2 op a b= sumprobs (liftA2 op a b)
rerolldie op = proboperate2 op <*> id
sumprobs (Prob a) = Prob [(v,sum (map snd (filter ((==v) . fst) a))) |v <- indivalues]
    where indivalues = nub (map fst a)
survivedeath :: Integer -> Prob Integer -> Prob Bool
survivedeath dc die = (\(Left b) -> b) <$> surviveafter 6
    where 
        -- Survival state after (up to) n rolls:
        -- - Left True (survived)
        -- - Left False (died)
        -- - Right (a, b) ('a' failed throws, 'b' successfull throws)
        surviveafter :: Integer -> Prob (Either Bool (Integer, Integer))
        surviveafter 0 = pure (Right (0, 0))
        surviveafter n = sumprobs (do
          state <- surviveafter (n-1)
          case state of
            Left _ -> pure state
            Right (a, b) -> do
              roll <- die
              case roll of
                1 -> pure (step (a+2, b))
                20 -> pure (Left True)
                n | n >= dc -> pure (step (a, 1+b))
                  | otherwise -> pure (step (1+a, b)))
        step (a, b) | a >= 3 = Left False
                    | b >= 3 = Left True
                    | otherwise = Right (a, b)
d20 = makedie 20

main :: IO ()
main = print (survivedeath 6 d20)

Upvotes: 2

amalloy
amalloy

Reputation: 92117

When you roll 6d20, there aren't 20*6 possible results, but 20^6, or 64 million. I'm not certain, but it looks to me like you are handling each of those individually via the recursion in survivegiven, except for special cases around 1 and 20.

Almost all of this work is wasted effort, because you evaluate the same function call many many times. There's no difference between (a) rolling a 3 and then a 4, or (b) rolling a 4 and then a 3, or (c) rolling a 5 and then a 5, assuming all of these are less than your dc. But you compute them all separately, and each requires 20^4 followup rolls. Some basic memoization would help, or instead of a fair d20 you could use a die with 4 sides (1, 20, >=dc, and <dc), weighted according to the likelihood of each. That still wastes a little effort, because there's no difference between passing then failing or failing then passing, but 4^6 is a small enough number of possible outcomes it should hardly matter.

Upvotes: 2

Related Questions