Abhijit Sarkar
Abhijit Sarkar

Reputation: 24593

How to test Monad instance for custom StateT?

I'm learning Monad Transformers, and one of the exercises asks to implement the Monad instance for StateT. I want to test that my implementation admits to the Monad laws using the validity package, which is like the checkers package.

Problem is, my Arbitrary instance doesn't compile. I saw this question, but it doesn't quite do what I want because the test basically duplicates the implementation and doesn't check the laws. There's also this question, but it's unanswered, and I've already figured out how to test Monad Transformers not involving functions (like MaybeT).

{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE InstanceSigs #-}

module Ch11.MonadT (StT (..)) where

import Control.Monad.Trans.State (StateT (..))

newtype StT s m a = StT (s -> m (a, s))
  deriving
    (Functor, Applicative)
    via StateT s m

instance (Monad m) => Monad (StT s m) where
  return :: a -> StT s m a
  return = pure

  (>>=) :: StT s m a -> (a -> StT s m b) -> StT s m b
  StT x >>= f = StT $ \s -> do
    (k, s') <- x s
    let StT y = f k
    y s'

  (>>) :: StT s m a -> StT s m b -> StT s m b
  (>>) = (*>)

My test:

{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TypeApplications #-}

module Ch11.MonadTSpec (spec) where

import Ch11.MonadT (StT (..))
import Test.Hspec
import Test.QuickCheck
import Test.Validity.Monad

spec :: Spec
spec = do
  monadSpecOnArbitrary @(StTArbit Int [] Int)

-- create wrapper to avoid orphan instance error
newtype StTArbit s m a = StTArbit (StT s m a)
  deriving
    (Functor, Applicative, Monad)

instance (Arbitrary s, Function s, Arbitrary1 m, Arbitrary a) => Arbitrary (StTArbit s m a) where
  arbitrary = do
    f <- arbitrary :: Fun s (m (a, s))
    StTArbit . StT <$> f

Error:

• Couldn't match type: (a0, s0)
                 with: s -> m (a, s)
  Expected: Gen (s -> m (a, s))
    Actual: Gen (a0, s0)
• In the second argument of ‘(<$>)’, namely ‘f’
  In a stmt of a 'do' block: StTArbit . StT <$> f

Upvotes: 3

Views: 103

Answers (2)

Abhijit Sarkar
Abhijit Sarkar

Reputation: 24593

OP here, this is what I ended up doing.

-- https://ghc.gitlab.haskell.org/ghc/doc/users_guide/exts/explicit_forall.html
{-# LANGUAGE ExplicitForAll #-}
-- https://ghc.gitlab.haskell.org/ghc/doc/users_guide/exts/type_applications.html
{-# LANGUAGE TypeApplications #-}

module Ch11.MonadTSpec (spec) where

import Ch11.MonadT (StT (..), runStT)
import Data.Function as F
import Test.Hspec
import Test.Hspec.QuickCheck
import Test.QuickCheck

spec :: Spec
spec = do
  describe "Monad (StT Int [])" $ do
    describe "satisfies Monad laws" $ do
      -- the types are in the same order as in `forall`
      prop "right identity law" (prop_monadRightId @Int @Int @[])
      prop "left identity law" (prop_monadLeftId @Int @Int @Int @[])
      prop "associative law" (prop_monadAssoc @Int @Int @Int @Int @[])

{- HLINT ignore -}

{-
the types in `forall` are specified in the order of dependency.
since `m` needs `a` and `s`, those appear before `m` in the list.
-}

-- (x >>= return) == x
prop_monadRightId ::
  forall a s m.
  (Monad m, Eq (m (a, s)), Show (m (a, s))) =>
  s ->
  Fun s (m (a, s)) ->
  Property
prop_monadRightId s f = ((===) `F.on` go) (m >>= return) m
  where
    m = StT $ applyFun f
    go st = runStT st s

-- (return x >>= f) == (f x)
prop_monadLeftId ::
  forall a b s m.
  (Monad m, Eq (m (b, s)), Show (m (b, s))) =>
  a ->
  s ->
  Fun (a, s) (m (b, s)) ->
  Property
prop_monadLeftId a s f = ((===) `F.on` go) (return a >>= h) m
  where
    g = applyFun2 f
    m = StT $ g a
    h = StT . g
    go st = runStT st s

-- ((x >>= f) >>= g) == (x >>= (\x' -> f x' >>= g))
prop_monadAssoc ::
  forall a b c s m.
  (Monad m, Eq (m (b, s)), Show (m (b, s)), Eq (m (c, s)), Show (m (c, s))) =>
  s ->
  Fun s (m (a, s)) ->
  Fun (a, s) (m (b, s)) ->
  Fun (b, s) (m (c, s)) ->
  Property
prop_monadAssoc s h f g =
  ((===) `F.on` go)
    ((m >>= f') >>= g')
    (m >>= (\x -> f' x >>= g'))
  where
    m = StT $ applyFun h
    f' = StT . applyFun2 f
    g' = StT . applyFun2 g
    go st = runStT st s

Upvotes: 0

Daniel Wagner
Daniel Wagner

Reputation: 153172

I think you want pure, not (<$>). (But I haven't checked with my local compiler, so I'm not sure.) You probably also have to turn your Fun into an actual function.

arbitrary = do
  f <- arbitrary
  pure (StTArbit . StT . applyFun $ f)

I'd also point out that there's not much point to making a newtype here. I guess it avoids an orphan instance warning? But you've defined the type you're writing an instance for yourself, presumably even in the same package, so it seems pretty benign; if it's part of a separate cabal component that people can't depend on, like a test suite, even more so.

Upvotes: -1

Related Questions