Skynet90
Skynet90

Reputation: 11

How to define a state monad?

I want to define a State monad that manages errors (in a sense like Maybe): if an error/problem occurs during the "do" computation, it is signal led and propagated by >>=. The error should also contain a string describing it. After, i want to apply this monad to mapTreeM, using for map a function that assumes states as numbers and a tree containing numbers, and at each visiting step updates the current state by adding to it the value of the current leaf; the resulting tree must contain a pair with the old leaf value and the state at the visiting instant. Such visit must fail if the state becomes negative during the computation, and succeed if it is positive.

e.g. Given this tree: Branch (Branch (Leaf 7) (Branch (Leaf (-1)) (Leaf 3))) (Branch (Leaf (-2)) (Leaf 9))

We obtain a tree (considering the initial state 0): Branch (Branch (Leaf (7,7)) (Branch (Leaf (-1,6)) (Leaf (3,9)))) (Branch (Leaf (-2,7)) (Leaf (9,16)))

If we put -18 in the second leaf, we should obtain an erroneous value signaling that we reached a negative state (-11).

I did a thing like this to print the tree without managing errors...i haven't understood how to do it. The following is my code:

module Main where

-- State monad
newtype State st a = State (st -> (st, a))

instance Monad (State state) where

return x = State(\s -> (s,x))

State f >>= g = State(\oldstate -> 
                    let (newstate, val) = f oldstate
                        State newf      = g val
                    in newf newstate)




-- Recursive data structure for representing trees
data Tree a = Leaf a | Branch (Tree a) (Tree a)
          deriving (Show,Eq)


-- Utility methods     
getState :: State state state
getState = State(\state -> (state,state))

putState :: state -> State state ()
putState new = State(\_ -> (new, ()))


mapTreeM :: (Num a) => (a -> State state b) -> Tree a -> State state (Tree b)
mapTreeM f (Leaf a) = 
f a >>= (\b -> return (Leaf b))
mapTreeM f (Branch lhs rhs) = do
 lhs' <- mapTreeM f lhs
 rhs' <- mapTreeM f rhs
 return (Branch lhs' rhs')


numberTree :: (Num a) => Tree a -> State a (Tree (a,a))
numberTree tree = mapTreeM number tree
    where number v = do
        cur <- getState
        putState(cur+v)
        return (v,cur+v)

-- An instance of a tree                         
testTree = (Branch 
              (Branch
                (Leaf 7) (Branch (Leaf (-1)) (Leaf 3)))
           (Branch
                (Leaf (-2)) (Leaf (-20))))


runStateM :: State state a -> state -> a
runStateM (State f) st = snd (f st)


main :: IO()              
main = print $ runStateM (numberTree testTree) 0

Upvotes: 0

Views: 304

Answers (2)

pat
pat

Reputation: 12749

By making your Tree datatype an instance of Traversable, you can use mapM (from Data.Traversable) to map an action over a Tree. You can also layer the StateT monad transformer atop the Either monad to provide error handling.

import Control.Monad.State
import Control.Applicative
import Control.Monad.Error
import Data.Monoid
import Data.Foldable
import Data.Traversable
import qualified Data.Traversable as T

-- our monad which carries state but allows for errors with string message    
type M s = StateT s (Either String)

data Tree a = Leaf a | Branch (Tree a) (Tree a)
          deriving (Show,Eq)

-- Traversable requires Functor
instance Functor Tree where
    fmap f (Leaf a) = Leaf (f a)
    fmap f (Branch lhs rhs) = Branch (fmap f lhs) (fmap f rhs)

-- Traversable requires Foldable
instance Foldable Tree where
    foldMap f (Leaf a) = f a
    foldMap f (Branch lhs rhs) = foldMap f lhs `mappend` foldMap f rhs

-- Finally, we can get to Traversable
instance Traversable Tree where
    traverse f (Leaf a) = Leaf <$> f a
    traverse f (Branch lhs rhs) = Branch <$> traverse f lhs <*> traverse f rhs

testTree = (Branch
              (Branch
                (Leaf 7) (Branch (Leaf (-1)) (Leaf 3)))
           (Branch
                (Leaf (-2)) (Leaf (-20))))

numberTree :: (Num a, Ord a) => Tree a -> M a (Tree (a,a))
numberTree = T.mapM number where
    number v = do
        cur <- get
        let nxt = cur+v
        -- lift the error into the StateT layer
        when (nxt < 0) $ throwError "state went negative"
        put nxt
        return (v, nxt)

main :: IO ()
main =
    case evalStateT (numberTree testTree) 0 of
        Left e -> putStrLn $ "Error: " ++ e
        Right t -> putStrLn $ "Success: " ++ show t

Upvotes: 1

mariop
mariop

Reputation: 3226

Can I propose an alternative solution to your problem? While Monads are good for many things, what you want to do can be done with a simple function that keeps track of errors. My function transferVal below is an example of such function. The function transferVal traverses the Tree from left to right while keeping the last value found. If an error occurs, the function returns the error and stops traversing the Tree. Instead of using Maybe, it is often better to use Either <error_type> <result_type> to get a more clear error if something goes wrong. In my example, I use ([ChildDir],a) where [ChildDir] contains the "direction" of the incriminated node and a is the erroneous value that triggered the error. The function printErrorsOrTree is an example of how you can use the output of transferVal and main contains 4 examples of which the first three are correct and the last one triggers the error that you was expecting.

module Main where

import Data.List     (intercalate)
import Control.Monad (mapM_)

data Tree a = Leaf a | Branch (Tree a) (Tree a)
  deriving (Show,Eq)

-- given a Branch, in which child the error is?
data ChildDir = LeftChild | RightChild
  deriving Show

-- an error is the direction to get to the error from the root and the
-- value that triggered the error
type Error a = ([ChildDir],a)

-- util to append a direction to an error
appendDir :: ChildDir -> Error a -> Error a
appendDir d (ds,x) = (d:ds,x)

transferVal :: (Ord a,Num a) => Tree a -> Either (Error a) (Tree (a,a))
transferVal = fmap fst . go 0
  where go :: (Ord a,Num a) => a -> Tree a -> Either (Error a) (Tree (a,a),a)
        go c (Leaf x) = let newC = x + c
                        in if newC < 0
                             then Left ([],newC)
                             else Right (Leaf (x,newC),newC)
        go c (Branch t1 t2) = case go c t1 of
            Left e             -> Left $ appendDir LeftChild e
            Right (newT1,newC) -> case go newC t2 of
                Left              e -> Left $ appendDir RightChild e 
                Right (newT2,newC') -> Right (Branch newT1 newT2,newC')

printErrorsOrTree :: (Show a,Show b) => Either (Error a) (Tree b) -> IO ()
printErrorsOrTree (Left (ds,x)) = putStrLn $ "Error in position " ++ (intercalate " -> " $ map show ds) ++ ". Error value is " ++ show x
printErrorsOrTree (Right     t) = putStrLn $ "Result: " ++ show t

main :: IO ()
main = mapM_ runExample
             [(Leaf 1)
             ,(Branch (Leaf 1) (Leaf 2))
             ,(Branch (Branch (Leaf 7) (Branch (Leaf (-1)) (Leaf 3))) (Branch (Leaf (-2)) (Leaf 9)))
             ,(Branch (Branch (Leaf 7) (Branch (Leaf (-11)) (Leaf 3))) (Branch (Leaf (-2)) (Leaf 9)))]
  where runExample orig = do
          let res = transferVal orig
          print orig
          printErrorsOrTree res

Upvotes: 1

Related Questions