Oliver
Oliver

Reputation: 2400

Memoising based on a key

I've been playing around with the MemoCombinators and MemoTrie packages lately and I was trying to memoise a function which was taking a tree (which was really a DAG in disguise as several of the nodes were shared). In the form of:

data Tree k a = Branch (Tree k a) (k, a) (Tree k a) | Leaf (k, a)

So I want to memoise a function of type (based on it's key):

Tree k a -> b

Now I have a vague understanding that these memoisation combinators are used to turn your function f :: a -> a into a structure of lazy (unevaluated) values of a, so that when you pull one out it is already evaluated. So that wouldn't be a problem with my Tree - somehow I'd need turn it into a structure of values indexed by k.

I couldn't figure out how to do it with the combinator libraries. One easy way around it is to make a function k -> a which indexes a map, which fits in just fine but that seems a little clunky.

Am I misguided in this goal, or have I missed something obvious?

I can easily see how to write this function out with this kind of style, explicitly threading my "table" through all the computations:

f :: Tree Int Int -> Map Int Int -> (Int, Map Int Int)
f (Branch l (k, x) r) m | (Just r) <- lookup k m = r
                        | otherwise = (max rl rr, m'')
     where 
       (rl, m') = (f l m) 
       (rr, m'') = (f r m') 

But that is not so nice.

Upvotes: 3

Views: 219

Answers (1)

rampion
rampion

Reputation: 89053

So, most memoization techniques use state. The memoized version of a function keeps a collection mapping inputs to memoized outputs. When it gets an input, it checks the collection, returning the memoized value if available. Otherwise, it computes the output using the original version of the function, saves the output in the collection, and returns the newly memoized output. The memoized collection, therefore, grows over the lifespan of the function.

Haskell memoizers like the ones you mention eschew state, and instead precompute a data structure that holds the collection of memoized outputs, using laziness to insure that the value of a particular output isn't computed until needed. This has a lot in common with the stateful approach, except for a few key points:

  • Since the collection is immutable, it never grows. Unmemoized outputs are recomputed each time.
  • Since the collection is created before the function is used, it doesn't know which inputs will be used. So the memoizer has to provide a collection of inputs over which to memoize.

This is fairly straightforward to implement by hand:

module Temp where
import Prelude hiding (lookup)
import Control.Arrow ((&&&))
import Data.Map (fromList, lookup)

data Tree k a = Branch (Tree k a) (k, a) (Tree k a) | Leaf (k, a)

key :: Tree k a -> k
key (Leaf (k, _)) = k
key (Branch _ (k,_) _) = k

-- memoize a given function over the given trees
memoFor :: Ord k => [Tree k a] -> (Tree k a -> b) -> Tree k a -> b
memoFor ts f = f'
  where f' t = maybe (f t) id $ lookup (key t) m
        m = fromList $ map (key &&& f) ts

What the MemoCombinators and MemoTrie packages try to do is make the collection of inputs implicit (using functions and type-classes, repsectively). If all the possible inputs can be enumerated, then we can use that enumeration to build our data structure.

In your case, since you want to memoize on just the key of your trees, the easiest way might be to use the wrap function from the MemoCombinators package:

wrap :: (a -> b) -> (b -> a) -> Memo a -> Memo b

Given a memoizer for a and an isomorphism between a and b, build a memoizer for b.

So if your key values have a corresponding Memo value (like, say, type Key = Int), and you have a bijection from Keys to Tree Key Val, then you can use that bijection to make a memoizer for your Tree Key Val functions:

memoize :: (Tree Key Val -> b) -> (Tree Key Val -> b)
memoize = wrap keyToTree treeToKey memoForKey

Update: If you can't create such a mapping ahead of time, perhaps the solution is to use a state monad so you can memoize on the go:

{-# LANGUAGE FlexibleContexts #-}
-- ... 

import Control.Monad.State (MonadState, gets, modify)
import Data.Map (Map, insert)
-- ... 

memoM :: (Ord k, MonadState (Map k b) m) => (Tree k a -> m b) -> (Tree k a -> m b)
memoM f = f'
  where f' t = do
        let k = key t
        v <- gets $ lookup k
        case v of
          Just b -> return b
          Nothing -> do
            b <- f t
            modify $ insert k b
            return b

-- example of use
sumM :: (Ord k, MonadState (Map k Int) m) => Tree k Int -> m Int
sumM = memoM $ \t -> case t of
        Leaf (_,b) -> return b
        Branch l (_,b) r -> do
          lsum <- sumM l
          rsum <- sumM r
          return $ lsum + b + rsum

Upvotes: 5

Related Questions