Gaston Gilbert
Gaston Gilbert

Reputation: 33

Memoization of a single parameter in a multi-parameter function in Haskell

I am solving this programming problem, and I am exceeding the time limit with my current solution. I believe that the solution to my problem is memoization. However, I do not understand the memoization solutions documented here.

Here is the primary function in my current solution.

maxCuts :: Int -> Int -> Int -> Int -> Int
maxCuts n a b c  
    | n == 0    = 0
    | n < 0     = -10000
    | otherwise = max (max amax bmax) cmax
    where 
        amax = 1 + maxCuts (n - a) a b c
        bmax = 1 + maxCuts (n - b) a b c
        cmax = 1 + maxCuts (n - c) a b c

This function takes too long to run if a b and c are small relative to n. I would just copy the solution they used for the factorial function, but that function only takes one parameter. I have four parameters, but I only want to key the memoiziation on the first parameter, n. Notice that a b and c do not change in the recursive calls.

Upvotes: 0

Views: 66

Answers (2)

ErikR
ErikR

Reputation: 52029

Aside: Isn't your algorithm just computing something like div n (minimum [a,b,c])?

As you pointed out, the parameters a, b and c don't change, so first rewrite the function to place the parameter n at the end.

If you decide to use a list to memoize the function values it requires a little care to make sure GHC will save the mapped list:

import Debug.Trace

maxCuts' :: Int -> Int -> Int -> Int -> Int
maxCuts' a b c n = memoized_go n
  where
    memoized_go n
      | n < 0 = -10000
      | otherwise =  mapped_list !! n

    mapped_list = map go [0..]

    go n | trace msg False = undefined
      where msg = "go called for " ++ show n
    go 0 = 0
    go n = maximum [amax, bmax, cmax]
      where
        amax = 1 + memoized_go (n-a)
        bmax = 1 + memoized_go (n-b)
        cmax = 1 + memoized_go (n-c)

test1 = print $ maxCuts' 1 2 3 10

Note the circular dependency of the definitions: memoized_go depends on mapped_list which depends on go which depends on memozied_go.

Since lists only admit non-negative indexes, the case n < 0 has to be handled in a separate guard pattern.

The trace calls show that go is only called once per value of n. For instance, consider trying to do this without defining mapped_list:

maxCuts2 :: Int -> Int -> Int -> Int -> Int
maxCuts2 a b c n = memoized_go n
  where
    memoized_go n
      | n < 0 = -10000
      | otherwise =  (map go [0..]) !! n
    -- mapped_list = map go [0..]
    go n | trace msg False = undefined
      where msg = "go called for " ++ show n
    go 0 = 0
    go n = maximum [amax, bmax, cmax]
      where
        amax = 1 + memoized_go (n-a)
        bmax = 1 + memoized_go (n-b)
        cmax = 1 + memoized_go (n-c)

test2 = print $ maxCuts2 1 2 3 11

Running test2 shows that go is called multiple times for the same value of n.

Update

To avoid creating large unevaluated thunks, I would use BangPatterns for amax, bmax and cmax:

{-# LANGUAGE BangPatterns #-}

maxCuts' ... =
  ...
  where
    !amax = 1 + ...
    !bmax = 1 + ...
    !cmax = 1 + ...

Upvotes: 0

n. m. could be an AI
n. m. could be an AI

Reputation: 119857

Rewrite your function definition like this:

 maxCuts :: Int -> Int -> Int -> Int -> Int
 maxCuts n a b c = maxCuts' n where
     maxCuts' n
          | n == 0    = 0
          | n < 0     = -10000
          | otherwise = max (max amax bmax) cmax
            where 
               amax = 1 + maxCuts' (n - a)  
               bmax = 1 + maxCuts' (n - b) 
               cmax = 1 + maxCuts' (n - c)

Now you have a one-argument function you can memoize.

Upvotes: 2

Related Questions