Reputation: 8699
I have a Haskell program which simulates the Ising model with the Metropolis algorithm. The main operation is a stencil operation that takes the sum of next neighbors in 2D and then multiplies that with the center element. Then the element is perhaps updated.
In C++, where I get decent performance, I use a 1D array and then linearize the
access to it with simple index arithmetics. In the past months I have picked up Haskell to broaden my horizon and also tried to implement the Ising model there. The data structure is just a list of Bool
:
type Spin = Bool
type Lattice = [Spin]
Then I have some fixed extent:
extent = 30
And a get
function which retrieves a particular lattice site, including periodic boundary conditions:
-- Wrap a coordinate for periodic boundary conditions.
wrap :: Int -> Int
wrap = flip mod $ extent
-- Converts an unbounded (x,y) index into a linearized index with periodic
-- boundary conditions.
index :: Int -> Int -> Int
index x y = wrap x + wrap y * extent
-- Retrieve a single element from the lattice, automatically performing
-- periodic boundary conditions.
get :: Lattice -> Int -> Int -> Spin
get l x y = l !! index x y
I use the same thing in C++ and there it works fine, though I know that the
std::vector
guarantees me fast random access.
While profiling, I found that the get
function takes up a significant amount
of computing time:
COST CENTRE MODULE SRC no. entries %time %alloc %time %alloc
get Main ising.hs:36:1-26 153 899100 8.3 0.4 9.2 1.9
index Main ising.hs:31:1-36 154 899100 0.5 1.2 0.9 1.5
wrap Main ising.hs:26:1-24 155 0 0.4 0.4 0.4 0.4
neighborSum Main ising.hs:(40,1)-(43,56) 133 899100 4.9 16.6 46.6 25.3
spin Main ising.hs:(21,1)-(22,17) 135 3596400 0.5 0.4 0.5 0.4
neighborSum.neighbors Main ising.hs:43:9-56 134 899100 0.9 0.7 0.9 0.7
neighborSum.retriever Main ising.hs:42:9-40 136 899100 0.4 0.0 40.2 7.6
neighborSum.retriever.\ Main ising.hs:42:32-40 137 3596400 0.2 0.0 39.8 7.6
get Main ising.hs:36:1-26 138 3596400 33.7 1.4 39.6 7.6
index Main ising.hs:31:1-36 139 3596400 3.1 4.7 5.9 6.1
wrap Main ising.hs:26:1-24 141 0 2.7 1.4 2.7 1.4
I have read that the Haskell list is only good when one pushes/pops elements at the front, so performance is only given when one uses it as a stack.
When I “update” the lattice, I use splitAt
and then ++
to return a new list which has the one element changed.
Is there something relatively straightforward that I can do do improve the random access performance?
The full code is here:
-- Copyright © 2017 Martin Ueding <[email protected]>
-- Ising model with the Metropolis algorithm. Random choice of lattice site for
-- a spin flip.
import qualified Data.Text
import System.Random
type Spin = Bool
type Lattice = [Spin]
-- Lattice extent is fixed to a square.
extent = 30
volume = extent * extent
temperature :: Double
temperature = 0.0
-- Converts a `Spin` into `+1` or `-1`.
spin :: Spin -> Int
spin True = 1
spin False = (-1)
-- Wrap a coordinate for periodic boundary conditions.
wrap :: Int -> Int
wrap = flip mod $ extent
-- Converts an unbounded (x,y) index into a linearized index with periodic
-- boundary conditions.
index :: Int -> Int -> Int
index x y = wrap x + wrap y * extent
-- Retrieve a single element from the lattice, automatically performing
-- periodic boundary conditions.
get :: Lattice -> Int -> Int -> Spin
get l x y = l !! index x y
-- Computes the sum of neighboring spings.
neighborSum :: Lattice -> Int -> Int -> Int
neighborSum l x y = sum $ map spin $ map retriever neighbors
where
retriever = \(x, y) -> get l x y
neighbors = [(x+1,y), (x-1,y), (x,y+1), (x,y-1)]
-- Computes the energy difference at a certain lattice site if it would be
-- flipped.
energy :: Lattice -> Int -> Int -> Int
energy l x y = 2 * neighborSum l x y * (spin (get l x y))
-- Converts a full lattice into a textual representation.
latticeToString l = unlines lines
where
spinToChar :: Spin -> String
spinToChar True = "#"
spinToChar False = "."
line :: String
line = concat $ map spinToChar l
lines :: [String]
lines = map Data.Text.unpack $ Data.Text.chunksOf extent $ Data.Text.pack line
-- Populates a lattice given a random seed.
initLattice :: Int -> (Lattice,StdGen)
initLattice s = (l,rng)
where
rng = mkStdGen s
allRandom :: Lattice
allRandom = randoms rng
l = take volume allRandom
-- Performs a single Metropolis update at the given lattice site.
update (l,rng) x y
| doUpdate = (l',rng')
| otherwise = (l,rng')
where
shift = energy l x y
r :: Double
(r,rng') = random rng
doUpdate :: Bool
doUpdate = (shift < 0) || (exp (- fromIntegral shift / temperature) > r)
i = index x y
(a,b) = splitAt i l
l' = a ++ [not $ head b] ++ tail b
-- A full sweep through the lattice.
doSweep (l,rng) = doSweep' (l,rng) (extent * extent)
-- Implementation that does the needed number of sweeps at a random lattice
-- site.
doSweep' (l,rng) 0 = (l,rng)
doSweep' (l,rng) i = doSweep' (update (l,rng'') x y) (i - 1)
where
x :: Int
(x,rng') = random rng
y :: Int
(y,rng'') = random rng'
-- Creates an IO action that prints the lattice to the screen.
printLattice :: (Lattice,StdGen) -> IO ()
printLattice (l,rng) = do
putStrLn ""
putStr $ latticeToString l
dummy :: (Lattice,StdGen) -> IO ()
dummy (l,rng) = do
putStr "."
-- Creates a random lattice and performs five sweeps.
main = do
let lrngs = iterate doSweep $ initLattice 2
mapM_ dummy $ take 1000 lrngs
Upvotes: 4
Views: 1364
Reputation: 50819
With profiling turned off, your original version runs in about 5 seconds on my laptop.
Converting the code to use an immutable, unboxed vector (from Data.Vector.Unboxed
) is a straightforward modification and reduces the run time to about 1.8 seconds. Profiling that version shows that the time is dominated by the very slow System.Random
generator.
Using a custom generator based on the random-mersenne-pure64
package, I can get the run time down to about 0.32 seconds. Using a linear congruential generator brings the time down to 0.22 seconds.
Re-profiling, the bottleneck appears to be bounds checking on vector operations, so replacing those with their "unsafe" counterparts gets the run time down to about 0.17 seconds.
At this point, converting to a mutable, unboxed vector (which is a more involved modification than before) didn't appreciably improve performance, but I didn't work very hard on optimizing it. (I've seen other algorithms that benefited enormously from using mutable vectors.)
My final code for the LCG version follows. I tried to keep as much of your original code as was reasonable.
The one annoying bit is the necessity of specifying the extentBits
for the random index generation, and note that the algorithm will be most efficient if the extent is a power of two (because randomIndex
generates indexes using the given number of extentBits
, and it re-tries until the index is less than extent
).
Note that I decided to print the number of True
s in the final lattice instead of using a dummy
call, since it's a little more reliable for benchmarking.
import Data.Bits ((.&.), shiftL)
import Data.Word
import qualified Data.Vector as V
type Spin = Bool
type Lattice = V.Vector Spin
-- Lattice extent is fixed to a square.
extent, extentBits, volume :: Int
extent = 30
extentBits = 5 -- no of bits s.t. 2**5 >= 30
volume = extent * extent
temperature :: Double
temperature = 0.0
-- Converts a `Spin` into `+1` or `-1`.
spin :: Spin -> Int
spin True = 1
spin False = (-1)
-- Wrap a coordinate for periodic boundary conditions.
wrap :: Int -> Int
wrap = flip mod $ extent
-- Converts an unbounded (x,y) index into a linearized index with periodic
-- boundary conditions.
index :: Int -> Int -> Int
index x y = wrap x + wrap y * extent
-- Retrieve a single element from the lattice, automatically performing
-- periodic boundary conditions.
get :: Lattice -> Int -> Int -> Spin
get l x y = l `V.unsafeIndex` index x y
-- Toggle the spin of an element
toggle :: Lattice -> Int -> Int -> Lattice
toggle l x y = l `V.unsafeUpd` [(i, not (l `V.unsafeIndex` i))] -- flip bit at index i
where i = index x y
-- Computes the sum of neighboring spins.
neighborSum :: Lattice -> Int -> Int -> Int
neighborSum l x y = sum $ map spin $ map (uncurry (get l)) neighbors
where
neighbors = [(x+1,y), (x-1,y), (x,y+1), (x,y-1)]
-- Computes the energy difference at a certain lattice site if it would be
-- flipped.
energy :: Lattice -> Int -> Int -> Int
energy l x y = 2 * neighborSum l x y * spin (get l x y)
-- Populates a lattice given a random seed.
initLattice :: Int -> (Lattice,MyGen)
initLattice s = (l, rng')
where
rng = newMyGen s
(allRandom, rng') = go [] rng volume
go out r 0 = (out, r)
go out r n = let (a,r') = randBool r
in go (a:out) r' (n-1)
l = V.fromList allRandom
-- Performs a single Metropolis update at the given lattice site.
update :: (Lattice, MyGen) -> Int -> Int -> (Lattice, MyGen)
update (l, rng) x y
| doUpdate = (toggle l x y, rng')
| otherwise = (l, rng')
where
doUpdate = (shift < 0) || (exp (- fromIntegral shift / temperature) > r)
shift = energy l x y
(r, rng') = randDouble rng
-- A full sweep through the lattice.
doSweep :: (Lattice, MyGen) -> (Lattice, MyGen)
doSweep (l, rng) = iterate updateRand (l, rng) !! (extent * extent)
updateRand :: (Lattice, MyGen) -> (Lattice, MyGen)
updateRand (l, rng)
= let (x, rng') = randIndex rng
(y, rng'') = randIndex rng'
in update (l, rng'') x y
-- Creates a random lattice and performs five sweeps.
main :: IO ()
main = do let lrngs = iterate doSweep (initLattice 2)
l = fst (lrngs !! 1000)
print $ V.length (V.filter id l) -- count the Trues
-- * Random number generation
data MyGen = MyGen Word32
newMyGen :: Int -> MyGen
newMyGen = MyGen . fromIntegral
-- | Get a (positive) integer with given number of bits.
randInt :: Int -> MyGen -> (Int, MyGen)
randInt bits (MyGen s) =
let s' = 1664525 * s + 1013904223
mask = (1 `shiftL` bits) - 1
in (fromIntegral (s' .&. mask), MyGen s')
-- | Random Bool value
randBool :: MyGen -> (Bool, MyGen)
randBool g = let (i, g') = randInt 1 g
in (if i==1 then True else False, g')
-- | Random index
randIndex :: MyGen -> (Int, MyGen)
randIndex g = let (i, g') = randInt extentBits g
in if i >= extent then randIndex g' else (i, g')
-- | Random [0,1]
randDouble :: MyGen -> (Double, MyGen)
randDouble rng = let (ri, rng') = randInt 32 rng
in (fromIntegral ri / (2**32), rng')
If you prefer to use the MT generator, you can modify the imports and replace a few definitions as below. Note that I didn't work too hard on testing randInt
, so I'm not 100% sure it's 100% correct with all the bit twiddling that's going on there.
import Data.Bits ((.|.), shiftL, shiftR, xor)
import Data.Word
import qualified Data.Vector as V
import System.Random.Mersenne.Pure64
-- replace these definitions:
-- | Mersenne-Twister generator w/ pool of bits
data MyGen = MyGen PureMT !Int !Word64 !Int !Word64
newMyGen :: Int -> MyGen
newMyGen seed = MyGen (pureMT (fromIntegral seed)) 0 0 0 0
-- | Split w into bottom n bits and rest
splitBits :: Int -> Word64 -> (Word64, Word64)
splitBits n w =
let w2 = w `shiftR` n -- top 64-n bits
w1 = (w2 `shiftL` n) `xor` w -- bottom n bits
in (w1, w2)
-- | Get a (positive) integer with given number of bits.
randInt :: Int -> MyGen -> (Int, MyGen)
randInt bits (MyGen p lft1 w1 lft2 w2)
-- generate at least 64 bits
| let lft = lft1 + lft2, lft < 64
= let w1' = w1 .|. (w2 `shiftL` lft1)
(w2', p') = randomWord64 p
in randInt bits (MyGen p' lft w1' 64 w2')
| bits > 64 = error "randInt has max of 64 bits"
-- if not enough bits in first word, get needed bits from second
| bits > lft1
= let needed = bits - lft1
(bts, w2') = splitBits needed w2
out = (w1 `shiftL` needed) .|. bts
in (fromIntegral out, MyGen p (lft2 - needed) w2' 0 0)
-- otherwise, just take enough bits from first word
| otherwise
= let (out, w1') = splitBits bits w1
in (fromIntegral out, MyGen p (lft1 - bits) w1' lft2 w2)
Upvotes: 3
Reputation: 120711
You can always use Data.Vector.Unboxed
, which is basically the same as std::vector
. It has very fast random access, however it doesn't really allow purely-functional updates†. You can still do such updates by working in the ST
monad, and indeed that's probably the solution that would give you the best performance, but it's not really Haskell-idiomatic.
Better: use a functional structure that allows both lookup and update and log(n)-ish time; this is typical for tree-based structures. IntMap
should work pretty well.
I wouldn't recommend that either though. Generally, in Haskell you want to avoid juggling any indices at all. As you say, algorithms like Metropolis are actually based on a stencil. The operation on each spin shouldn't ever need to see more than its direct neighbours, so it's best to structure your program accordingly.
Even on a simple list, it's easy to achieve efficient access to the direct neighbours: implement
neighboursInList :: [a] -> [(a, (Maybe a, Maybe a))]
the actual algorithm is then just a map
over these local-environments.
For the periodic case, you should actually make it something like
data Lattice a = Lattice
{ latticeNodes :: [a]
, latticeLength :: Int }
deriving (Functor)
data NodeInLattice a = NodeInLattice
{ thisNode :: a
, xPrev, xNext, yPrev, yNext :: a }
deriving (Functor)
neighboursInLattice :: Lattice a -> Lattice (NodeInLattice a)
Such an approach has many advantages:
†To pure-functionally update a vector, you need to make a complete copy.
Upvotes: 10