nomicflux
nomicflux

Reputation: 45

Haskell Optimizations for List Processing stymied by Lazy Evaluation

I'm trying to improve the efficiency of the following code. I want to count all occurrences of a symbol before a given point (as part of pattern-matching using a Burrows-Wheeler transform). There's some overlap in how I'm counting symbols. However, when I have tried to implement what looks like it should be more efficient code, it turns out to be less efficient, and I'm assuming that lazy evaluation and my poor understanding of it is to blame.

My first attempt at a counting function went like this:

count :: Ord a => [a] -> a -> Int -> Int
count list sym pos = length . filter (== sym) . take pos $ list

Then in the body of the matching function itself:

matching str refCol pattern = match 0 (n - 1) (reverse pattern)
  where n = length str
        refFstOcc sym = length $ takeWhile (/= sym) refCol
        match top bottom [] = bottom - top + 1
        match top bottom (sym : syms) =
          let topCt = count str sym top
              bottomCt = count str sym (bottom + 1)
              middleCt = bottomCt - topCt
              refCt = refFstOcc sym
          in if middleCt > 0
               then match (refCt + topCt) (refCt + bottomCt - 1) syms
               else 0

(Stripped down for brevity - I'm memoizing first occurrences of symbols in refCol through a Map, and a couple other details as well).

Edit: Sample use would be:

matching "AT$TCTAGT" "$AACGTTTT" "TCG"

which should be 1 (assuming I didn't mistype anything).

Now, I'm recounting everything in the middle between the top pointer and the bottom twice, which adds up when I count a million character DNA string with only 4 possible choices for characters (and profiling tells me that this is the big bottleneck, too, taking 48% of my time for bottomCt and around 38% of my time for topCt). For reference, when calculating this for a million character string and trying to match 50 patterns (each of which is between 1 and 1000 characters), the program takes about 8.5 to 9.5 seconds to run.

However, if I try to implement the following function:

countBetween :: Ord a => [a] -> a -> Int -> Int -> (Int, Int)
countBetween list sym top bottom =
  let (topList, bottomList) = splitAt top list
      midList = take (bottom - top) bottomList
      getSyms = length . filter (== sym)
  in (getSyms topList, getSyms midList)

(with changes made to the matching function to compensate), the program takes between 18 and 22 seconds to run.

I've also tried passing in a Map which can keep track of previous calls, but that also takes about 20 seconds to run and runs up the memory usage.

Similarly, I've shorted length . filter (== sym) to a fold, but again - 20 seconds for foldr, and 14-15 for foldl.

So what would be a proper Haskell way to optimize this code through rewriting it? (Specifically, I'm looking for something that doesn't involve precomputation - I may not be reusing strings very much - and which explains something of why this is happening).

Edit: More clearly, what I am looking for is the following:

a) Why does this behaviour happen in Haskell? How does lazy evaluation play a role, what optimizations is the compiler making to rewrite the count and countBetween functions, and what other factors may be involved?

b) What is a simple code rewrite which would address this issue so that I don't traverse the lists multiple times? I'm looking specifically for something which addresses that issue, rather than a solution which sidesteps it. If the final answer is, count is the most efficient possible way to write the code, why is that?

Upvotes: 3

Views: 166

Answers (2)

ErikR
ErikR

Reputation: 52057

I'm not sure lazy evaluation has much to do with the performance of the code. I think the main problem is the use of String - which is a linked list - instead of more performant string type.

Note that this call in your countBetween function:

  let (topList, bottomList) = splitAt top list

will re-create the linked link corresponding to topList meaning a lot more allocations.

A Criterion benchmark to compare splitAt versus using take n/drop n may be found here: http://lpaste.net/174526. The splitAt version is about 3 times slower and, of course, has a lot more allocations.

Even if you don't want to "pre-compute" the counts you can improve matters a great deal by simply switching to either ByteString or Text.

Define:

countSyms :: Char -> ByteString -> Int -> Int -> Int
countSyms sym str lo hi =
  length [ i | i <- [lo..hi], BS.index str i == sym ]

and then:

countBetween :: ByteString -> Char -> Int -> Int -> (Int,Int)
countBetween str sym top bottom = (a,b)
  where a = countSyms sym str 0 (top-1)
        b = countSyms sym str top (bottom-1)

Also, don't use reverse on large lists - it will reallocate the entire list. Just index into a ByteString / Text in reverse.

Memoizing counts may or may not help. It all depends on how it's done.

Upvotes: 1

ErikR
ErikR

Reputation: 52057

It seems that the main point of the match routine is to transform a interval (bottom,top) to another interval based on the current symbol sym. The formulas are basically:

ref_fst = index of sym in ref_col
  -- defined in an outer scope

match :: Char -> (Int,Int) -> (Int,Int)
match sym (bottom, top) | bottom > top =  (bottom, top) -- if the empty interval
match sym (bottom, top) =
  let 
    top_count = count of sym in str from index 0 to top
    bot_count = count of sym in str from index 0 to bottom
    mid_count = top_count - bot_count
  in if mid_count > 0
         then (ref_fst + bot_count, ref_fst + top_count)
         else (1,0)  -- the empty interval

And then matching is just a fold over pattern using match with the initial interval (0, n-1).

Both top_count and bot_count can be computed efficiently using a precomputed lookup table, and below is code which does that.

If you run test1 you'll see a trace of how the interval is transformed via each symbol in the pattern.

Note: There may be off-by-1 errors, and I've hard coded ref_fst to be 0 - I'm not sure how this fits into the larger algorithm, but the basic idea should be sound.

Note that once the counts vector has been created there is no need to index into the original string anymore. Therefore, even though I use a ByteString here for the (larger) DNA sequence, it's not crucial, and the mkCounts routine should work just as well if passed a String instead.

Code also available at http://lpaste.net/174288

{-# LANGUAGE OverloadedStrings #-}

import Data.Vector.Unboxed ((!))
import qualified Data.Vector.Unboxed as UV
import qualified Data.Vector.Unboxed.Mutable as UVM
import qualified Data.ByteString.Char8 as BS
import Debug.Trace
import Text.Printf
import Data.List

mkCounts :: BS.ByteString -> UV.Vector (Int,Int,Int,Int)
mkCounts syms = UV.create $ do
  let n = BS.length syms
  v <- UVM.new (n+1)
  let loop x i | i >= n = return x
      loop x i = let s = BS.index syms i
                     (a,t,c,g) = x
                     x' = case s of
                            'A' -> (a+1,t,c,g)
                            'T' -> (a,t+1,c,g)
                            'C' -> (a,t,c+1,g)
                            'G' -> (a,t,c,g+1)
                            _   -> x
                 in do UVM.write v i x
                       loop x' (i+1) 
  x <- loop (0,0,0,0) 0
  UVM.write v n x
  return v

data DNA = A | C | T | G
  deriving (Show)

getter :: DNA -> (Int,Int,Int,Int) -> Int
getter A (a,_,_,_) = a
getter T (_,t,_,_) = t
getter C (_,_,c,_) = c
getter G (_,_,_,g) = g

-- narrow a window
narrow :: Int -> UV.Vector (Int,Int,Int,Int) -> DNA -> (Int,Int) ->  (Int,Int)

narrow refcol counts sym (lo,hi) | trace msg False = undefined
  where msg = printf "-- lo: %d  hi: %d  refcol: %d  sym: %s  top_cnt: %d  bot_count: %d" lo hi refcol (show sym) top_count bot_count
        top_count = getter sym (counts ! (hi+1))
        bot_count = getter sym (counts ! lo)

narrow refcol counts sym (lo,hi) =
  let top_count = getter sym (counts ! (hi+1))
      bot_count = getter sym (counts ! (lo+0))
      mid_count = top_count - bot_count
  in if mid_count > 0
       then ( refcol + bot_count, refcol + top_count-1 )
       else (lo+1,lo)  -- signal an wmpty window

findFirst :: DNA -> UV.Vector (Int,Int,Int,Int)  -> Int
findFirst sym v =
  let n = UV.length v
      loop i | i >= n = n
      loop i = if getter sym (v ! i) > 0
                 then i
                 else loop (i+1)
  in loop 0

toDNA :: String -> [DNA]
toDNA str = map charToDNA str

charToDNA :: Char -> DNA
charToDNA = go
  where go 'A' = A
        go 'C' = C
        go 'T' = T
        go 'G' = G

dnaToChar A = 'A'
dnaToChar C = 'C'
dnaToChar T = 'T'
dnaToChar G = 'G'

first :: DNA -> BS.ByteString -> Int
first sym str = maybe len id (BS.elemIndex (dnaToChar sym) str)
  where len = BS.length str

test2 = do
 -- matching "AT$TCTAGT" "$AACGTTTT" "TCG"
  let str    = "AT$TCTAGT"
      refcol = "$AACGTTTT"
      syms   = toDNA "TCG"

      -- hard coded for now
      -- may be computeed an memoized
      refcol_G = 4
      refcol_C = 3
      refcol_T = 5

      counts = mkCounts str
      w0 = (0, BS.length str -1)

      w1 = narrow refcol_G counts G w0
      w2 = narrow refcol_C counts C w1
      w3 = narrow refcol_T counts T w2

      firsts = (first A refcol, first T refcol, first C refcol, first G refcol)

  putStrLn $ "firsts: " ++ show firsts

  putStrLn $ "w0: " ++ show w0
  putStrLn $ "w1: " ++ show w1
  putStrLn $ "w2: " ++ show w2
  putStrLn $ "w3: " ++ show w3
  let (lo,hi) = w3
      len = if lo <= hi then hi - lo + 1 else 0
  putStrLn $ "length: " ++ show len

matching :: BS.ByteString -> BS.ByteString -> String -> Int
matching  str refcol pattern = 
  let counts = mkCounts str
      n = BS.length str
      syms = toDNA (reverse pattern)
      firsts = (first A refcol, first T refcol, first C refcol, first G refcol)

      go (lo,hi) sym = narrow refcol counts sym (lo,hi)
        where refcol = getter sym firsts

      (lo, hi) = foldl' go (0,n-1) syms
      len = if lo <= hi then hi - lo + 1 else 0
  in len

test3 = matching "AT$TCTAGT" "$AACGTTTT" "TCG"

Upvotes: 1

Related Questions