Simd
Simd

Reputation: 21274

Simple loop with good performance in Haskell

I am starting in Haskell and am interested in how to get matching performance for simple code I would normally write in C or Python. Consider the following problem.

You are given a long string of 1s and 0s of length n. We want to output, for each substring of length m, the number of 1s in that window. That is the output has n-m+1 different possible values between 0 and m inclusive.

In C this is very simple to do in time proportional to n and using extra space (on top of the space needed to store the input) proportional to m bits. You just count the number of 1s in the first window of length m and then maintain two pointers, one to the start of the window and one to the end and increment or decrement depending of whether one points to a 1 and the other points to a 0 or the opposite occurs.

Is it possible to get the same theoretical performance in a purely functional way in Haskell?

Some terrible code:

chunkBits m = helper
  where helper [] = []
        helper xs = sum (take m xs) : helper (drop m xs)

main = print $ chunkBits 5 [0,1,1,0,1,0,0,1,0,1,0,1,1,1,0,0,0,1]

Upvotes: 2

Views: 296

Answers (3)

Will Ness
Will Ness

Reputation: 71065

The most basic level is re-implementing the cool HOF-based algorithms with hand-written recursive functions to express the loops.

Banged patterns mark arguments as strict, so simple values can be calculated without unnecessary delay (this is implicitly taken care of when using scanl', for example). This also shows that "pointers" are just names:

{-# LANGUAGE BangPatterns #-}

-- assumes xs has only 0s and 1s
counts :: Int -> [Int] -> [Int]
counts m xs = g 0 m xs
  where
    g !c    0      ys  = h c ys xs
    g !c    _      []  = []                  -- m > |xs|
    g !c    m   (y:ys) = g (c+y) (m-1) ys
    h !c    []     _   = [c]
    h !c (y:ys) (x:xs) = c : h (c+y-x) ys xs

Testing,

 > counts [1,1,0,0,1,1,0,1] 2
[2,1,0,1,2,1,1]
 > counts [1,1,0,0,1,1,1,1] 3
[2,1,1,2,3,3]

Upvotes: 1

Zeta
Zeta

Reputation: 105876

C Code

Here is the C code you've described:

int sliding_window(const char * const str, const int n, const int m, int * result){
  const char * back  = str;
  const char * front = str + m;
  int sum = 0;
  int i;

  for(i = 0; i < m; ++i){
     sum += str[i] == '1';
  }

  *result++ = sum;

  for(; i < n; ++i){
    sum += *front++ == '1';
    sum -= *back++  == '1';
    *result++ = sum;
  }
  return n - m + 1;
}

Algorithm

The code above is apparently O(n), since we have n iterations. But lets go a step back and have a look at the underlying algorithm:

  1. Sum the first m elements. Keep this as sum. O(m)
  2. Our first window has sum 1s. O(1)
  3. Until we've exhausted our original string: O(n)
    1. "Slide" the window. O(1)
      • add 1 to sum if we gain a '1' by sliding O(1)
      • subtract 1 from sum if we lose a '1' by sliding O(1)
    2. Push sum onto the results. O(1)

Since n > m (otherwise there is no window), O(n) holds.

Moulding a Haskell variant

That's basically a left scan (scanl) with a way to get a list of those differences in (2.1.). So all we need is a way to somehow slide:

slide :: Int -> [Char] -> [Int]
slide m xs = zipWith f xs (drop m xs)
  where
    f '1' '0' = -1  -- we lose a one
    f '0' '1' =  1  -- we gain a one
    f  _   _  =  0  -- nothing :/

That's O(n), where n is the length of our list.

slidingWindow :: Int -> [Char] -> [Int]
slidingWindow m xs = scanl (+) start (slide m xs)
 where
   start = length (filter (== '1') (take m xs))

That's O(n), same as in C, since both use the same algorithm.

Caveats

In a real life application, you would always use Text or ByteString instead of String, since the latter is a list of Char with much overhead. Since you only use a string of '1' and '0', you can use ByteString:

import           Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as BS
import           Data.List (scanl')

slide :: Int -> ByteString -> [Int]
slide m xs = BS.zipWith f xs (BS.drop m xs)
  where
    f '1' '0' = -1
    f '0' '1' =  1
    f  _   _  =  0

slidingWindow :: Int -> ByteString -> [Int]
slidingWindow m xs = scanl' (+) start (slide m xs)
 where
   start = BS.count '1' (BS.take m xs)

Upvotes: 5

ErikR
ErikR

Reputation: 52029

Update

After reading the question more carefully I noticed that the C program reads its input from an array.

So here is an equivalent Haskell "pure" function which performs the task.

 import qualified Data.Vector as V
 import Data.List
 import Control.Monad

 count :: Int -> V.Vector Int -> [Int]
 count m v = 
   let c0 = V.sum (V.take m v)
       n = V.length v
       results = scanl' go c0 [0..n-m-1]
         where go r i = r - (v V.! i) + (v V.! (i+m))
   in results

 test1 = let v = V.fromList [0,0,1,1,1,1,1,0,0,0,0]
         in print $ count 3 v

Even though count returns a list it will be generated lazily. Moreover, if it is consume by another list operation it could be optimized via one of the various fusion techniques.

Original Answer

This is a good exercise, but why does it have to be "purely functional" (and what does that mean anyway)?

You can write the C algorithm in Haskell - it's not as terse, but it will generate essentially the same code.

 import Data.Vector.Unboxed.Mutable as V

 count m = do
   v <- V.replicate m '0'
   let toInt ch = if ch == '1' then 1 else 0
   let loop c i = do
         ch <- getChar
         oldch <- V.read v i
         let c' = c + toInt ch - toInt oldch
         V.write v i ch
         let i' = mod (i+1) m 
         putStrLn $ show c
         loop c' i'
   loop 0 0

 main = count 3

(For simplicity this generates n results.)

If you were benchmark this note that you are also including the performance of getChar and putStrLn and show, so it might be difficult to make a fair comparison with a C program. However, it has O(n) complexity and constant memory usage which is what I think you're asking for.

Upvotes: 1

Related Questions