Reputation: 1949
I've been trying to complete this exercise on hackerrank in time. But my following Haskell solution fails on test case 13 to 15 due to time out.
import Data.Vector(Vector(..),fromList,(!),(//),toList)
import Data.Vector.Mutable
import qualified Data.Vector as V
import Data.ByteString.Lazy.Char8 (ByteString(..))
import qualified Data.ByteString.Lazy.Char8 as L
import Data.ByteString.Lazy.Builder
import Data.Maybe
import Control.Applicative
import Data.Monoid
import Prelude hiding (length)
readInt' = fst . fromJust . L.readInt
toB [] = mempty
toB (x:xs) = string8 (show x) <> string8 " " <> toB xs
main = do
[firstLine, secondLine] <- L.lines <$> L.getContents
let [n,k] = map readInt' $ L.words firstLine
let xs = largestPermutation n k $ fromList $ map readInt' $ Prelude.take n $ L.words secondLine
L.putStrLn $ toLazyByteString $ toB $ toList xs
largestPermutation n k v
| i >= l || k == 0 = v
| n == x = largestPermutation (n-1) k v
| otherwise = largestPermutation (n-1) (k-1) (replaceOne n x (i+1) (V.modify (\v' -> write v' i n) v))
where l = V.length v
i = l - n
x = v!i
replaceOne n x i v
| n == h = V.modify (\v' -> write v' i x ) v
| otherwise = replaceOne n x (i+1) v
where h = v!i
Most optimal solution that I've found constantly updates 2 arrays. One array being the main target, and other array being for fast index look ups.
public static void main(String[] args) {
Scanner input = new Scanner(System.in);
int n = input.nextInt();
int k = input.nextInt();
int[] a = new int[n];
int[] index = new int[n + 1];
for (int i = 0; i < n; i++) {
a[i] = input.nextInt();
index[a[i]] = i;
}
for (int i = 0; i < n && k > 0; i++) {
if (a[i] == n - i) {
continue;
}
a[index[n - i]] = a[i];
index[a[i]] = index[n - i];
a[i] = n - i;
index[n - i] = i;
k--;
}
for (int i = 0; i < n; i++) {
System.out.print(a[i] + " ");
}
}
Upvotes: 2
Views: 257
Reputation: 77941
One optimization you can do to mutable arrays is not to use them at all. In particular, the problem you have linked to has a right fold solution.
The idea being that you fold the list and greedily swap the items with the largest value to the right and maintain swaps already made in a Data.Map
:
import qualified Data.Map as M
import Data.Map (empty, insert)
solve :: Int -> Int -> [Int] -> [Int]
solve n k xs = foldr go (\_ _ _ -> []) xs n empty k
where
go x run i m k
-- out of budget to do a swap or no swap necessary
| k == 0 || y == i = y : run (pred i) m k
-- make a swap and record the swap made in the map
| otherwise = i : run (pred i) (insert i y m) (k - 1)
where
-- find the value current position is swapped with
y = find x
find k = case M.lookup k m of
Just a -> find a
Nothing -> k
In above, run
is a function which given the reverse index i
, current mapping m
and the remaining swap budget k
, solves the rest of the list onwards. By reverse index I mean indices of the list in the reverse direction: n, n - 1, ..., 1
.
The folding function go
, builds the run
function at each step by updating values of i
, m
and k
which are passed to the next step. At the end we call this function with initial parameters i = n
, m = empty
and initial swap budget k
.
The recursive search in find
can be optimized out by maintaining a reverse map, but this already performs much faster than the java code you have posted.
Edit: Above solution, still pays a logarithmic cost for tree access. Here is an alternative solution using mutable STUArray
and monadic fold foldM_
, which in fact performs faster than above:
import Control.Monad.ST (ST)
import Control.Monad (foldM_)
import Data.Array.Unboxed (UArray, elems, listArray, array)
import Data.Array.ST (STUArray, readArray, writeArray, runSTUArray, thaw)
-- first 3 args are the scope, which will be curried
swap :: STUArray s Int Int -> STUArray s Int Int -> Int
-> Int -> Int -> ST s Int
swap _ _ _ 0 _ = return 0 -- out of budget to make a swap
swap arr rev n k i = do
xi <- readArray arr i
if xi + i == n + 1
then return k -- no swap necessary
else do -- make a swap, and reduce budget
j <- readArray rev (n + 1 - i)
writeArray rev xi j
writeArray arr j xi
writeArray arr i (n + 1 - i)
return $ pred k
solve :: Int -> Int -> [Int] -> [Int]
solve n k xs = elems $ runSTUArray $ do
arr <- thaw (listArray (1, n) xs :: UArray Int Int)
rev <- thaw (array (1, n) (zip xs [1..]) :: UArray Int Int)
foldM_ (swap arr rev n) k [1..n]
return arr
Upvotes: 7
Reputation: 11602
Not exactly an answer to #2, but there is a left fold solution that requires loading at most ~K values in memory at a time.
Because the problem deals with permutations, we know that 1 through N will appear in the output. If K > 0, at least the first K terms are going to be N, N-1, ... N - K, because we can afford at least K swaps. In addition, we expect some (K/N) digits to be in their optimal position.
This suggests an algorithm:
Initialize a map / dictionary and scan input xs
as zip xs [n, n-1..]
. For every (x, i)
, if x \= i
, we 'decrement' K
and update out dictionary s.t. dct[i] = x
. This procedure terminates when K == 0
(out of swaps) or we run out of input (can output {N, N-1, ... 1}).
Next, if we have any more x <- xs
we look at each one and print x
if x
is not in our dictionary or dct[x]
otherwise.
The above algorithm can fail to produce an optimal permutation only if our dictionary contains a cycle. In that case, we moved around elements with absolute value >= K
using |cycle|
swaps. But this means that we moved one element to its original position! So we can always save a swap on every cycle (i.e. increment K
).
Finally, this gives the memory efficient algorithm.
Step 0: get N, K
Step 1: Read the input permutation and output {N, N-1, ... N-K-E}, N <- N - K - E, K <- 0, update dict as per above,
where E = number of elements X equal to N - (index of X)
Step 2: remove and count cycles from dict; let cycles
= number of cycles; if cycles > 0
, let K <- |cycles|
, go to step 1,
else go to step 3. We can make this step more efficient by optimizing the dict.
Step 3: Output the rest of the input as is.
The following Python code implements the idea and can be made quite fast if better cycle detection is used. Of course, data better be read in chunks, unlike below.
from collections import deque
n, t = map(int, raw_input().split())
xs = deque(map(int, raw_input().split()))
dct = {}
cycles = True
while cycles:
while t > 0 and xs:
x = xs.popleft()
if x != n:
dct[n] = x
t -= 1
print n,
n -= 1
cycles = False
for k, v in dct.items():
visited = set()
cycle = False
while v in dct:
if v in visited:
cycle = True
break
visited.add(v)
v, buf = dct[v], v
dct[buf] = v
if cycle:
cycles = True
for i in visited:
del dct[i]
t += 1
else:
dct[k] = v
while xs:
x = xs.popleft()
print dct.get(x, x),
Upvotes: 1