Reputation: 430
I'm testing the speed of various memoizing methods. The code below compares two implementation of memoizing with an array. I tested this on a recursive function. The complete code is below
Running the program with stack test
for memoweird 1000
, memoweird 5000
etc, shows that IOArray
is consistently faster than STArray
by a couple seconds, and the difference seems to be O(1). However, running the same program with stack test --profile
reverses the result, and STArray
becomes consistently faster by about one second.
{-# LANGUAGE ScopedTypeVariables #-}
module Main where
import Data.Array
import Data.Array.ST
import Control.Monad.ST
import Data.Array.IO
import GHC.IO
import Control.Monad
import Data.Time
memoST :: forall a b. (Ix a)
=> (a, a) -- range of the argument memoized
-> ((a -> b) -- a recursive function, but uses it's first argument for recursive calls instead
-> a -> b)
-> (a -> b) -- memoized function
memoST r f = (runSTArray compute !)
where
compute :: ST s (STArray s a b)
compute= do
arr <- newArray_ r
forM_ (range r) (\i -> do
writeArray arr i $ f (memoST r f) i)
return arr
memoArray :: forall a b. (Ix a)
=> (a, a)
-> ((a -> b) -> a -> b)
-> a -> b
memoArray r f = (unsafePerformIO compute !) -- safe!
where
compute :: IO (Array a b)
compute = do
arr <- newArray_ r :: IO (IOArray a b)
forM_ (range r) (\i -> do
writeArray arr i$ f (memoArray r f) i)
freeze arr
weird :: (Int -> Int) -> Int -> Int
weird _ 0 = 0
weird _ 1 = 0
weird f i = f (i `div` 2) + f (i - 1) + 1
stweird :: Int -> Int
stweird n = memoST (0,n) weird n
arrayweird :: Int -> Int
arrayweird n = memoArray (0,n) weird n
main :: IO()
main = do
t0 <- getCurrentTime
print (stweird 5000)
t1 <- getCurrentTime
print (arrayweird 5000)
t2 <- getCurrentTime
let sttime = diffUTCTime t0 t1
let artime = diffUTCTime t1 t2
print (sttime - artime)
Is there a reason why the profiling overhead is so different (albeit small) on the two array types?
I'm using Stack Version 2.7.3, GHC version 8.10.4 on OS X.
Some data on my computer.
Running this a couple times:
Without Profiling:
-0.222663s -0.116007s -0.202765s -0.205319s -0.130202s
Avg -0.1754s
Std 0.0486s
With Profiling:
0.608895s -0.755541s -0.61222s -0.83613s 0.450045s
1.879662s -0.181789s 3.251379s 0.359211s 0.122721s
Avg 0.4286s
Std 1.2764s
Apparently, the random fluctuations of the profiler covers the difference up. The data here is not sufficient to confirm a difference.
Upvotes: 3
Views: 139
Reputation: 27023
You really should use criterion
for benchmarking.
benchmarking stweird
time 3.116 s (3.109 s .. 3.119 s)
1.000 R² (1.000 R² .. 1.000 R²)
mean 3.112 s (3.110 s .. 3.113 s)
std dev 2.220 ms (953.8 μs .. 2.807 ms)
variance introduced by outliers: 19% (moderately inflated)
benchmarking marrayweird
time 3.170 s (2.684 s .. 3.602 s)
0.997 R² (0.989 R² .. 1.000 R²)
mean 3.204 s (3.148 s .. 3.280 s)
std dev 72.66 ms (1.810 ms .. 88.94 ms)
variance introduced by outliers: 19% (moderately inflated)
My system is noisy, but it does appear that the standard deviations don't overlap. I don't actually care much about figuring out why, though, because the code is exceptionally slow. 3 seconds for memoizing 5000 operations? Something has gone horribly wrong.
The code as written is a super-exponential algorithm - there's no sharing of memoized functions in the memoization code. Each sub-evaluation could create an entirely new array and populate it. You're being saved from that situation by two things. First is laziness - most values are never evaluated. The upshot here is that the algorithm will actually terminate, instead of eagerly evaluating array entries forever. Second, and more importantly, GHC does some constant-lifting, lifting the expression (memoST r f)
(or the arrayST
version) out of the loop body. This creates sharing within each loop body so that the two sub-calls actually share memoization. It's not great, but it's actually doing some speedup. But it's mostly accidental.
The traditional approach to this sort of memoization is to just let laziness do the necessary mutation:
memoArray
:: forall a b. (Ix a)
=> (a, a)
-> ((a -> b) -> a -> b)
-> a -> b
memoArray r f = fetch
where
fetch n = arr ! n
arr = listArray r $ map (f fetch) (range r)
Note the knot-tying between fetch
and arr
internally. This ensures that the same array is used in every calculation. It benchmarks a bit better:
benchmarking arrayweird
time 212.0 μs (211.5 μs .. 212.6 μs)
1.000 R² (0.999 R² .. 1.000 R²)
mean 213.3 μs (212.4 μs .. 215.0 μs)
std dev 4.104 μs (2.469 μs .. 6.194 μs)
variance introduced by outliers: 12% (moderately inflated)
213 microseconds is much more what I'd expect from only 5000 iterations. Still, one might be curious whether doing explicit sharing could work with the other variants. And it can:
memoST'
:: forall a b. (Ix a)
=> (a, a)
-> ((a -> b) -> a -> b)
-> a -> b
memoST' r f = fetch
where
fetch n = arr ! n
arr = runSTArray compute
compute :: ST s (STArray s a b)
compute = do
a <- newArray_ r
forM_ (range r) $ \i -> do
writeArray a i $ f fetch i
return a
memoMArray'
:: forall a b. (Ix a)
=> (a, a)
-> ((a -> b) -> a -> b)
-> a -> b
memoMArray' r f = fetch
where
fetch n = arr ! n
arr = unsafePerformIO compute
compute :: IO (Array a b)
compute = do
a <- newArray_ r :: IO (IOArray a b)
forM_ (range r) $ \i -> do
writeArray a i $ f fetch i
freeze a
Those use explicit sharing to introduce the same sort of knot-tying, though significantly more indirectly.
benchmarking stweird'
time 168.1 μs (167.1 μs .. 169.9 μs)
1.000 R² (0.999 R² .. 1.000 R²)
mean 167.1 μs (166.7 μs .. 167.8 μs)
std dev 1.636 μs (832.3 ns .. 3.007 μs)
benchmarking marrayweird'
time 171.1 μs (170.7 μs .. 171.7 μs)
1.000 R² (1.000 R² .. 1.000 R²)
mean 170.9 μs (170.5 μs .. 171.4 μs)
std dev 1.554 μs (1.076 μs .. 2.224 μs)
And those actually seem to beat the listArray
variant. I really don't know what's up with that. listArray
must be doing some surprising extra amount of work. Oh well.
In the end, I don't actually know what's leading to these small performance differences. But none of them are significant in comparison to actually using an efficient algorithm.
Full code, for your perusal:
{-# LANGUAGE ScopedTypeVariables #-}
module Main where
import Data.Array
import Data.Array.ST
import Data.Array.Unsafe
import Control.Monad.ST
import Data.Array.IO
import GHC.IO.Unsafe
import Control.Monad
import Criterion.Main
memoST
:: forall a b. (Ix a)
=> (a, a)
-> ((a -> b) -> a -> b)
-> a -> b
memoST r f = (runSTArray compute !)
where
compute :: ST s (STArray s a b)
compute = do
arr <- newArray_ r
forM_ (range r) $ \i -> do
writeArray arr i $ f (memoST r f) i
return arr
memoMArray
:: forall a b. (Ix a)
=> (a, a)
-> ((a -> b) -> a -> b)
-> a -> b
memoMArray r f = (unsafePerformIO compute !)
where
compute :: IO (Array a b)
compute = do
arr <- newArray_ r :: IO (IOArray a b)
forM_ (range r) $ \i -> do
writeArray arr i $ f (memoMArray r f) i
freeze arr
memoArray
:: forall a b. (Ix a)
=> (a, a)
-> ((a -> b) -> a -> b)
-> a -> b
memoArray r f = fetch
where
fetch n = arr ! n
arr = listArray r $ map (f fetch) (range r)
memoST'
:: forall a b. (Ix a)
=> (a, a)
-> ((a -> b) -> a -> b)
-> a -> b
memoST' r f = fetch
where
fetch n = arr ! n
arr = runSTArray compute
compute :: ST s (STArray s a b)
compute = do
a <- newArray_ r
forM_ (range r) $ \i -> do
writeArray a i $ f fetch i
return a
memoMArray'
:: forall a b. (Ix a)
=> (a, a)
-> ((a -> b) -> a -> b)
-> a -> b
memoMArray' r f = fetch
where
fetch n = arr ! n
arr = unsafePerformIO compute
compute :: IO (Array a b)
compute = do
a <- newArray_ r :: IO (IOArray a b)
forM_ (range r) $ \i -> do
writeArray a i $ f fetch i
freeze a
weird :: (Int -> Int) -> Int -> Int
weird _ 0 = 0
weird _ 1 = 0
weird f i = f (i `div` 2) + f (i - 1) + 1
stweird :: Int -> Int
stweird n = memoST (0, n) weird n
marrayweird :: Int -> Int
marrayweird n = memoMArray (0, n) weird n
arrayweird :: Int -> Int
arrayweird n = memoArray (0, n) weird n
stweird' :: Int -> Int
stweird' n = memoST' (0, n) weird n
marrayweird' :: Int -> Int
marrayweird' n = memoMArray' (0, n) weird n
main :: IO()
main = do
let rounds = 5000
print $ stweird rounds
print $ marrayweird rounds
print $ arrayweird rounds
print $ stweird' rounds
print $ marrayweird' rounds
putStrLn ""
defaultMain
[ bench "stweird" $ whnf stweird rounds
, bench "marrayweird" $ whnf marrayweird rounds
, bench "arrayweird" $ whnf arrayweird rounds
, bench "stweird'" $ whnf stweird' rounds
, bench "marrayweird'" $ whnf marrayweird' rounds
]
Upvotes: 2