Reputation: 121
I'm trying to write a Haskell function that automatically distributes a binary operation over a list, kind of how arithmetic operations work in the J language.
You can think of it as a "deep zipWith
" that works on nested lists of any depth, including non-lists and lists of different depths.
For example:
distr (+) 1 10 === 11 -- Non-list values are added together
distr (+) [1,2] 10 === [11,12] -- Non-list values distribute over lists
distr (+) [1,2] [10,20] === [11,22] -- Two lists get zipped
distr (+) [[1,2],[3,4]] [[10,20],[30,40]] === [[11,22],[33,44]] -- Nested lists get zipped
Lists of different lengths get truncated, like with zipWith
, but this is not important.
Now, I have already written this:
{-# LANGUAGE
MultiParamTypeClasses,
FunctionalDependencies,
UndecidableInstances,
FlexibleInstances
#-}
class Distr a b c x y z | a b c x y -> z
where distr :: (a -> b -> c) -> (x -> y -> z)
instance Distr a b c a b c where distr = id
instance {-# OVERLAPPING #-}
(Distr a b c x y z) => Distr a b c [x] [y] [z]
where distr = zipWith . distr
instance (Distr a b c x y z) => Distr a b c [x] y [z]
where distr f xs y = map (\x -> distr f x y) xs
instance (Distr a b c x y z) => Distr a b c x [y] [z]
where distr f x ys = map (\y -> distr f x y) ys
This defines a 6-parameter typeclass Distr
with a function distr :: (Distr a b c x y z) => (a -> b -> c) -> (x -> y -> z)
, and some instances of Distr
on nested lists.
It works well on the examples above, but its behavior on lists of unequal nesting depth is not exactly what I want.
It does this (which works if you add type annotations to (+)
and both lists):
distr (+) [[1,2],[3,4]] [10,20] === [[11,12],[23,24]] -- Zip and distribute
Try it here. What I want is this:
distr (+) [[1,2],[3,4]] [10,20] === [[11,22],[13,24]] -- Distribute and zip
The current implementation applies zipWith
until one of its arguments is a non-list value, which is then distributed over the other list.
I would prefer it to distribute one argument (the one with fewer list layers) over the other until it reaches equal nesting depth, and then use zipWith
to reduce them to non-list values.
My question is: Can I achieve the second kind of behavior?
I'm happy with a solution that requires me to explicitly tell Haskell the types of the operator and each argument, as my current solution does.
I will not call distr
on an operator that takes lists as inputs, so that case need not be handled.
However, I don't want to give extra arguments to distr
that serve as type hints, or have several different versions of distr
for different use cases.
I know my problem could be solved this way, but I'd prefer a solution where it isn't necessary.
Upvotes: 2
Views: 499
Reputation: 33389
(As a gist in Literate Haskell)
{-# LANGUAGE DataKinds, FlexibleContexts, FlexibleInstances,
TypeFamilies, MultiParamTypeClasses, UndecidableInstances,
RankNTypes, ScopedTypeVariables, FunctionalDependencies, TypeOperators #-}
module Zip where
import Data.Proxy
import GHC.TypeLits
Let's first assume that the two nested lists have the same depth. E.g., depth 2:
zipDeep0 ((+) :: Int -> Int -> Int) [[1,2],[3,4,5]] [[10,20],[30,40]] :: [[Int]]
[[11,22],[33,44]]
Implementation:
zipDeep0
:: forall n a b c x y z
. (ZipDeep0 n a b c x y z, n ~ Levels a x, n ~ Levels b y, n ~ Levels c z)
=> (a -> b -> c) -> (x -> y -> z)
zipDeep0 = zipDeep0_ (Proxy :: Proxy n)
Levels a x
computes the depth of a in the nested list type x.
Closed type families allow us to do some non-linear type-level pattern matching
(where a
occurs twice in a clause).
type family Levels a x :: Nat where
Levels a a = 0
Levels a [x] = 1 + Levels a x
We use that depth to select the ZipDeep0
instance implementing the zip,
This way is neater than relying only on the six other ordinary type parameters,
as it avoids a problem with type inference and overlapping instances when some list is empty (so
we can't infer its actual type from itself), or when a
, b
, c
are also
list types.
class ZipDeep0 (n :: Nat) a b c x y z where
zipDeep0_ :: proxy n -> (a -> b -> c) -> x -> y -> z
-- Moving the equality constraints into the context helps type inference.
instance {-# OVERLAPPING #-} (a ~ x, b ~ y, c ~ z) => ZipDeep0 0 a b c x y z where
zipDeep0_ _ = id
instance (ZipDeep0 (n - 1) a b c x y z, xs ~ [x], ys ~ [y], zs ~ [z])
=> ZipDeep0 n a b c xs ys zs where
zipDeep0_ _ f = zipWith (zipDeep0_ (Proxy :: Proxy (n - 1)) f)
When the two lists are not of the same depth, let's first assume the second one
is deeper, so we must distribute over it.
We start losing some type inference, we must know at least Levels a x
(and
thus a
and x
) and either Levels b y
or Levels c z
before this function
can be applied.
Example:
zipDeep1 (+) [10,20 :: Int] [[1,2],[3,4]] :: [[Int]]
[[11,22],[13,24]]
Implementation:
zipDeep1
:: forall n m a b c x y z
. (n ~ Levels a x, m ~ Levels b y, m ~ Levels c z, ZipDeep1 (m - n) a b c x y z)
=> (a -> b -> c) -> x -> y -> z
zipDeep1 = zipDeep1_ (Proxy :: Proxy (m - n))
The difference between levels (m - n)
tells us how many layers we must "distribute"
through before falling back to zipDeep0
.
class ZipDeep1 (n :: Nat) a b c x y z where
zipDeep1_ :: proxy n -> (a -> b -> c) -> x -> y -> z
instance {-# OVERLAPPING #-} ZipDeep0 (Levels a x) a b c x y z => ZipDeep1 0 a b c x y z where
zipDeep1_ _ = zipDeep0_ (Proxy :: Proxy (Levels a x))
instance (ZipDeep1 (n - 1) a b c x y z, ys ~ [y], zs ~ [z]) => ZipDeep1 n a b c x ys zs where
zipDeep1_ proxy f xs = fmap (zipDeep1_ (Proxy :: Proxy (n - 1)) f xs)
Finally, we can do a type-level comparison when either list may be the deeper one. We lose all type inference though.
Example:
zipDeep ((+) :: Int -> Int -> Int) [[1,2 :: Int],[3,4]] [10 :: Int,20] :: [[Int]]
[[11,22],[13,24]]
Some type inference is recovered by instead specifying the expected depth of each list with TypeApplications.
zipDeep @2 @1 ((+) :: Int -> Int -> Int) [[1,2],[3,4]] [10,20]
[[11,22],[13,24]]
Implementation:
zipDeep
:: forall n m a b c x y z
. (ZipDeep2 (CmpNat n m) n m a b c x y z, n ~ Levels a x, m ~ Levels b y)
=> (a -> b -> c) -> x -> y -> z
zipDeep = zipDeep2_ (Proxy :: Proxy '(CmpNat n m, n, m))
class ZipDeep2 (cmp :: Ordering) (n :: Nat) (m :: Nat) a b c x y z where
zipDeep2_ :: proxy '(cmp, n, m) -> (a -> b -> c) -> x -> y -> z
instance {-# OVERLAPPING #-} (n ~ Levels a x, m ~ Levels b y, m ~ Levels c z, ZipDeep1 (m - n) a b c x y z)
=> ZipDeep2 'LT n m a b c x y z where
zipDeep2_ _ = zipDeep1
instance (n ~ Levels a x, m ~ Levels b y, n ~ Levels c z, ZipDeep1 (n - m) b a c y x z)
=> ZipDeep2 cmp n m a b c x y z where
zipDeep2_ _ = flip . zipDeep1 . flip
Upvotes: 2