Manuel Eberl
Manuel Eberl

Reputation: 8258

Memoisation with auxiliary parameter in Haskell

I have a recursive function f that takes two parameters x and y. The function is uniquely determined by the first parameter; the second one merely makes things easier.

I now want to memoise that function w.r.t. it's first parameter while ignoring the second one. (I.e. f is evaluated at most one for every value of x)

What is the easiest way to do that? At the moment, I simply define an array containing all values recursively, but that is a somewhat ad-hoc solution. I would prefer some kind of memoisation combinator that I can just throw at my function.

EDIT: to clarify, the function f takes a pair of integers and a list. The first integer is some parameter value, the second one denotes the index of an element in some global list xs to consume.

To avoid indexing the list, I pass the partially consumed list to f as well, but obviously, the invariant is that if the first parameter is (m, n), the second one will always be drop n xs, so the result is uniquely determined by the first parameter.

Just using a memoisation combinator on the partially applied function will not work, since that will leave an unevaluated thunk \xs -> … lying around. I could probably wrap the two parameters in a datatype whose Eq instance ignores the second value (and similarly for other instances), but that seems like a very ad-hoc solution. Is there not an easier way?

EDIT2: The concrete function I want to memoise:

g :: [(Int, Int)] -> Int -> Int
g xs n = f 0 n
  where f :: Int -> Int -> Int
        f _ 0 = 0
        f m n
            | m == length xs  = 0
            | w > n           = f (m + 1) n
            | otherwise       = maximum [f (m + 1) n, v + f (m + 1) (n - w)]
          where (w, v) = xs !! m

To avoid the expensive indexing operation, I instead pass the partially-consumed list to f as well:

g' :: [(Int, Int)] -> Int -> Int
g' xs n = f xs 0 n
  where f :: [(Int, Int)] -> Int -> Int -> Int
        f []           _ _ = 0
        f _            _ 0 = 0
        f ((w,v) : xs) m n
            | w > n           = f xs (m + 1) n
            | otherwise       = maximum [f xs (m + 1) n, v + f xs (m + 1) (n - w)]

Memoisation of f w.r.t. the list parameter is, of course, unnecessary, since the list does not (morally) influence the result. I would therefore like the memoisation to simply ignore the list parameter.

Upvotes: 1

Views: 198

Answers (1)

Aadit M Shah
Aadit M Shah

Reputation: 74204

Your function is unnecessarily complicated. You don't need the index m at all:

foo :: [(Int, Int)] -> Int -> Int
foo []         _ = 0
foo _          0 = 0
foo ((w,v):xs) n
    | w > n      = foo xs n
    | otherwise  = foo xs n `max` foo xs (n - w) + v

Now if you want to memoize foo then both the arguments must be considered (as it should be).

We'll use the monadic memoization mixin method to memoize foo:

  1. First, we create an uncurried version of foo (because we want to memoize both arguments):

    foo' :: ([(Int, Int)], Int) -> Int
    foo' ([],       _) = 0
    foo' (_,        0) = 0
    foo' ((w,v):xs, n)
        | w > n       = foo' (xs, n)
        | otherwise   = foo' (xs, n) `max` foo' (xs, n - w) + v
    
  2. Next, we monadify the function foo' (because we want to thread a memo table in the function):

    foo' :: Monad m => ([(Int, Int)], Int) -> m Int
    foo' ([],       _) = return 0
    foo' (_,        0) = return 0
    foo' ((w,v):xs, n)
        | w > n        = foo' (xs, n)
        | otherwise    = do
            a <- foo' (xs, n)
            b <- foo' (xs, n - w)
            return (a `max` b + v)
    
  3. Then, we open the self-reference in foo' (because we want to call the memoized function):

    type Endo a = a -> a
    
    foo' :: Monad m => Endo (([(Int, Int)], Int) -> Int)
    foo' _    ([],       _) = return 0
    foo' _    (_,        0) = return 0
    foo' self ((w,v):xs, n)
        | w > n             = foo' (xs, n)
        | otherwise         = do
            a <- self (xs, n)
            b <- self (xs, n - w)
            return (a `max` b + v)
    
  4. We'll use the following memoization mixin to memoize our function foo':

    type Dict a b m = (a -> m (Maybe b), a -> b -> m ())
    
    memo :: Monad m => Dict a b m -> Endo (a -> m b)
    memo (check, store) super a = do
        b <- check a
        case b of
            Just b  -> return b
            Nothing -> do
                b <- super a
                store a b
                return b
    
  5. Our dictionary (memo table) will use the State monad and a Map data structure:

    import Prelude hiding (lookup)
    import Control.Monad.State
    import Data.Map.Strict
    
    mapDict :: Ord a => Dict a b (State (Map a b))
    mapDict = (check, store) where
        check a   = gets (lookup a)
        store a b = modify (insert a b)
    
  6. Finally, we combine everything to create a memoized function memoFoo:

    import Data.Function (fix)
    
    type MapMemoized a b = a -> State (Map a b) b
    
    memoFoo :: MapMemoized ([(Int, Int)], Int) Int
    memoFoo = fix (memo mapDict . foo')
    
  7. We can recover the original function foo as follows:

    foo :: [(Int, Int)] -> Int -> Int
    foo xs n = evalState (memoFoo (xs, n)) empty
    

Hope that helps.

Upvotes: 2

Related Questions