Dan T
Dan T

Reputation: 63

Working with the State monad in Haskell

I have been learning some Haskell little by little and am (slowly) working on understanding the State monad, attempting to write a function that repeats a State computation until the state meets some boolean test and collecting the returned values in a list for the overall result. I finally succeeded with this:

collectUntil :: (s -> Bool) -> State s a -> State s [a]
collectUntil f s = do s0 <- get
                      let (a,s') = runState s s0
                      put s'
                      if (f s') then return [a] else liftM (a:) $ collectUntil f s

so that

simpleState = state (\x -> (x,x+1))

*Main> evalState (collectUntil (>10) simpleState) 0
[0,1,2,3,4,5,6,7,8,9,10]

Is this a reasonable function for this task, or is there a more idiomatic way?

Upvotes: 6

Views: 3234

Answers (3)

Chris Taylor
Chris Taylor

Reputation: 47402

You are making exactly the same mistakes that I made when I first started writing monadic code - making it way too complicated, overusing liftM and underusing >>= (equivalently, underusing the <- notation).

Ideally, you shouldn't have to mention runState or evalState inside the state monad at all. The functionality you want is as follows:

  • Read the current state
  • If it satisfies the predicate f, then return
  • If not, then run the computation s and add its result to the output

You can do this quite directly as:

collectUntil f comp = do
    s <- get                              -- Get the current state
    if f s then return []                 -- If it satisfies predicate, return
           else do                        -- Otherwise...
               x  <- comp                 -- Perform the computation s
               xs <- collectUntil f comp  -- Perform the rest of the computation
               return (x:xs)              -- Collect the results and return them

Note that you can nest do statements if they are part of the same monad! This is very useful - it allows you to branch within one do block, as long as both branches of the if statement lead to something of the same monadic type.

The inferred type for this function is:

collectUntil :: MonadState t m => (t -> Bool) -> m a -> m [a]

If you wish, you can specialise that to the State s type, although you don't have to:

collectUntil :: (s -> Bool) -> State s a -> State s [a]

It might even be preferable to keep the more general state, in case you want to use a different monad later.

What's the intuition?

Whenever s is a stateful computation and you are inside the state monad, you can do

x <- s

and x will now have the result of the computation (as if you'd called evalState and fed in an initial state). If you ever need to check the state, you can do

s' <- get

and s' will have the value of the current state.

Upvotes: 10

Riccardo T.
Riccardo T.

Reputation: 8937

For such a simple task I would not use the State monad. The others already clarified how you actually should write the monadic version, but I would like to add my personal (simpler) solution since you're asking for the most idiomatic way to write that.

collectWhile, collectUntil :: (a -> a) -> (a -> Bool) -> a -> [a]
collectWhile f cond z = takeWhile cond $ iterate f z
collectUntil f cond z = collectWhile f (not . cond) z

Alternatively, just the following line is enough if you only want collectUntil

collectUntil f cond z = takeWhile (not.cond) $ iterate f z

Here takeWhile and iterate are from Prelude. For completeness, as it's the core of the implementation, the following is the (very simple) code for iterate:

iterate f x =  x : iterate f (f x)

warning: probably this wasn't clear enough from my answer, but this solution isn't really the same since I fuse together state and result by working outside State. Of course one may do something very similar by using f :: (s, a) -> (s, a) and then projecting with map fst or map snd to get respectively the list of intermediate states or results. For ease of notation at this point it may be simpler to use the solution with State, though.

Upvotes: 2

Heatsink
Heatsink

Reputation: 7761

Most monads come with a few primitive "run" operations such as runState, execState, and so forth. If you are frequently calling runState inside the state monad, it means you are not really using the functionality the monad provides. You have written

s0 <- get                    -- Read state
let (a,s') = runState s s0   -- Pass state to 's', get new state
put s'                       -- Save new state

You do not have to explicitly pass the state around. This is what the state monad does! You can just write

a <- s

Otherwise, the function looks reasonable. Since a is part of the result in both branches of the 'if', I would suggest factoring that out for clarity.

collectUntil f s = step
  where
    step = do a <- s
              liftM (a:) continue
    continue = do s' <- get
                  if f s' then return [] else step

Upvotes: 4

Related Questions