Sassa NF
Sassa NF

Reputation: 5406

Type Inference for declared instance

I came up with a nice exercise, but can't make it work.

The idea is to try and express Roman numerals in such a way that the type checker will tell me whether the numeral is valid.

    {-# LANGUAGE RankNTypes
               , MultiParamTypeClasses #-}

    data One a b c = One a deriving (Show, Eq)
    data Two a b c = Two (One a b c) (One a b c) deriving (Show, Eq)
    data Three a b c = Three (One a b c) (Two a b c) deriving (Show, Eq)
    data Four a b c = Four (One a b c) (Five a b c) deriving (Show, Eq)
    data Five a b c = Five b deriving (Show, Eq)
    data Six a b c = Six (Five a b c) (One a b c) deriving (Show, Eq)
    data Seven a b c = Seven (Five a b c) (Two a b c) deriving (Show, Eq)
    data Eight a b c = Eight (Five a b c) (Three a b c) deriving (Show, Eq)
    data Nine a b c d e = Nine (One a b c) (One c d e) deriving (Show, Eq)

    data Z = Z deriving (Show, Eq) -- dummy for the last level
    data I = I deriving (Show, Eq)
    data V = V deriving (Show, Eq)
    data X = X deriving (Show, Eq)
    data L = L deriving (Show, Eq)
    data C = C deriving (Show, Eq)
    data D = D deriving (Show, Eq)
    data M = M deriving (Show, Eq)

    i :: One I V X
    i = One I

    v :: Five I V X
    v = Five V

    x :: One X L C
    x = One X

    l :: Five X L C
    l = Five L

    c :: One C D M
    c = One C

    d :: Five C D M
    d = Five D

    m :: One M Z Z
    m = One M

    infixr 4 #

    class RomanJoiner a b c where
      (#) :: a -> b -> c

    instance RomanJoiner (One a b c) (One a b c) (Two a b c) where
      (#) = Two

    instance RomanJoiner (One a b c) (Two a b c) (Three a b c) where
      (#) = Three

    instance RomanJoiner (One a b c) (Five a b c) (Four a b c) where
      (#) = Four

    instance RomanJoiner (Five a b c) (One a b c) (Six a b c) where
      (#) = Six

    instance RomanJoiner (Five a b c) (Two a b c) (Seven a b c) where
      (#) = Seven

    instance RomanJoiner (Five a b c) (Three a b c) (Eight a b c) where
      (#) = Eight

    instance RomanJoiner (One a b c) (One c d e) (Nine a b c d e) where
      (#) = Nine

    main = print $ v # i # i

This possibly can be done differently, and the solution is incomplete, but right now I need to understand why it complains that there is no instance for RomanJoiner (One I V X) (One I V X) b0, whereas I think I declared such a joiner.

Upvotes: 1

Views: 148

Answers (1)

aavogt
aavogt

Reputation: 1308

The issue is that instances are not chosen based on the only one that works: one extension FunctionalDependencies helps to get some more type inference. Enabling that, and saying with | a b -> c that the type of a # b can be inferred from the types of a and b. Unfortunately, that's not the only thing you have to do because you'd get the error Functional dependencies conflict between instance declarations. Using some classes defined in HList (these could be defined anywhere else), the conflicting two instances can be combined into a single one, where the two (or 3 if you count an error) possible results are chosen based on whether some types are equal.

A couple comments about this solution being ugly:

  1. you don't have to replicate what's going on at the type level at the value level again (hCond vs HCond) if you had lazier Show instances (like instance Show I where show _ = "I").
  2. with more modern extension TypeFamilies many of those intermediate type variables ba, bb, bc, babc ... can be eliminated.

    {-# LANGUAGE RankNTypes, MultiParamTypeClasses, FunctionalDependencies, ScopedTypeVariables, UndecidableInstances, FlexibleContexts, FlexibleInstances #-}
    import Data.HList hiding ((#))
    import Data.HList.TypeEqGeneric1
    import Data.HList.TypeCastGeneric1
    import Unsafe.Coerce
    
    data One a b c = One a deriving (Show, Eq)
    data Two a b c = Two (One a b c) (One a b c) deriving (Show, Eq)
    data Three a b c = Three (One a b c) (Two a b c) deriving (Show, Eq)
    data Four a b c = Four (One a b c) (Five a b c) deriving (Show, Eq)
    data Five a b c = Five b deriving (Show, Eq)
    data Six a b c = Six (Five a b c) (One a b c) deriving (Show, Eq)
    data Seven a b c = Seven (Five a b c) (Two a b c) deriving (Show, Eq)
    data Eight a b c = Eight (Five a b c) (Three a b c) deriving (Show, Eq)
    data Nine a b c d e = Nine (One a b c) (One c d e) deriving (Show, Eq)
    
    data Z = Z deriving (Show, Eq) -- dummy for the last level
    data I = I deriving (Show, Eq)
    data V = V deriving (Show, Eq)
    data X = X deriving (Show, Eq)
    data L = L deriving (Show, Eq)
    data C = C deriving (Show, Eq)
    data D = D deriving (Show, Eq)
    data M = M deriving (Show, Eq)
    
    i :: One I V X
    i = One I
    
    v :: Five I V X
    v = Five V
    
    x :: One X L C
    x = One X
    
    l :: Five X L C
    l = Five L
    
    c :: One C D M
    c = One C
    
    d :: Five C D M
    d = Five D
    
    m :: One M Z Z
    m = One M
    
    infixr 4 #
    
    class RomanJoiner a b c | a b -> c where
        (#) :: a -> b -> c
    
    
    instance RomanJoiner (One a b c) (Two a b c) (Three a b c) where
        (#) = Three
    
    instance RomanJoiner (One a b c) (Five a b c) (Four a b c) where
        (#) = Four
    
    instance RomanJoiner (Five a b c) (One a b c) (Six a b c) where
        (#) = Six
    
    instance RomanJoiner (Five a b c) (Two a b c) (Seven a b c) where
        (#) = Seven
    
    instance RomanJoiner (Five a b c) (Three a b c) (Eight a b c) where
        (#) = Eight
    
    data Error = Error
    instance forall a b c a' b' c' ba bb bc bab babc z bn nine.
       (TypeEq a a' ba,
        TypeEq b b' bb,
        TypeEq c c' bc,
        HAnd ba bb bab,
        HAnd bab bc babc,
    
        TypeEq c a' bn,
        HCond bn (Nine a b c b' c') Error nine,
    
        HCond babc (Two a b c) nine  z) =>
            RomanJoiner (One a b c) (One a' b' c') z where
        (#) x y = hCond (undefined :: babc)
                    (Two (uc x :: One a b c) (uc y :: One a b c)) $
                  hCond (undefined :: bn)
                    (Nine (uc x :: One a b c) (uc y :: One c b' c'))
                    Error
            where uc = unsafeCoerce
    
    main = print $ v # i # i
    {-
    Prints with ghc 762, HList-0.2.3
    
    *Main> main
    Seven (Five V) (Two (One I) (One I)
    
    -}
    

Upvotes: 4

Related Questions