coder_bro
coder_bro

Reputation: 10773

How to derive a state monad from first principles?

I am trying to come up with an implementation of State Monad derived from examples of function composition. Here I what I came up with:

First deriving the concept of Monad:

data Maybe' a = Nothing' | Just' a deriving Show

sqrt' :: (Floating a, Ord a) => a -> Maybe' a
sqrt' x = if x < 0 then Nothing' else Just' (sqrt x)

inv' :: (Floating a, Ord a) => a -> Maybe' a
inv' x = if x == 0 then Nothing' else Just' (1/x)

log' :: (Floating a, Ord a) => a -> Maybe' a
log' x = if x == 0 then Nothing' else Just' (log x)

We can have function that composes these functions as follows:

sqrtInvLog' :: (Floating a, Ord a) => a -> Maybe' a
sqrtInvLog' x = case (sqrt' x) of
                  Nothing' -> Nothing'
                  (Just' y) -> case (inv' y) of
                                Nothing' -> Nothing'
                                (Just' z) -> log' z

This could be simplified by factoring out the case statement and function application:

fMaybe' :: (Maybe' a) -> (a -> Maybe' b) -> Maybe' b
fMaybe' Nothing' _ = Nothing'
fMaybe' (Just' x) f = f x

-- Applying fMaybe' =>
sqrtInvLog'' :: (Floating a, Ord a) => a -> Maybe' a
sqrtInvLog'' x = (sqrt' x) `fMaybe'` (inv') `fMaybe'` (log')`

Now we can generalize the concept to any type, instead of just Maybe' by defining a Monad =>

class Monad' m where
  bind' :: m a -> (a -> m b) -> m b
  return' :: a -> m a

instance Monad' Maybe' where
  bind' Nothing' _ = Nothing'
  bind' (Just' x) f = f x
  return' x = Just' x

Using Monad' implementation, sqrtInvLog'' can be written as:

sqrtInvLog''' :: (Floating a, Ord a) => a -> Maybe' a
sqrtInvLog''' x = (sqrt' x) \bind'` (inv') `bind'` (log')`

Trying to apply the concept to maintain state, I defined something as shown below:

data St a s = St (a,s) deriving Show

sqrtLogInvSt' :: (Floating a, Ord a) => St a a -> St (Maybe' a) a
sqrtLogInvSt' (St (x,s)) = case (sqrt' x) of
                             Nothing' -> St (Nothing', s)
                             (Just' y) -> case (log' y) of
                                            Nothing' -> St (Nothing', s+y)
                                            (Just' z) -> St (inv' z, s+y+z)

It is not possible to define a monad using the above definition as bind needs to be defined as taking in a single type "m a".

Second attempt based on Haskell's definition of State Monad:

newtype State s a = State { runState :: s -> (a, s) }

First attempt to define function that is built using composed functions and maintains state:

fex1 :: Int->State Int Int
fex1 x = State { runState = \s->(r,(s+r)) } where r = x `mod` 2`

fex2 :: Int->State Int Int
fex2 x = State { runState = \s-> (r,s+r)} where r = x * 5

A composed function:

fex3 x = (runState (fex2 y)) st where (st, y) = (runState (fex1 x)) 0

But the definition newtype State s a = State { runState :: s -> (a, s) } does not fit the pattern of m a -> (a -> m b) -> m b of bind

An attempt could be made as follows:

instance Monad' (State s) where
   bind' st f = undefined
   return' x = State { runState = \s -> (x,s) }

bind' is undefined above becuase I did not know how I would implement it.

I could derive why monads are useful and apply it the first example (Maybe') but cannot seem to apply it to State. It will be useful to understand how I could derive the State Moand using concepts defined above.

Note that I have asked a similar question earlier: Haskell - Unable to define a State monad like function using a Monad like definition but I have expanded here and added more details.

Upvotes: 2

Views: 311

Answers (2)

cibercitizen1
cibercitizen1

Reputation: 21476

Have a look to this. Summing and extending a bit.

If you have a function

ta -> tb

and want to add "state" to it, then you should pass that state along, and have

(ta, ts) -> (tb, ts)

You can transform this by currying it:

ta -> ts -> (tb, ts)

If you compare this with the original ta -> tb, we obtain (adding parentheses)

ta -> (ts -> (tb, ts))

Summing up, if a function returns tb from ta (i.e. ta -> tb), a "stateful" version of it will return (ts -> (tb, ts)) from ta (i.e. ta -> (ts -> (tb, ts)))

Therefore, a "stateful computation" (just one function, or either a chain of functions dealing with state) must return/produce this:

(ts -> (tb, ts))

This is the typical case of a stack of ints. [Int] is the State

pop :: [Int] -> (Int, [Int]) -- remove top
pop (v:s) -> (v, s)

push :: Int -> [Int] -> (int, [Int]) -- add to the top
push v s -> (v, v:s)

For push, the type of push 5 is the same than type of pop :: [Int] -> (Int, [Int]). So we would like to combine/join this basic operations to get things as

duplicateTop =
   v <- pop
   push v
   push v

Note that, as desired, duplicateTop :: [Int] -> (Int, [Int])

Now: how to combine two stateful computations to get a new one? Let's do it (Caution: this definition is not the same that the used for the bind of monad (>>=) but it is equivalent).

Combine:

f :: ta -> (ts -> (tb, ts))

with

g :: tb -> (ts -> (tc, ts))

to get

h :: ta -> (ts -> (tc, ts))

This is the construction of h (in pseudo-haskell)

h = \a -> ( \s -> (c, s') )

where we have to calculate (c, s') (the rest in the expressions are just parameters a and s). Here it is how:

                   -- 1. run f using a and s
  l1 = f a         -- use the parameter a to get the function (ts -> (tb, ts)) returned by f 
  (b, s1) = l1 s   -- use the parameter s to get the pair that l1 returns ( :: (tb, ts) )
                   -- 2. run g using f output, b and s1
  l2  = g b        -- use b to get the function (ts -> (tb, ts)) returned by g
  (c, s') = l2 s1  -- use s1 to get the pair that l2 returns ( :: (tc, ts) ) 

Upvotes: 1

melpomene
melpomene

Reputation: 85757

Your composed function fex3 has the wrong type:

fex3 :: Int -> (Int, Int)

Unlike with your sqrtInvLog' example for Maybe', State does not appear in the type of fex3.

We could define it as

fex3 :: Int -> State Int Int
fex3 x = State { runState = \s ->
    let (y, st) = runState (fex1 x) s in
        runState (fex2 y) st }

The main difference to your definition is that instead of hardcoding 0 as the initial state, we pass on our own state s.

What if (like in your Maybe example) we wanted to compose three functions? Here I'll just reuse fex2 instead of introducing another intermediate function:

fex4 :: Int -> State Int Int
fex4 x = State { runState = \s ->
        let (y, st) = runState (fex1 x) s in
            let (z, st') = runState (fex2 y) st in
                runState (fex2 z) st' }

SPOILERS:

The generalized version bindState can be extracted as follows:

bindState m f = State { runState = \s ->
    let (x, st) = runState m s in
    runState (f x) st }

fex3' x = fex1 x `bindState` fex2
fex4' x = fex1 x `bindState` fex2 `bindState` fex2

We can also start with Monad' and types.

The m in the definition of Monad' is applied to one type argument (m a, m b). We can't set m = State because State requires two arguments. On the other hand, partial application is perfectly valid for types: State s a really means (State s) a, so we can set m = State s:

instance Monad' (State s) where
   -- return' :: a -> m a (where m = State s)
   -- return' :: a -> State s a
   return' x = State { runState = \s -> (x,s) }

   -- bind' :: m a -> (a -> m b) -> m b (where m = State s)
   -- bind' :: State s a -> (a -> State s b) -> State s b
   bind' st f =
   -- Good so far: we have two arguments
   --   st :: State s a
   --   f :: a -> State s b
   -- We also need a result
   --   ... :: State s b
   -- It must be a State, so we can start with:
       State { runState = \s ->
   -- Now we also have
   --   s :: s
   -- That means we can run st:
           let (x, s') = runState st s in
   --   runState :: State s a -> s -> (a, s)
   --   st :: State s a
   --   s :: s
   --   x :: a
   --   s' :: s
   -- Now we have a value of type 'a' that we can pass to f:
   --   f x :: State s b
   -- We are already in a State { ... } context, so we need
   -- to return a (value, state) tuple. We can get that from
   -- 'State s b' by using runState again:
           runState (f x) s'
       }

Upvotes: 3

Related Questions