dainichi
dainichi

Reputation: 2231

Partial memoization in Haskell

I'm trying to find a good way to memoize a function for only part of its domain (non-negative integers) in Haskell, using Data.MemoCombinators.

import Data.MemoCombinators

--approach 1

partFib n | n < 0 = undefined
          | otherwise = integral fib n where
  fib 0 = 1
  fib 1 = 1
  fib k = partFib (k-1) + partFib (k-2)

--approach 2

partFib2 n | n < 0 = undefined
           | otherwise = fib n
fib = integral fib'
  where
    fib' 0 = 1
    fib' 1 = 1
    fib' n = partFib2 (n-1) + partFib2 (n-2)

Approach 1 is how I would like to do it, however, it doesn't seem to work. I assume this is because the fib function is "recreated" every time partFib is called, throwing away the memoization. fib doesn't depend on the input of partFib, so you would assume that the compiler could hoist it, but apparently GHC doesn't work that way.

Approach 2 is how I end up doing it. Eerk, a lot of ugly wiring.

Does anybody know of a better way to do this?

Upvotes: 5

Views: 384

Answers (3)

luqui
luqui

Reputation: 60543

There is a combinator in the library for this purpose:

switch :: (a -> Bool) -> Memo a -> Memo a -> Memo a  

switch p a b uses the memo table a whenever p gives true and the memo table b whenever p gives false.

Recall that id is technically a memoizer (which does not memoize :-), so you can do:

partFib = Memo.switch (< 0) id Memo.integral fib'
    where
    ...

Upvotes: 4

Ankur
Ankur

Reputation: 33657

Hmm what about separating things a bit:

fib 0 = 0
fib 1 = 1
fib x = doFib (x-1) + doFib (x-2)

memFib = Memo.integral fib

doFib n | n < 0 = fib n
        | otherwise memFib n

Now you need to use doFib.

Upvotes: 4

JB.
JB.

Reputation: 42154

Not quite sure what's "ugly" to your eye, but you can have proper memoization while using only a single top-level identifier by lifting the memoization operation out of the function of n.

partFib3 = \n -> if n < 0 then undefined else fib' n
    where fib 0 = 1
          fib 1 = 1
          fib k = partFib3 (k-1) + partFib3 (k-2)
          fib' = integral fib

Upvotes: 6

Related Questions