dspyz
dspyz

Reputation: 5524

Slow array access in haskell?

I'm doing this Car Game problem on Kattis: https://open.kattis.com/problems/cargame There's a five-second time limit, but on the last instance, my code requires longer to run. I'm fairly sure I'm doing the right thing (from a big-O standpoint) so now I need to optimize it somehow. I downloaded the test data from: http://challenge.csc.kth.se/2013/challenge-2013.tar.bz2

From profiling, it seems like most of the running time is spent in containsSub which is nothing more than an array access together with a tail-recursive call. Furthermore, it's only called about 100M times, so why does it take 6.5 seconds to run (6.5 s on my laptop. I've found Kattis is generally about twice as slow, so probably more like 13 seconds). On the statistics page, some of the C++ solutions run in under a second. Even some python solutions just barely make it under the 5-second bar.

module Main where

import           Control.Monad
import           Data.Array            (Array, (!), (//))
import qualified Data.Array            as Array
import           Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as BS
import           Data.Char
import           Data.List
import           Data.Maybe

main::IO()
main = do
  [n, m] <- readIntsLn
  dictWords <- replicateM n BS.getLine
  let suffixChains = map (\w -> (w, buildChain w)) dictWords
  replicateM_ m $ findChain suffixChains

noWordMsg :: ByteString
noWordMsg = BS.pack "No valid word"

findChain :: [(ByteString, WordChain)] -> IO ()
findChain suffixChains = do
  chrs <- liftM (BS.map toLower) BS.getLine
  BS.putStrLn
    (
      case find (containsSub chrs . snd) suffixChains of
        Nothing -> noWordMsg
        Just (w, _) -> w
    )

readAsInt :: BS.ByteString -> Int
readAsInt = fst . fromJust . BS.readInt

readIntsLn :: IO [Int]
readIntsLn = liftM (map readAsInt . BS.words) BS.getLine

data WordChain = None | Rest (Array Char WordChain)

emptyChars :: WordChain
emptyChars = Rest . Array.listArray ('a', 'z') $ repeat None

buildChain :: ByteString -> WordChain
buildChain s =
  case BS.uncons s of
    Nothing -> emptyChars
    Just (hd, tl) ->
      let wc@(Rest m) = buildChain tl in
      Rest $ m // [(hd, wc)]

containsSub :: ByteString -> WordChain -> Bool
containsSub _ None = False
containsSub s (Rest m) =
  case BS.uncons s of
    Nothing -> True
    Just (hd, tl) -> containsSub tl (m ! hd)

EDIT: TAKE 2:

I tried building a lazy trie to avoid searching things I'd already searched. So for instance, if I've already encountered a triplet beginning with 'a', then in the future I can skip anything which doesn't contain an 'a'. If I've already searched a triplet beginning 'ab', I can skip anything which doesn't contain 'ab'. And if I've already searched the exact triplet 'abc', I can just return the same result from last time. In theory, this should contribute a significant speedup. In practice the running time is identical.

Furthermore, without the seq's, profiling takes forever and gives bogus results (I couldn't guess why). With the seqs, profiling says the bulk of the time is spent in forLetter (which is where the array accesses have been moved to so again it looks like array access is the slow part)

{-# LANGUAGE TupleSections #-}

module Main where

import           Control.Monad
import           Data.Array            (Array, (!), (//))
import qualified Data.Array            as Array
import qualified Data.Array.Base       as Base
import           Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as BS
import           Data.Char
import           Data.Functor
import           Data.Maybe

main::IO()
main = do
  [n, m] <- readIntsLn
  dictWords <- replicateM n BS.getLine
  let suffixChainsL = map (\w -> (w, buildChain w)) dictWords
  let suffixChains = foldr seq suffixChainsL suffixChainsL
  suffixChains `seq` doProbs m suffixChains

noWordMsg :: ByteString
noWordMsg = BS.pack "No valid word"

doProbs :: Int -> [(ByteString, WordChain)] -> IO ()
doProbs m chains = replicateM_ m doProb
  where
    cf = findChain chains
    doProb =
      do
        chrs <- liftM (map toLower) getLine
        BS.putStrLn . fromMaybe noWordMsg $ cf chrs

findChain :: [(ByteString, WordChain)] -> String -> Maybe ByteString
findChain [] = const Nothing
findChain suffixChains@(shd : _) = doFind
  where
    letterMap :: Array Char (String -> Maybe ByteString)
    letterMap =
      Array.listArray ('a','z')
        [findChain (mapMaybe (forLetter hd) suffixChains) | hd <- [0..25]]
    endRes = Just $ fst shd
    doFind :: String -> Maybe ByteString
    doFind [] = endRes
    doFind (hd : tl) = (letterMap ! hd) tl
    forLetter :: Int -> (ByteString, WordChain) -> Maybe (ByteString, WordChain)
    forLetter c (s, WC wc) = (s,) <$> wc `Base.unsafeAt` c

readAsInt :: BS.ByteString -> Int
readAsInt = fst . fromJust . BS.readInt

readIntsLn :: IO [Int]
readIntsLn = liftM (map readAsInt . BS.words) BS.getLine

newtype WordChain = WC (Array Char (Maybe WordChain))

emptyChars :: WordChain
emptyChars = WC . Array.listArray ('a', 'z') $ repeat Nothing

buildChain :: ByteString -> WordChain
buildChain = BS.foldr helper emptyChars
  where
    helper :: Char -> WordChain -> WordChain
    helper hd wc@(WC m) = m `seq` WC (m // [(hd, Just wc)])

Upvotes: 1

Views: 294

Answers (2)

dspyz
dspyz

Reputation: 5524

After much discussion on the #haskell and #ghc IRC channels, I found that the problem was related to this ghc bug: https://ghc.haskell.org/trac/ghc/ticket/1168

The solution was simply to change the definition of doProbs

doProbs m chains = cf `seq` replicateM_ m doProb
...

Or just to compile with -fno-state-hack

ghc's state hack optimization was causing it to unnecessarily recompute cf (and the associated letterMap) on every call.

So it has nothing to do with array accesses.

Upvotes: 1

ErikR
ErikR

Reputation: 52057

The uncons call in containsSub creates a new ByteString. Try speeding it up by keeping track of the offset into the string with an index, e.g.:

containsSub' :: ByteString -> WordChain -> Bool
containsSub' str wc = go 0 wc
  where len = BS.length str
        go _ None = False
        go i (Rest m) | i >= len = True
                      | otherwise = go (i+1) (m ! BS.index str i)

Upvotes: 2

Related Questions