Douglas Lewit
Douglas Lewit

Reputation: 199

My "memoized" pascal function doesn't really work

import Data.List (intercalate)
import Control.Concurrent (threadDelay)
import Data.Maybe (fromJust)
import System.IO


-- I love how amazingly concise Haskell code can be.  This same program in C, C++ or Java
-- would be at least twice as long.


pascal :: Int -> Int -> Int
pascal row col | col >= 0 && col <= row =
                 if row == 0 || col == 0 || row == col
                 then 1
                 else pascal (row - 1) (col - 1) + pascal (row - 1) col
pascal _ _ = 0


pascalFast' :: [((Int, Int), Int)] -> Int -> Int -> Int
pascalFast' dict row col | col > row = 0
pascalFast' dict row col | row == 0 || col == 0 || row == col = 1
pascalFast' dict row col =
  let value1 = lookup (row - 1, col - 1) dict
      value2 = lookup (row - 1, col) dict
  in if not(value1 == Nothing || value2 == Nothing)
     then (fromJust value1) + (fromJust value2)
     else let dict'  = ((row - 1, col), pascalFast' dict (row - 1) col) : dict
              dict'' = ((row - 1, col - 1), pascalFast' dict' (row - 1) (col - 1)) : dict' 
          in (pascalFast' dict'' (row - 1) col) + (pascalFast' dict'' (row - 1) (col - 1))


pascalFast = pascalFast' []
                

pascalsTriangle :: Int -> [[Int]]
pascalsTriangle rows =
  [[pascal row col | col <- [0..row]] | row <- [0..rows]]


main :: IO ()
main = do
  putStrLn "" 
  putStr "Starting at row #0, how many rows of Pascal's Triangle do you want to print out? "
  hFlush stdout
  numRows <- (\s -> read s :: Int) <$> getLine
  let triangle = pascalsTriangle numRows
      longestStringLength = (length . show) $ foldl1 max $ flatten triangle
      triangleOfStrings = map (intercalate ", ") $ map (map (pad longestStringLength)) triangle
      lengthOfLastDiv2 = div ((length . last) triangleOfStrings) 2 
  putStrLn ""
  mapM_ (\s -> let spaces = [' ' | x <- [1 .. lengthOfLastDiv2 - div (length s) 2]]
                   in (putStrLn $ spaces ++ s) >> threadDelay 200000) triangleOfStrings
  putStrLn ""
  

flatten :: [[a]] -> [a]
flatten xs =
  [xss | ys <- xs, xss <- ys]


pad :: Int -> Int -> String
pad i k =
  [' ' | _ <- [1..n]] ++ m
  where m = show k
        n = i - length m

For the life of me I do not understand why pascalFast isn't FAST!!! It type checks and mathematically it is correct, but my "pascalFast" function is just as slow as my "pascal" function. Any ideas? And no, this is not a homework assignment. It's something I just want to try for myself. Thanks for the feedback.

Best, Douglas Lewit

Upvotes: 1

Views: 124

Answers (1)

amalloy
amalloy

Reputation: 91907

Your main doesn't actually call pascalFast at all, so it's not clear to me exactly what you were doing that caused you to conclude it is slow - with some effort, I can tell it is slow from looking at it, but some evidence in the question would be nice.

As to why, two problems leap out at me. It seems to me that, because you pass the dictionary "upwards" to the base case but never pass it downwards or sideways, you are only caching results that you will never look at again. Try evaluating pascalFast [] 2 1 by hand, on paper, and see if you ever get a cache hit.

Secondly, even if you were caching correctly, using lookup will take time linear in the size of the list, so your runtime is at least quadratic in the number of entries generated: for each item you generate, you look at all the other items at least once. To cache efficiently you need a real data structure, like one from Data.Map.

But separate from the question of how to memoize effectively, it is often better to not memoize at all, by starting from the base cases and building up, rather than reaching down from the final result. Something like this is pretty classic for Pascal's Triangle:

triangle :: [[Int]]
triangle = iterate nextRow [1]
  where nextRow xs = 1 : zipWith (+) xs (tail xs) ++ [1]

main :: IO ()
main = print $ take 5 triangle

Upvotes: 2

Related Questions