crockeea
crockeea

Reputation: 21811

New scope in 'do' notation

I'm trying to write a recursive function that mutates a Data.Vector.Unboxed.Mutable 'Vector', though the question applies to any monadic code, I think.

As a contrived example:

import Data.Vector.Unboxed as U
import Data.Vector.Unboxed.Mutable as M
import Control.Monad
import Control.Monad.ST
import Control.Monad.Primitive

f :: U.Vector Int -> U.Vector Int
f x = runST $ do
        y <- U.thaw x
        add1 y 0
        U.freeze y

add1 :: (PrimMonad m) => MVector (PrimState m) Int -> Int -> m()
add1 v i | i == M.length v = return ()
add1 v i = do
     c <- M.unsafeRead v i
     M.unsafeWrite v i (c + 1)
     add1 v (i+1)

However, v does not change in each recursive call. I would like to be able to remove v as a parameter to the function and inline 'add1' into f, but I need 'y' to be in scope.

I can get one step closer is by changing add1 (and keeping f the same) so that v is not passed in the recursion:

add1 :: (PrimMonad m) => MVector (PrimState m) Int -> m()
add1 v = do add1_ 0
    where len = M.length v
          add1_ i | i == len = do return ()
          add1_ i = do
                x <- M.unsafeRead v i
                M.unsafeWrite v i (x + 1)
                add1_ (i+1)

What I would really like is to totally inline add1 though. Here's a solution that doesn't quite compile yet:

f x = let len = U.length x
          y = U.thaw x
          add1 i | i == len = return ()
          add1 i = do
             y' <- y
             c <- M.unsafeRead y' i
             M.unsafeWrite y' i (c+1)
             add1 (i+1)
      in runST $ do
            add1 0
            y' <- y
            U.freeze y'

GHC errors:

couldn't match type 'm0' with 'ST s'
couldn't match type 's' with 'PrimState m0'

Errors aside, this isn't optimal yet: I don't want to have to do (y' <- y) in every do statement (especially when add1 is recursive). I'd really like y' (the 'non-monadic' version of y) to just be in scope. Is there any way to do this?

(I apologize if I am horribly misusing monads in some way)

Upvotes: 2

Views: 213

Answers (1)

Daniel Wagner
Daniel Wagner

Reputation: 152847

How about this?

f :: U.Vector Int -> U.Vector Int
f x = runST $ do
    y <- U.thaw x
    let add1 i | i == length x = return ()
               | otherwise     = do
            c <- M.unsafeRead y i
            M.unsafeWrite y i (c+1)
            add1 (i+1)
    add1 0
    U.freeze y

Upvotes: 5

Related Questions