Reputation: 61
newtype Prob a = Prob { getProb :: [(a,Rational)] } deriving (Show,Eq,Functor)
flatten :: Prob (Prob a) -> Prob a
flatten (Prob xs) = Prob $ concat $ map multAll xs
where multAll (Prob innerxs,p) = map (\(x,r) -> (x,p*r)) innerxs
instance Applicative Prob where
liftA2 fn (Prob x) (Prob y) = Prob [(fn a b,prob1 *prob2) |(a,prob1) <- x, (b,prob2) <- y]
pure x = Prob [(x,1%1)]
instance Monad Prob where
m >>= f = flatten (fmap f m)
makedie x = Prob (zip [1..x] (repeat (1%x)))
proboperate2 op a b= sumprobs (liftA2 op a b)
rerolldie op = proboperate2 op <*> id
sumprobs (Prob a) = Prob [(v,sum (map snd (filter ((==v) . fst) a))) |v <- indivalues]
where indivalues = nub (map fst a)
survivedeath :: Integer -> Prob Integer -> Prob Bool
survivedeath dc die = sumprobs (survivegiven (0,0) =<< die)
where
survivegiven :: (Integer,Integer) -> Integer -> Prob Bool
survivegiven (a,_) _ | a >= 3 = return False
survivegiven (_,a) _ | a >= 3 = return True
survivegiven (a,b) 1 = sumprobs ((survivegiven (a+2,b)) =<< die)
survivegiven (a,b) 20 = return True
survivegiven (a,b) n | n >= dc = sumprobs ((survivegiven (a,1+b)) =<< die)
survivegiven (a,b) n = sumprobs ((survivegiven (1+a,b)) =<< die)
survivedeath starts to take a long time quickly. As far as I can tell, sumprobs
= O(N^2), =<<
= O(N), if survivedeath
goes six times on a d20, it should run 20^2 * 6 or 400 * 6 or 2400 operations, which should happen quickly, why is that not the case?
Upvotes: 0
Views: 92
Reputation: 153172
One thing that decreases the runtime significantly is doing a little bit of work up front to collapse your 20 possible d20 outcomes to just four:
survivedeath dc die = sumprobs (go (0,0)) where
go (a, b) | a >= 3 = return False
| b >= 3 = return True
| otherwise = do
(da, db) <- outcome
go (a+da, b+db)
outcome = sumprobs $ do
n <- die
pure $ case n of
1 -> (2, 0)
20 -> (0, 3)
_ | n >= dc -> (0, 1)
_ -> (1, 0)
For quick ghci-based testing, this appears to give a ~10,000x speedup when computing results for all DCs from 1 to 20 (250s for your version, 0.02s for mine).
A back-of-the-envelope calculation suggests this speedup jives with the explanations given by other answers about why your version is slow. If each probabilistic event has 5x (=20/4) as many possible outcomes, then a depth-6 search through outcomes would have 5^6 = 15,625x as many outcomes to look through.
Upvotes: 3
Reputation: 33519
Although sumprobs
lets you keep your lists short, you are calling survivegiven
for every possible die roll value.
In other words, if die
is a D20,
survivegiven (a,b) n = sumprobs ((survivegiven (1+a,b)) =<< die)
makes 20 recursive calls. So you are still enumerating all dice roll sequences, of which there are only somewhat less than 20^6.
You can easily measure the exact number using Debug.Trace
making every call to survivegiven
print a line, and then counting the lines in the output. Here is your survivedeath
with only one line changed:
import Debug.Trace
...
survivedeath :: Integer -> Prob Integer -> Prob Bool
survivedeath dc die = sumprobs (survivegiven (0,0) =<< die)
where
survivegiven :: (Integer,Integer) -> Integer -> Prob Bool
survivegiven (a,_) _ | trace "X" $ a >= 3 = return False
survivegiven (_,a) _ | a >= 3 = return True
survivegiven (a,b) 1 = sumprobs ((survivegiven (a+2,b)) =<< die)
survivegiven (a,b) 20 = return True
survivegiven (a,b) n | n >= dc = sumprobs ((survivegiven (a,1+b)) =<< die)
survivegiven (a,b) n = sumprobs ((survivegiven (1+a,b)) =<< die)
You can make your code compilable by adding a main
function:
d20 = makedie 20
main :: IO ()
main = print (survivedeath 6 d20)
$ ghc -O A.hs
$ ./A.hs 2> log # store the stderr output in log
$ wc -l log # count lines in log
8664020 log
There are 8.6 million recursive calls, each working with a rather inefficient representation of probability distributions, so it's expected that this takes at least a few seconds.
A faster solution is to change the purpose of the inner function to compute the state space after n
rolls. You only make one recursive call at a time to get the state space after n-1
rolls. Mix in the next die roll, and wrap the result at every step in sumprobs
.
survivedeath :: Integer -> Prob Integer -> Prob Bool
survivedeath dc die = (\(Left b) -> b) <$> surviveafter 6
where
-- Survival states after (up to) n rolls:
-- - Left True (survived)
-- - Left False (died)
-- - Right (a, b) ('a' failed throws, 'b' successfull throws)
surviveafter :: Integer -> Prob (Either Bool (Integer, Integer))
surviveafter 0 = pure (Right (0, 0))
surviveafter n = sumprobs (do
state <- surviveafter (n-1)
case state of
Left _ -> pure state
Right (a, b) -> do
roll <- die
case roll of
1 -> pure (step (a+2, b))
20 -> pure (Left True)
n | n >= dc -> pure (step (a, 1+b))
| otherwise -> pure (step (1+a, b)))
step (a, b) | a >= 3 = Left False
| b >= 3 = Left True
| otherwise = Right (a, b)
Compilable file:
import Data.Ratio ((%))
import Data.List (nub)
newtype Prob a = Prob { getProb :: [(a,Rational)] } deriving (Show,Eq,Functor)
flatten :: Prob (Prob a) -> Prob a
flatten (Prob xs) = Prob $ concat $ map multAll xs
where multAll (Prob innerxs,p) = map (\(x,r) -> (x,p*r)) innerxs
instance Applicative Prob where
liftA2 fn (Prob x) (Prob y) = Prob [(fn a b,prob1 *prob2) |(a,prob1) <- x, (b,prob2) <- y]
pure x = Prob [(x,1%1)]
instance Monad Prob where
m >>= f = flatten (fmap f m)
makedie x = Prob (zip [1..x] (repeat (1%x)))
proboperate2 op a b= sumprobs (liftA2 op a b)
rerolldie op = proboperate2 op <*> id
sumprobs (Prob a) = Prob [(v,sum (map snd (filter ((==v) . fst) a))) |v <- indivalues]
where indivalues = nub (map fst a)
survivedeath :: Integer -> Prob Integer -> Prob Bool
survivedeath dc die = (\(Left b) -> b) <$> surviveafter 6
where
-- Survival state after (up to) n rolls:
-- - Left True (survived)
-- - Left False (died)
-- - Right (a, b) ('a' failed throws, 'b' successfull throws)
surviveafter :: Integer -> Prob (Either Bool (Integer, Integer))
surviveafter 0 = pure (Right (0, 0))
surviveafter n = sumprobs (do
state <- surviveafter (n-1)
case state of
Left _ -> pure state
Right (a, b) -> do
roll <- die
case roll of
1 -> pure (step (a+2, b))
20 -> pure (Left True)
n | n >= dc -> pure (step (a, 1+b))
| otherwise -> pure (step (1+a, b)))
step (a, b) | a >= 3 = Left False
| b >= 3 = Left True
| otherwise = Right (a, b)
d20 = makedie 20
main :: IO ()
main = print (survivedeath 6 d20)
Upvotes: 2
Reputation: 92117
When you roll 6d20, there aren't 20*6 possible results, but 20^6, or 64 million. I'm not certain, but it looks to me like you are handling each of those individually via the recursion in survivegiven
, except for special cases around 1 and 20.
Almost all of this work is wasted effort, because you evaluate the same function call many many times. There's no difference between (a) rolling a 3 and then a 4, or (b) rolling a 4 and then a 3, or (c) rolling a 5 and then a 5, assuming all of these are less than your dc
. But you compute them all separately, and each requires 20^4 followup rolls. Some basic memoization would help, or instead of a fair d20 you could use a die with 4 sides (1, 20, >=dc, and <dc), weighted according to the likelihood of each. That still wastes a little effort, because there's no difference between passing then failing or failing then passing, but 4^6 is a small enough number of possible outcomes it should hardly matter.
Upvotes: 2