Dulguun Otgon
Dulguun Otgon

Reputation: 1949

Optimizing mutable array state heavy manipulation code

I've been trying to complete this exercise on hackerrank in time. But my following Haskell solution fails on test case 13 to 15 due to time out. enter image description here

My Haskell solution

import           Data.Vector(Vector(..),fromList,(!),(//),toList)
import           Data.Vector.Mutable
import qualified Data.Vector as V 
import           Data.ByteString.Lazy.Char8 (ByteString(..))
import qualified Data.ByteString.Lazy.Char8 as L
import Data.ByteString.Lazy.Builder
import Data.Maybe
import Control.Applicative
import Data.Monoid
import Prelude hiding (length)

readInt' = fst . fromJust . L.readInt 
toB []     = mempty
toB (x:xs) = string8 (show x) <> string8 " " <> toB xs

main = do 
  [firstLine, secondLine] <- L.lines <$> L.getContents
  let [n,k] = map readInt' $ L.words firstLine
  let xs = largestPermutation n k $ fromList $ map readInt' $ Prelude.take n $ L.words secondLine
  L.putStrLn $ toLazyByteString $ toB $ toList xs


largestPermutation n k v
  | i >= l || k == 0 = v 
  | n == x           = largestPermutation (n-1) k v
  | otherwise        = largestPermutation (n-1) (k-1) (replaceOne n x (i+1) (V.modify (\v' -> write v' i n) v))
        where l = V.length v 
              i = l - n
              x = v!i

replaceOne n x i v
  | n == h = V.modify (\v' -> write v' i x ) v
  | otherwise = replaceOne n x (i+1) v
    where h = v!i 

Most optimal solution that I've found constantly updates 2 arrays. One array being the main target, and other array being for fast index look ups.

Better Java solution

public static void main(String[] args) {
  Scanner input = new Scanner(System.in);
  int n = input.nextInt();
  int k = input.nextInt();
  int[] a = new int[n];
  int[] index = new int[n + 1];
  for (int i = 0; i < n; i++) {
      a[i] = input.nextInt();
      index[a[i]] = i;
  }
  for (int i = 0; i < n && k > 0; i++) {
      if (a[i] == n - i) {
          continue;
      }
      a[index[n - i]] = a[i];
      index[a[i]] = index[n - i];
      a[i] = n - i;
      index[n - i] = i;
      k--; 
  }
  for (int i = 0; i < n; i++) {
      System.out.print(a[i] + " ");
  }
}

My question is

  1. What's the elegant and fast implementation of this algorithm in Haskell?
  2. Is there a faster way to do this problem than the Java solution?
  3. How should I deal with heavy array update elegantly and yet efficiently in Haskell in general?

Upvotes: 2

Views: 257

Answers (2)

behzad.nouri
behzad.nouri

Reputation: 77941

One optimization you can do to mutable arrays is not to use them at all. In particular, the problem you have linked to has a right fold solution.

The idea being that you fold the list and greedily swap the items with the largest value to the right and maintain swaps already made in a Data.Map:

import qualified Data.Map as M
import Data.Map (empty, insert)

solve :: Int -> Int -> [Int] -> [Int]
solve n k xs = foldr go (\_ _ _ -> []) xs n empty k
    where
    go x run i m k
        -- out of budget to do a swap or no swap necessary
        | k == 0 || y == i = y : run (pred i) m k
        -- make a swap and record the swap made in the map
        | otherwise        = i : run (pred i) (insert i y m) (k - 1)
        where
        -- find the value current position is swapped with
        y = find x
        find k = case M.lookup k m of
            Just a  -> find a
            Nothing -> k

In above, run is a function which given the reverse index i, current mapping m and the remaining swap budget k, solves the rest of the list onwards. By reverse index I mean indices of the list in the reverse direction: n, n - 1, ..., 1.

The folding function go, builds the run function at each step by updating values of i, m and k which are passed to the next step. At the end we call this function with initial parameters i = n, m = empty and initial swap budget k.

The recursive search in find can be optimized out by maintaining a reverse map, but this already performs much faster than the java code you have posted.


Edit: Above solution, still pays a logarithmic cost for tree access. Here is an alternative solution using mutable STUArray and monadic fold foldM_, which in fact performs faster than above:

import Control.Monad.ST (ST)
import Control.Monad (foldM_)
import Data.Array.Unboxed (UArray, elems, listArray, array)
import Data.Array.ST (STUArray, readArray, writeArray, runSTUArray, thaw)

-- first 3 args are the scope, which will be curried
swap :: STUArray s Int Int -> STUArray s Int Int -> Int
     -> Int -> Int -> ST s Int
swap   _   _ _ 0 _ = return 0  -- out of budget to make a swap
swap arr rev n k i = do
    xi <- readArray arr i
    if xi + i == n + 1
    then return k -- no swap necessary
    else do -- make a swap, and reduce budget
        j <- readArray rev (n + 1 - i)
        writeArray rev xi j
        writeArray arr j  xi
        writeArray arr i (n + 1 - i)
        return $ pred k

solve :: Int -> Int -> [Int] -> [Int]
solve n k xs = elems $ runSTUArray $ do
    arr <- thaw (listArray (1, n) xs :: UArray Int Int)
    rev <- thaw (array (1, n) (zip xs [1..]) :: UArray Int Int)
    foldM_ (swap arr rev n) k [1..n]
    return arr

Upvotes: 7

hilberts_drinking_problem
hilberts_drinking_problem

Reputation: 11602

Not exactly an answer to #2, but there is a left fold solution that requires loading at most ~K values in memory at a time.

Because the problem deals with permutations, we know that 1 through N will appear in the output. If K > 0, at least the first K terms are going to be N, N-1, ... N - K, because we can afford at least K swaps. In addition, we expect some (K/N) digits to be in their optimal position.

This suggests an algorithm:

Initialize a map / dictionary and scan input xs as zip xs [n, n-1..]. For every (x, i), if x \= i, we 'decrement' K and update out dictionary s.t. dct[i] = x. This procedure terminates when K == 0 (out of swaps) or we run out of input (can output {N, N-1, ... 1}).

Next, if we have any more x <- xs we look at each one and print x if x is not in our dictionary or dct[x] otherwise.

The above algorithm can fail to produce an optimal permutation only if our dictionary contains a cycle. In that case, we moved around elements with absolute value >= K using |cycle| swaps. But this means that we moved one element to its original position! So we can always save a swap on every cycle (i.e. increment K).

Finally, this gives the memory efficient algorithm.

Step 0: get N, K

Step 1: Read the input permutation and output {N, N-1, ... N-K-E}, N <- N - K - E, K <- 0, update dict as per above,

where E = number of elements X equal to N - (index of X)

Step 2: remove and count cycles from dict; let cycles = number of cycles; if cycles > 0, let K <- |cycles|, go to step 1,

else go to step 3. We can make this step more efficient by optimizing the dict.

Step 3: Output the rest of the input as is.

The following Python code implements the idea and can be made quite fast if better cycle detection is used. Of course, data better be read in chunks, unlike below.

from collections import deque

n, t = map(int, raw_input().split())

xs = deque(map(int, raw_input().split()))

dct = {}

cycles = True
while cycles:
    while t > 0 and xs:
        x = xs.popleft()
        if x != n:
            dct[n] = x
            t -= 1
        print n,
        n -= 1

    cycles = False
    for k, v in dct.items():
        visited = set()
        cycle = False
        while v in dct:
            if v in visited:
                cycle = True
                break
            visited.add(v)
            v, buf = dct[v], v
            dct[buf] = v
        if cycle:
            cycles = True
            for i in visited:
                del dct[i]
            t += 1
        else:
            dct[k] = v

while xs:
    x = xs.popleft()
    print dct.get(x, x),

Upvotes: 1

Related Questions