Reputation: 8136
I have a type something like this:
data MyType = I Int | C Char -- and lots of other options
I want to be able to find out if a value of this type is a specific variant. I could define functions isInt
, isChar
, and so on using pattern matching. But I'd rather be able to write just one function, something like this:
hasBaseType :: MyType -> (a -> MyType) -> Bool
hasBaseType (f _) = True
hasBaseType _ = False
I would pass the appropriate constructor (I
or C
) as the second parameter. Unfortunately, you can't pattern match like that.
I also want to "unwrap" the value. I could write the funtions unwrapInt
, unwrapChar
, and so on, again using pattern matching. But I'd rather be able to write just one function, something like this:
unwrap :: MyType -> (a -> MyType) -> a
unwrap (f x) = x
unwrap _ = error "wrong base type"
Is there some fancy type magic that would allow me to do this? I thought maybe PatternSynonyms
would help here, but I couldn't figure out how.
Upvotes: 0
Views: 197
Reputation: 50829
I think you'll find these functions unwieldy in practice, but this can be accomplished with generics. Using Data.Data
seems easiest. All you need to do for hasBaseType
is use the supplied pattern to construct a skeleton value of MyType
(i.e., using undefined
as the field) and compare the constructors:
{-# LANGUAGE DeriveDataTypeable #-}
import Data.Data
data MyType = I Int | C Char deriving (Data)
hasBaseType :: MyType -> (a -> MyType) -> Bool
hasBaseType val pat = toConstr val == toConstr (pat undefined)
The unwrap
function is a little trickier, but you can query and cast the first field of the constructor. The fromJust
is safe here because hasBaseType
has ensured that we've got the right field type:
import Data.Maybe
unwrap :: (Typeable a) => MyType -> (a -> MyType) -> a
unwrap val pat
| hasBaseType val pat = gmapQi 0 (fromJust . cast) val
| otherwise = error "wrong base type"
The full code:
{-# LANGUAGE DeriveDataTypeable #-}
import Data.Data
import Data.Maybe
data MyType = I Int | C Char deriving (Data)
hasBaseType :: MyType -> (a -> MyType) -> Bool
hasBaseType val pat = toConstr val == toConstr (pat undefined)
unwrap :: (Typeable a) => MyType -> (a -> MyType) -> a
unwrap val pat
| hasBaseType val pat = gmapQi 0 (fromJust . cast) val
| otherwise = error "wrong base type"
main = do
print $ unwrap (C 'a') C -- 'a'
print $ unwrap (I 10) I -- 10
print $ unwrap (I 10) C -- throws "wrong base type" error
Upvotes: 5