And Pos
And Pos

Reputation: 147

Generalizing functions in Haskell

These two functions are almost identical:

dig :: MappingClassifierM FileSummary IO Partitioner -> Conduit (Cluster Partitioner FileSummary) IO (Cluster Partitioner FileSummary)
dig classifier =
  await >>= \case
    Nothing -> return ()
    Just ( Cluster clusterKey clusterValue ) -> do
      categories <- liftIO $ classify classifier clusterValue
      if (length clusterValue == length categories)
        then do
          yield $ Cluster clusterKey clusterValue
          dig classifier
        else do
          mapM_ (yield . cluster) categories
          dig classifier
      where
        cluster (key, val) = Cluster (key : clusterKey) val
        classify = classifyM

dig' :: BinaryClassifierM FileSummary IO -> Conduit (Cluster Partitioner FileSummary) IO (Cluster Partitioner FileSummary)
dig' classifier =
  await >>= \case
    Nothing -> return ()
    Just ( Cluster clusterKey clusterValue ) -> do
      categories <- liftIO $ classify classifier clusterValue
      if (length clusterValue == length categories)
        then do
          yield $ Cluster clusterKey clusterValue
          dig' classifier
        else do
          mapM_ (yield . cluster) categories
          dig' classifier
      where
        cluster = Cluster (Content : clusterKey)
        classify = classifyBinary

The only difference is in the the functions defined in the where clause.

The following constraints apply:

I want to generalize the two functions so as to create a single function that handles both implementations to avoid duplication.

I don't know if I'm on the right direction. Based on my limited knowledge of Haskell so far, I'd think that I have to create a class "Classifier", for which BinaryClassifier and MappingClassifierM would be instances of, but I'm facing with several compilation errors when I try to implement it.

So, my question is: How do an experienced Haskell programmer generalize these two functions to avoid duplication?

For additional context, below are the the relevant type signatures for the two different cases I'm trying to generalize:

type MappingClassifierM a m k = a -> m k
classifyM :: (Monad m, Ord k) => MappingClassifierM a m k -> [a] -> m [(k, [a])]
dig :: MappingClassifierM FileSummary IO Partitioner -> Conduit (Cluster Partitioner FileSummary) IO (Cluster Partitioner FileSummary)

type BinaryClassifierM a m = a -> a -> m Bool
classifyBinary :: Monad m => BinaryClassifierM a m -> [a] -> m [[a]]
dig' :: BinaryClassifierM FileSummary IO -> Conduit (Cluster Partitioner FileSummary) IO (Cluster Partitioner FileSummary)

Upvotes: 3

Views: 309

Answers (2)

And Pos
And Pos

Reputation: 147

I was able to find an alternative solution using classes allowing some language extensions:

{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE FlexibleContexts #-}

class Classifier classifier a m c | classifier -> m c where
  classify :: (Monad m) => classifier -> [a] -> m [c]
  cluster :: classifier -> [Partitioner] -> c -> Cluster Partitioner a

instance Classifier (MappingClassifierM FileSummary IO Partitioner) FileSummary IO (P
  classify = classifyM
  cluster _ x (key, val) = Cluster (key : x) val

instance Classifier (BinaryClassifierM FileSummary IO) FileSummary IO [FileSummary] w
  classify = classifyBinary
  cluster _ clusterKey = Cluster (Content : clusterKey)

dig :: (Classifier classifier FileSummary IO c) => classifier -> Conduit (Cluster Par
dig classifier =
  await >>= \case
    Nothing -> return ()
    Just ( Cluster clusterKey clusterValue ) -> do
      categories <- liftIO $ classify classifier clusterValue

      when (length clusterValue == length categories) $
        yield $ Cluster clusterKey clusterValue

      when (length clusterValue /= length categories) $
        mapM_ (yield . (cluster classifier clusterKey)) categories

      dig classifier

Upvotes: 1

rprospero
rprospero

Reputation: 961

My method was to let the compiler do most of the thinking for me. First, I took the common code and made it its own function

dig'' classifier =
  await >>= \case
    Nothing -> return ()
    Just ( Cluster clusterKey clusterValue ) -> do
      categories <- liftIO $ classify classifier clusterValue
      if (length clusterValue == length categories)
        then do
          yield $ Cluster clusterKey clusterValue
          dig'' classifier
        else do
          mapM_ (yield . cluster) categories
          dig'' classifier

This lead to a couple of problems. First, the compiler whines that cluster and classify are undefined, so we'll need to add them onto the function as parameters. You'll also notice that classifier is never used outside of the classify function, so we'll combine them into one value. Also, since we've changed the parameters of our function, our recursive calls change, so we'll need to take care of those.

dig'' cluster =
  await >>= \case
    Nothing -> return ()
    Just ( Cluster clusterKey clusterValue ) -> do
      categories <- liftIO $ classify clusterValue
      if (length clusterValue == length categories)
        then do
          yield $ Cluster clusterKey clusterValue
          dig'' cluster classify
        else do
          mapM_ (yield . cluster) categories
          dig'' cluster classify

Here's where I noticed that the cluster in the where clause needed clusterKey, so it will need to be passed in as a parameter somehow.

dig'' cluster =
  await >>= \case
    Nothing -> return ()
    Just ( Cluster clusterKey clusterValue ) -> do
      categories <- liftIO $ classify clusterValue
      if (length clusterValue == length categories)
        then do
          yield $ Cluster clusterKey clusterValue
          dig'' cluster classify
        else do
          mapM_ (yield . cluster clusterKey) categories
          dig'' cluster classify

Finally, use a type hole to let ghc figure out the data type for me.

dig'' :: _
dig'' cluster =
  await >>= \case
    Nothing -> return ()
    Just ( Cluster clusterKey clusterValue ) -> do
      categories <- liftIO $ classify clusterValue
      if (length clusterValue == length categories)
        then do
          yield $ Cluster clusterKey clusterValue
          dig'' cluster classify
        else do
          mapM_ (yield . cluster clusterKey) categories
          dig'' cluster classify

I got a value of

(Foldable t, MonadIO m) =>
 ([a] -> a1 -> Cluster a b)
 -> ([b] -> IO (t a1)) -> Conduit (Cluster a b) m (Cluster a b)

but your results may differ, since I largely made up all my data types. For example, I suspect that the [a] and [b] will be different for your code.

Now, to get back to the original functions. Here we have the advantage of already knowing the types for our result.

dig :: MappingClassifierM FileSummary IO Partitioner -> Conduit (Cluster Partitioner FileSummary) IO (Cluster Partitioner FileSummary)
dig classifier = dig'' ????

Now we just need the two parameters for dig''. The first parameter is just filling in for the cluster definition, so we get

(\clusterKey (key, val) -> Cluster (key : clusterKey) val)

The second parameter was classify, to which we'd rolled in the classifier, so it's simply (classifyM classifier)

Thus, the final definition is

dig :: MappingClassifierM FileSummary IO Partitioner -> Conduit (Cluster Partitioner FileSummary) IO (Cluster Partitioner FileSummary)
dig classifier = dig'' (\clusterKey (key, val) -> Cluster (key : clusterKey) val) (classifyM classifier)

Similarly, you can also find

dig' :: BinaryClassifierM FileSummary IO -> Conduit (Cluster Partitioner FileSummary) IO (Cluster Partitioner FileSummary)
dig' classifier = dig'' (\clusterKey -> Cluster (Content : clusterKey)) (classifyBinary classifier)

Upvotes: 1

Related Questions