spencerlyon2
spencerlyon2

Reputation: 9686

Using monads: folding over part of its contents

I'm just finishing up learn you a Haskell for great good and am still struggling with how to work with Monads.

At the very end of the chapter for a few monads more the author gives an exercise (second to last paragraph). Specifically, he encourages us to write a function that will collapse all the (False, Rational) values into a single value (False, sum Rationals).

I have written out all the code that was presented in the chapter and included the relevant portion here

import Data.Ratio
import Control.Monad
import Control.Applicative
import Data.List (all)

newtype Prob a = Prob { getProb :: [(a, Rational)]} deriving (Show)

instance Functor Prob where
    fmap f (Prob xs) = Prob $ map (\(x, p) -> (f x, p)) xs

flatten :: Prob (Prob a) -> Prob a
flatten (Prob xs) = Prob $ concat $ map multAll xs
    where multAll (Prob innerxs, p) = map (\(x, r) -> (x, p*r)) innerxs

instance Applicative Prob where
    pure  = return
    (<*>) = ap

instance Monad Prob where
    return x = Prob [(x, 1%1)]
    m >>= f = flatten (fmap f m)
    fail _ = Prob []

data Coin = Heads | Tails deriving (Show, Eq)

coin :: Prob Coin
coin = Prob [(Heads, 1%2), (Tails, 1%2)]

loadedCoin :: Prob Coin
loadedCoin = Prob [(Heads, 1%10), (Tails, 9%10)]

flipThree :: Prob Bool
flipThree = do
    a <- coin
    b <- coin
    c <- loadedCoin
    return (all (== Tails) [a, b, c])

When I run this code I get

ghci> getProb  flipThree
[(False,1 % 40),(False,9 % 40),(False,1 % 40),(False,9 % 40),(False,1 % 40),(False,9 % 40),(False,1 % 40),(True,9 % 40)]

I would like to somehow filter the elements of flipThree that have a False in their first position and then sum the associated probabilities. I have written some ugly non-monadic code to do this, but am confident there is a better way.

The desired output would be

ghci> getProb  flipThree
[(False,31 % 40),(True,9 % 40)]

Upvotes: 0

Views: 109

Answers (1)

user2407038
user2407038

Reputation: 14598

The function that you want is

import Data.Function
import Data.List 

runProb :: Eq a => Prob a -> [(a, Rational)]
runProb =  map (\x -> (fst (head x), sum (map snd x))) 
         . groupBy ((==) `on` fst) 
         . getProb 

so that

>runProb flipThree
[(False,31 % 40),(True,9 % 40)]

You could also make your type an instance of MonadPlus to sort-of get this result "directly":

instance MonadPlus Prob where 
  mzero = Prob [] 
  mplus (Prob x) (Prob y) = Prob (x++y) 

flipThree' = do
    a <- coin
    b <- coin
    c <- loadedCoin
    guard (all (== Tails) [a, b, c])

> runProb flipThree'
[((),9 % 40)]

Upvotes: 3

Related Questions