Zgarb
Zgarb

Reputation: 121

Overloaded "zipWith" supporting nested lists

I'm trying to write a Haskell function that automatically distributes a binary operation over a list, kind of how arithmetic operations work in the J language. You can think of it as a "deep zipWith" that works on nested lists of any depth, including non-lists and lists of different depths. For example:

distr (+) 1 10 === 11  -- Non-list values are added together
distr (+) [1,2] 10 === [11,12]  -- Non-list values distribute over lists
distr (+) [1,2] [10,20] === [11,22]  -- Two lists get zipped
distr (+) [[1,2],[3,4]] [[10,20],[30,40]] === [[11,22],[33,44]]  -- Nested lists get zipped

Lists of different lengths get truncated, like with zipWith, but this is not important.

Now, I have already written this:

{-# LANGUAGE
    MultiParamTypeClasses,
    FunctionalDependencies,
    UndecidableInstances,
    FlexibleInstances
#-}

class Distr a b c x y z | a b c x y -> z
  where distr :: (a -> b -> c) -> (x -> y -> z)

instance Distr a b c a b c where distr = id
instance {-# OVERLAPPING #-}
         (Distr a b c x y z) => Distr a b c [x] [y] [z]
  where distr = zipWith . distr
instance (Distr a b c x y z) => Distr a b c [x]  y  [z]
  where distr f xs y = map (\x -> distr f x y) xs
instance (Distr a b c x y z) => Distr a b c  x  [y] [z]
  where distr f x ys = map (\y -> distr f x y) ys

This defines a 6-parameter typeclass Distr with a function distr :: (Distr a b c x y z) => (a -> b -> c) -> (x -> y -> z), and some instances of Distr on nested lists. It works well on the examples above, but its behavior on lists of unequal nesting depth is not exactly what I want. It does this (which works if you add type annotations to (+) and both lists):

distr (+) [[1,2],[3,4]] [10,20] === [[11,12],[23,24]]  -- Zip and distribute

Try it here. What I want is this:

distr (+) [[1,2],[3,4]] [10,20] === [[11,22],[13,24]]  -- Distribute and zip

The current implementation applies zipWith until one of its arguments is a non-list value, which is then distributed over the other list. I would prefer it to distribute one argument (the one with fewer list layers) over the other until it reaches equal nesting depth, and then use zipWith to reduce them to non-list values.

My question is: Can I achieve the second kind of behavior? I'm happy with a solution that requires me to explicitly tell Haskell the types of the operator and each argument, as my current solution does. I will not call distr on an operator that takes lists as inputs, so that case need not be handled. However, I don't want to give extra arguments to distr that serve as type hints, or have several different versions of distr for different use cases. I know my problem could be solved this way, but I'd prefer a solution where it isn't necessary.

Upvotes: 2

Views: 499

Answers (1)

Li-yao Xia
Li-yao Xia

Reputation: 33389

(As a gist in Literate Haskell)

{-# LANGUAGE DataKinds, FlexibleContexts, FlexibleInstances,
    TypeFamilies, MultiParamTypeClasses, UndecidableInstances,
    RankNTypes, ScopedTypeVariables, FunctionalDependencies, TypeOperators #-}

module Zip where

import Data.Proxy
import GHC.TypeLits

Let's first assume that the two nested lists have the same depth. E.g., depth 2:

zipDeep0 ((+) :: Int -> Int -> Int) [[1,2],[3,4,5]] [[10,20],[30,40]] :: [[Int]]
[[11,22],[33,44]]

Implementation:

zipDeep0
  :: forall n a b c x y z
  .  (ZipDeep0 n a b c x y z, n ~ Levels a x, n ~ Levels b y, n ~ Levels c z)
  => (a -> b -> c) -> (x -> y -> z)
zipDeep0 = zipDeep0_ (Proxy :: Proxy n)

Levels a x computes the depth of a in the nested list type x. Closed type families allow us to do some non-linear type-level pattern matching (where a occurs twice in a clause).

type family Levels a x :: Nat where
  Levels a a = 0
  Levels a [x] = 1 + Levels a x

We use that depth to select the ZipDeep0 instance implementing the zip, This way is neater than relying only on the six other ordinary type parameters, as it avoids a problem with type inference and overlapping instances when some list is empty (so we can't infer its actual type from itself), or when a, b, c are also list types.

class ZipDeep0 (n :: Nat) a b c x y z where
  zipDeep0_ :: proxy n -> (a -> b -> c) -> x -> y -> z

-- Moving the equality constraints into the context helps type inference.
instance {-# OVERLAPPING #-} (a ~ x, b ~ y, c ~ z) => ZipDeep0 0 a b c x y z where
  zipDeep0_ _ = id

instance (ZipDeep0 (n - 1) a b c x y z, xs ~ [x], ys ~ [y], zs ~ [z])
  => ZipDeep0 n a b c xs ys zs where
  zipDeep0_ _ f = zipWith (zipDeep0_ (Proxy :: Proxy (n - 1)) f)

When the two lists are not of the same depth, let's first assume the second one is deeper, so we must distribute over it. We start losing some type inference, we must know at least Levels a x (and thus a and x) and either Levels b y or Levels c z before this function can be applied.

Example:

zipDeep1 (+) [10,20 :: Int] [[1,2],[3,4]] :: [[Int]]
[[11,22],[13,24]]

Implementation:

zipDeep1
 :: forall n m a b c x y z
 .  (n ~ Levels a x, m ~ Levels b y, m ~ Levels c z, ZipDeep1 (m - n) a b c x y z)
 => (a -> b -> c) -> x -> y -> z
zipDeep1 = zipDeep1_ (Proxy :: Proxy (m - n))

The difference between levels (m - n) tells us how many layers we must "distribute" through before falling back to zipDeep0.

class ZipDeep1 (n :: Nat) a b c x y z where
  zipDeep1_ :: proxy n -> (a -> b -> c) -> x -> y -> z

instance {-# OVERLAPPING #-} ZipDeep0 (Levels a x) a b c x y z => ZipDeep1 0 a b c x y z where
  zipDeep1_ _ = zipDeep0_ (Proxy :: Proxy (Levels a x))

instance (ZipDeep1 (n - 1) a b c x y z, ys ~ [y], zs ~ [z]) => ZipDeep1 n a b c x ys zs where
  zipDeep1_ proxy f xs = fmap (zipDeep1_ (Proxy :: Proxy (n - 1)) f xs)

Finally, we can do a type-level comparison when either list may be the deeper one. We lose all type inference though.

Example:

zipDeep ((+) :: Int -> Int -> Int) [[1,2 :: Int],[3,4]] [10 :: Int,20] :: [[Int]]
[[11,22],[13,24]]

Some type inference is recovered by instead specifying the expected depth of each list with TypeApplications.

zipDeep @2 @1 ((+) :: Int -> Int -> Int) [[1,2],[3,4]] [10,20]
[[11,22],[13,24]]

Implementation:

zipDeep
  :: forall n m a b c x y z
  .  (ZipDeep2 (CmpNat n m) n m a b c x y z, n ~ Levels a x, m ~ Levels b y)
  => (a -> b -> c) -> x -> y -> z
zipDeep = zipDeep2_ (Proxy :: Proxy '(CmpNat n m, n, m))

class ZipDeep2 (cmp :: Ordering) (n :: Nat) (m :: Nat) a b c x y z where
  zipDeep2_ :: proxy '(cmp, n, m) -> (a -> b -> c) -> x -> y -> z

instance {-# OVERLAPPING #-} (n ~ Levels a x, m ~ Levels b y, m ~ Levels c z, ZipDeep1 (m - n) a b c x y z)
  => ZipDeep2 'LT n m a b c x y z where
  zipDeep2_ _ = zipDeep1

instance (n ~ Levels a x, m ~ Levels b y, n ~ Levels c z, ZipDeep1 (n - m) b a c y x z)
  => ZipDeep2 cmp n m a b c x y z where
  zipDeep2_ _ = flip . zipDeep1 . flip

Upvotes: 2

Related Questions