Reputation: 9686
I'm just finishing up learn you a Haskell for great good and am still struggling with how to work with Monads.
At the very end of the chapter for a few monads more the author gives an exercise (second to last paragraph). Specifically, he encourages us to write a function that will collapse all the (False, Rational)
values into a single value (False, sum Rationals)
.
I have written out all the code that was presented in the chapter and included the relevant portion here
import Data.Ratio
import Control.Monad
import Control.Applicative
import Data.List (all)
newtype Prob a = Prob { getProb :: [(a, Rational)]} deriving (Show)
instance Functor Prob where
fmap f (Prob xs) = Prob $ map (\(x, p) -> (f x, p)) xs
flatten :: Prob (Prob a) -> Prob a
flatten (Prob xs) = Prob $ concat $ map multAll xs
where multAll (Prob innerxs, p) = map (\(x, r) -> (x, p*r)) innerxs
instance Applicative Prob where
pure = return
(<*>) = ap
instance Monad Prob where
return x = Prob [(x, 1%1)]
m >>= f = flatten (fmap f m)
fail _ = Prob []
data Coin = Heads | Tails deriving (Show, Eq)
coin :: Prob Coin
coin = Prob [(Heads, 1%2), (Tails, 1%2)]
loadedCoin :: Prob Coin
loadedCoin = Prob [(Heads, 1%10), (Tails, 9%10)]
flipThree :: Prob Bool
flipThree = do
a <- coin
b <- coin
c <- loadedCoin
return (all (== Tails) [a, b, c])
When I run this code I get
ghci> getProb flipThree
[(False,1 % 40),(False,9 % 40),(False,1 % 40),(False,9 % 40),(False,1 % 40),(False,9 % 40),(False,1 % 40),(True,9 % 40)]
I would like to somehow filter the elements of flipThree
that have a False
in their first position and then sum the associated probabilities. I have written some ugly non-monadic code to do this, but am confident there is a better way.
The desired output would be
ghci> getProb flipThree
[(False,31 % 40),(True,9 % 40)]
Upvotes: 0
Views: 109
Reputation: 14598
The function that you want is
import Data.Function
import Data.List
runProb :: Eq a => Prob a -> [(a, Rational)]
runProb = map (\x -> (fst (head x), sum (map snd x)))
. groupBy ((==) `on` fst)
. getProb
so that
>runProb flipThree
[(False,31 % 40),(True,9 % 40)]
You could also make your type an instance of MonadPlus
to sort-of get this result "directly":
instance MonadPlus Prob where
mzero = Prob []
mplus (Prob x) (Prob y) = Prob (x++y)
flipThree' = do
a <- coin
b <- coin
c <- loadedCoin
guard (all (== Tails) [a, b, c])
> runProb flipThree'
[((),9 % 40)]
Upvotes: 3