lukstafi
lukstafi

Reputation: 1895

Inheritance for functors

Excuse me the lengthy example:

module type MONAD = sig
  type ('r, 'a) t
  val return : 'a -> ('r, 'a) t
  val bind : ('r, 'a) t -> ('a -> ('r, 'b) t) -> ('r, 'b) t
end

module MonadOps (Monad : MONAD) = struct
  include Monad
  type ('r, 'a) monad = ('r, 'a) t
  let run x = x
  let return = Monad.return
  let bind = Monad.bind
  let (>>=) a b = bind a b
  let rec foldM f a = function
    | [] -> return a
    | x::xs -> f a x >>= fun a' -> foldM f a' xs
  let whenM p s = if p then s else return ()
  let lift f m = perform x <-- m; return (f x)
  let join m = perform x <-- m; x
  let (>=>) f g = fun x -> f x >>= g
end

module Monad = (MonadOps : functor (M : MONAD) -> sig
  type ('a, 'b) monad
  val run : ('a, 'b) monad -> ('a, 'b) M.t
  val return : 'a -> ('b, 'a) monad
  val bind : ('a, 'b) monad -> ('b -> ('a, 'c) monad) -> ('a, 'c) monad
  val ( >>= ) :
    ('a, 'b) monad -> ('b -> ('a, 'c) monad) -> ('a, 'c) monad
  val foldM :
    ('a -> 'b -> ('c, 'a) monad) -> 'a -> 'b list -> ('c, 'a) monad
  val whenM : bool -> ('a, unit) monad -> ('a, unit) monad
  val lift : ('a -> 'b) -> ('c, 'a) monad -> ('c, 'b) monad
  val join : ('a, ('a, 'b) monad) monad -> ('a, 'b) monad
  val ( >=> ) :
    ('a -> ('b, 'c) monad) ->
    ('c -> ('b, 'd) monad) -> 'a -> ('b, 'd) monad
end)

module type MONAD_PLUS = sig
  include MONAD
  val mzero : ('r, 'a) t
  val mplus : ('r, 'a) t -> ('r, 'a) t -> ('r, 'a) t
end

module MonadPlusOps (MonadPlus : MONAD_PLUS) = struct
  include MonadOps (MonadPlus)
  let mzero = MonadPlus.mzero
  let mplus = MonadPlus.mplus
  let fail = mzero
  let (++) a b = mplus a b
  let guard p = if p then return () else fail
end

Is there a way to have MonadPlus analogous to Monad without excessive signature code duplication? Along the lines of (wrong solution):

module MonadPlus = (MonadPlusOps : functor (M : MONAD_PLUS) -> sig
  include module type of MonadPlusOps (M)
    with type ('a, 'b) t := ('a, 'b) MonadPlusOps (M).monad
end)

or (does not type-check):

module MonadPlus = (MonadPlusOps : functor (M : MONAD_PLUS) -> sig
  include module type of Monad(M)
  val mzero : ('a, 'b) monad
  (* ... *)
end)

Edit: updated -- better final solution

module type MONAD = sig
  type ('s, 'a) t
  val return : 'a -> ('s, 'a) t
  val bind : ('s, 'a) t -> ('a -> ('s, 'b) t) -> ('s, 'b) t
end

module type MONAD_OPS = sig
  type ('s, 'a) monad
  include MONAD with type ('s, 'a) t := ('s, 'a) monad
  val ( >>= ) :
    ('s, 'a) monad -> ('a -> ('s, 'b) monad) -> ('s, 'b) monad
  val foldM :
    ('a -> 'b -> ('s, 'a) monad) -> 'a -> 'b list -> ('s, 'a) monad
  val whenM : bool -> ('s, unit) monad -> ('s, unit) monad
  val lift : ('a -> 'b) -> ('s, 'a) monad -> ('s, 'b) monad
  val join : ('s, ('s, 'a) monad) monad -> ('s, 'a) monad
  val ( >=> ) :
    ('a -> ('s, 'b) monad) ->
    ('b -> ('s, 'c) monad) -> 'a -> ('s, 'c) monad
end

module MonadOps (M : MONAD) = struct
  open M
  type ('s, 'a) monad = ('s, 'a) t
  let run x = x
  let (>>=) a b = bind a b
  let rec foldM f a = function
    | [] -> return a
    | x::xs -> f a x >>= fun a' -> foldM f a' xs
  let whenM p s = if p then s else return ()
  let lift f m = perform x <-- m; return (f x)
  let join m = perform x <-- m; x
  let (>=>) f g = fun x -> f x >>= g
end

module Monad (M : MONAD) =
sig
  include MONAD_OPS
  val run : ('s, 'a) monad -> ('s, 'a) M.t
end = struct
  include M
  include MonadOps(M)
end

module type MONAD_PLUS = sig
  include MONAD
  val mzero : ('s, 'a) t
  val mplus : ('s, 'a) t -> ('s, 'a) t -> ('s, 'a) t
end

module type MONAD_PLUS_OPS = sig
  include MONAD_OPS
  val mzero : ('s, 'a) monad
  val mplus : ('s, 'a) monad -> ('s, 'a) monad -> ('s, 'a) monad
  val fail : ('s, 'a) monad
  val (++) : ('s, 'a) monad -> ('s, 'a) monad -> ('s, 'a) monad
  val guard : bool -> ('s, unit) monad
end

module MonadPlus (M : MONAD_PLUS) :
sig
  include MONAD_PLUS_OPS
  val run : ('s, 'a) monad -> ('s, 'a) M.t
end = struct
  include M
  include MonadOps(M)
  let fail = mzero
  let (++) a b = mplus a b
  let guard p = if p then return () else fail
end

Upvotes: 2

Views: 203

Answers (2)

gasche
gasche

Reputation: 31469

As a complement to Andreas' answer, I wished to show that you can use functors to produce signatures. I haven't exactly followed the discussion on which exact level of type abstraction you want, so this code is to be compared with Andreas' version.

module MonadSig = struct
  module type S = sig
    type ('r, 'a) t
    val return : 'a -> ('r, 'a) t
    val bind : ('r, 'a) t -> ('a -> ('r, 'b) t) -> ('r, 'b) t
  end
end

module MonadOpsSig (M : MonadSig.S) = struct
  module type S = sig
    type ('a, 'b) monad = ('a, 'b) M.t
    val run : ('a, 'b) monad -> ('a, 'b) monad
    val (>>=) : ('a, 'b) monad -> ('b -> ('a, 'c) monad) -> ('a, 'c) monad
    (* ... *)
  end
end

module MonadOps (M : MonadSig.S) : MonadOpsSig(M).S = struct
  open M
  type ('r, 'a) monad = ('r, 'a) t
  let run x = x
  let (>>=) = bind
  let rec foldM f a = function
    | [] -> return a
    | x::xs -> f a x >>= fun a' -> foldM f a' xs
  (* ... *)
end

module MonadPlusSig = struct
  module type S = sig
    include MonadSig.S
    val mzero : ('r, 'a) t
    val mplus : ('r, 'a) t -> ('r, 'a) t -> ('r, 'a) t
  end
end

module MonadPlusOpsSig (Monad : MonadPlusSig.S) = struct
  module type S = sig
    include MonadOpsSig(Monad).S
    val fail : ('r, 'a) monad
    val (++) : ('r, 'a) monad -> ('r, 'a) monad -> ('r, 'a) monad
    (* ... *)
  end
end

module MonadPlusOps (M : MonadPlusSig.S) : MonadPlusOpsSig(M).S = struct
  include MonadOps(M)
  open M
  let fail = mzero
  let (++) = mplus
  (* ... *)
end

The idea is that to provide a signature parametrized on something, you can either embed this signature into a parametrized functor (I'd call this the "functor style"), or define the parameters as abstract (but they're really inputs rather than outputs) and, at use site, equate them with the actual parameters (I'd call this the "mixin style"). I'm not saying the code above is better than Andreas', in fact I'd probably rather use his version, but its interesting to compare them.

Upvotes: 2

Andreas Rossberg
Andreas Rossberg

Reputation: 36118

I'm not entirely sure what you are trying to achieve, but I would perhaps try to factor it as follows:

module type MONAD =
sig
  type ('r, 'a) t
  val return : 'a -> ('r, 'a) t
  val bind : ('r, 'a) t -> ('a -> ('r, 'b) t) -> ('r, 'b) t
end

module type MONAD_OPS =
sig
  type ('a, 'b) monad
  val run : ('a, 'b) monad -> ('a, 'b) monad
  val (>>=) : ('a, 'b) monad -> ('b -> ('a, 'c) monad) -> ('a, 'c) monad
  (* ... *)
end

module MonadOps (Monad : MONAD) : 
sig
  include MONAD with type ('a ,'b) t := ('a, 'b) Monad.t
  include MONAD_OPS with type ('a ,'b) monad = ('a, 'b) Monad.t
end =
struct
  include Monad
  type ('r, 'a) monad = ('r, 'a) t
  let run x = x
  let (>>=) = bind
  let rec foldM f a = function
    | [] -> return a
    | x::xs -> f a x >>= fun a' -> foldM f a' xs
  (* ... *)
end

module type MONAD_PLUS = sig
  include MONAD
  val mzero : ('r, 'a) t
  val mplus : ('r, 'a) t -> ('r, 'a) t -> ('r, 'a) t
end

module type MONAD_PLUS_OPS =
sig
  include MONAD_OPS
  val fail : ('r, 'a) monad
  val (++) : ('r, 'a) monad -> ('r, 'a) monad -> ('r, 'a) monad
  (* ... *)
end

module MonadPlusOps (MonadPlus : MONAD_PLUS) :
sig
  include MONAD_PLUS with type ('a ,'b) t := ('a, 'b) Monad.t
  include MONAD_PLUS_OPS with type ('a ,'b) monad = ('a, 'b) Monad.t
end =
struct
  include MonadPlus
  include MonadOps (MonadPlus)
  let fail = mzero
  let (++) = mplus
  (* ... *)
end

Upvotes: 2

Related Questions