Reputation: 373
I am trying to figure out a way for picking a server handler function in a Servant API specification by specifying it's URL type. This is different from Servant.Util.Links - in that I don't want the link in text form, but select a handler function by a typelevel link.
So I have the API and an Endpoint in the API (similar to Servant.Util.Links). Now I want to "walk" through the API, picking up the server handler function matching EndPoint. This is what I came up with:
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleInstances #-}
module Gonimo.GetEndpoint where
import GHC.TypeLits
import Servant.API
import Servant.Utils.Links
import Data.Proxy
import Servant.Server
class GetEndpoint api endpoint where
getEndpoint :: Proxy m -> Proxy api -> Proxy endpoint -> ServerT api m -> ServerT endpoint m
instance (GetEndpoint b1 endpoint) => GetEndpoint (b1 :<|> b2) endpoint where
getEndpoint pM _ pE (lS :<|> _) = getEndpoint pM (Proxy :: Proxy b1) pE lS
instance (GetEndpoint b2 endpoint) => GetEndpoint (b1 :<|> b2) endpoint where
getEndpoint pM _ pE (_ :<|> lR) = getEndpoint pM (Proxy :: Proxy b1) pE lR
but ghc complains about duplicate instances:
Duplicate instance declarations:
instance forall (k :: BOX) b1 b2 (endpoint :: k).
GetEndpoint b1 endpoint =>
GetEndpoint (b1 :<|> b2) endpoint
-- Defined at src/Gonimo/GetEndpoint.hs:22:10
instance forall (k :: BOX) b1 b2 (endpoint :: k).
GetEndpoint b2 endpoint =>
GetEndpoint (b1 :<|> b2) endpoint
-- Defined at src/Gonimo/GetEndpoint.hs:26:10
Which I partly understand - but how else should I pick the right or the left route of :<|> at the type level?
Thanks for any pointers!
Upvotes: 1
Views: 132
Reputation: 373
Thank you user2407038 that did the trick, the following code actually compiles!
The trick as user2407038 suggested, is to use a type level bool - which gets calculated by IsElem. This way we can get the constraint into the type parameters and can select an instance based on the value of our bool -yeah!
Some boilerplate:
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Lib where
import GHC.TypeLits
import Servant.API hiding (IsElem)
import Servant.Utils.Links hiding (IsElem, Or)
import Data.Proxy
import Servant.Server
import GHC.Exts (Constraint)
import Network.Wai (Application)
import Control.Monad.Trans.Except (ExceptT)
We need an Or and an And at the type level:
type family Or (a :: Bool) (b :: Bool) :: Bool where
Or 'False 'False = 'False
Or 'False 'True = 'True
Or 'True 'False = 'True
Or 'True 'True = 'True
type family And (a :: Bool) (b :: Bool) :: Bool where
And 'False 'False = 'False
And 'False 'True = 'False
And 'True 'False = 'False
And 'True 'True = 'True
type family Not (a :: Bool) :: Bool where
Not 'False = 'True
Not 'True = 'False
-- Calculate our bool:
type family IsElem endpoint api :: Bool where
IsElem e (sa :<|> sb) = Or (IsElem e sa) (IsElem e sb)
IsElem (e :> sa) (e :> sb) = IsElem sa sb
IsElem sa (Header sym x :> sb) = IsElem sa sb
IsElem sa (ReqBody y x :> sb) = IsElem sa sb
IsElem (Capture z y :> sa) (Capture x y :> sb)
= IsElem sa sb
IsElem sa (QueryParam x y :> sb) = IsElem sa sb
IsElem sa (QueryParams x y :> sb) = IsElem sa sb
IsElem sa (QueryFlag x :> sb) = IsElem sa sb
IsElem (Verb m s ct typ) (Verb m s ct' typ)
= IsSubList ct ct'
IsElem e e = True
IsElem sa sb = False
type family IsSubList a b :: Bool where
IsSubList '[] b = True
IsSubList (x ': xs) y = Elem x y `And` IsSubList xs y
type family Elem e es :: Bool where
Elem x (x ': xs) = True
Elem y (x ': xs) = Elem y xs
Elem y '[] = False
type family EnableConstraint (c :: Constraint) (enable :: Bool) :: Constraint where
EnableConstraint c 'True = c
EnableConstraint c 'False = ()
Use our IsElem to figure out whether to take the right or the left branch:
type family PickLeftRight endpoint api :: Bool where
PickLeftRight endpoint (sa :<|> sb) = IsElem endpoint sb
PickLeftRight endpoint sa = 'True
Our entry point:
-- | Select a handler from an API by specifying a type level link.
callHandler :: forall api endpoint. (GetEndpoint api endpoint (PickLeftRight endpoint api))
=> Proxy api
-> ServerT api (ExceptT ServantErr IO)
-> Proxy endpoint
-> ServerT endpoint (ExceptT ServantErr IO)
callHandler pA handlers pE = getEndpoint (Proxy :: Proxy (PickLeftRight endpoint api)) pM pA pE handlers
where
pM = Proxy :: Proxy (ExceptT ServantErr IO)
The trick: Additional paramter of kind Bool!
class GetEndpoint api endpoint (chooseRight :: Bool) where
getEndpoint :: forall m. Proxy chooseRight -> Proxy m -> Proxy api -> Proxy endpoint -> ServerT api m -> ServerT endpoint m
Now use it to select an instance, either left:
-- Left choice
instance (GetEndpoint b1 endpoint (PickLeftRight endpoint b1)) => GetEndpoint (b1 :<|> b2) endpoint 'False where
getEndpoint _ pM _ pEndpoint (lS :<|> _) = getEndpoint pLeftRight pM (Proxy :: Proxy b1) pEndpoint lS
where pLeftRight = Proxy :: Proxy (PickLeftRight endpoint b1)
Or right, if our paramter is 'True:
-- Right choice
instance (GetEndpoint b2 endpoint (PickLeftRight endpoint b2)) => GetEndpoint (b1 :<|> b2) endpoint 'True where
getEndpoint _ pM _ pEndpoint (_ :<|> lR) = getEndpoint pLeftRight pM (Proxy :: Proxy b2) pEndpoint lR
where pLeftRight = Proxy :: Proxy (PickLeftRight endpoint b2)
Other instances - not of relevance to the original problem, but here for completeness:
-- Pathpiece
instance (KnownSymbol sym, GetEndpoint sa endpoint (PickLeftRight endpoint sa)) => GetEndpoint (sym :> sa) (sym :> endpoint) 'True where
getEndpoint _ pM _ pEndpoint server = getEndpoint pLeftRight pM (Proxy :: Proxy sa) (Proxy :: Proxy endpoint) server
where pLeftRight = Proxy :: Proxy (PickLeftRight endpoint sa)
-- Capture
instance (KnownSymbol sym, GetEndpoint sa endpoint (PickLeftRight endpoint sa)) => GetEndpoint (Capture sym a :> sa) (Capture sym1 a :> endpoint) 'True where
getEndpoint _ pM _ pEndpoint server a = getEndpoint pLeftRight pM (Proxy :: Proxy sa) (Proxy :: Proxy endpoint) (server a)
where pLeftRight = Proxy :: Proxy (PickLeftRight endpoint sa)
-- QueryParam
instance (KnownSymbol sym, GetEndpoint sa endpoint (PickLeftRight endpoint sa)) => GetEndpoint (QueryParam sym a :> sa) (QueryParam sym a :> endpoint) 'True where
getEndpoint _ pM _ pEndpoint server ma = getEndpoint pLeftRight pM (Proxy :: Proxy sa) (Proxy :: Proxy endpoint) (server ma)
where pLeftRight = Proxy :: Proxy (PickLeftRight endpoint sa)
-- QueryParams
instance (KnownSymbol sym, GetEndpoint sa endpoint (PickLeftRight endpoint sa)) => GetEndpoint (QueryParams sym a :> sa) (QueryParams sym a :> endpoint) 'True where
getEndpoint _ pM _ pEndpoint server as = getEndpoint pLeftRight pM (Proxy :: Proxy sa) (Proxy :: Proxy endpoint) (server as)
where pLeftRight = Proxy :: Proxy (PickLeftRight endpoint sa)
-- QueryFlag
instance (KnownSymbol sym, GetEndpoint sa endpoint (PickLeftRight endpoint sa)) => GetEndpoint (QueryFlag sym :> sa) (QueryFlag sym :> endpoint) 'True where
getEndpoint _ pM _ pEndpoint server f = getEndpoint pLeftRight pM (Proxy :: Proxy sa) (Proxy :: Proxy endpoint) (server f)
where pLeftRight = Proxy :: Proxy (PickLeftRight endpoint sa)
-- Header
instance (KnownSymbol sym, GetEndpoint sa endpoint (PickLeftRight endpoint sa)) => GetEndpoint (Header sym a :> sa) (Header sym a :> endpoint) 'True where
getEndpoint _ pM _ pEndpoint server ma = getEndpoint pLeftRight pM (Proxy :: Proxy sa) (Proxy :: Proxy endpoint) (server ma)
where pLeftRight = Proxy :: Proxy (PickLeftRight endpoint sa)
-- ReqBody
instance (GetEndpoint sa endpoint (PickLeftRight endpoint sa)) => GetEndpoint (ReqBody ct a :> sa) (ReqBody ct a :> endpoint) 'True where
getEndpoint _ pM _ pEndpoint server a = getEndpoint pLeftRight pM (Proxy :: Proxy sa) (Proxy :: Proxy endpoint) (server a)
where pLeftRight = Proxy :: Proxy (PickLeftRight endpoint sa)
-- Verb
instance GetEndpoint (Verb n s ct a) (Verb n s ct a) 'True where
getEndpoint _ _ _ _ server = server
-- Raw
instance GetEndpoint Raw Raw 'True where
getEndpoint _ _ _ _ server = server
Full code on github.
Thanks again for the hint user2407038!
Upvotes: 2