slaw
slaw

Reputation: 6899

Efficient Way to Repeatedly Split Large NumPy Array and Record Middle

I have a large NumPy array nodes = np.arange(100_000_000) and I need to rearrange this array by:

  1. Recording and then removing the middle value in the array
  2. Split the array into the left half and right half
  3. Repeat Steps 1-2 for each half
  4. Stop when all values are exhausted

So, for a smaller input example nodes = np.arange(10), the output would be:

[5 2 8 1 4 7 9 0 3 6]

This was accomplished by naively doing:

import numpy as np

def split(node, out):
    mid = len(node) // 2
    out.append(node[mid])
    return node[:mid], node[mid+1:]


def reorder(a):
    nodes = [a.tolist()]
    out = []

    while nodes:
        tmp = []
        for node in nodes:
            for n in split(node, out):
                if n:
                    tmp.append(n)
        nodes = tmp

    return np.array(out)

if __name__ == "__main__":
    nodes = np.arange(10)
    print(reorder(nodes))

However, this is way too slow for nodes = np.arange(100_000_000) and so I am looking for a much faster solution.

Upvotes: 1

Views: 719

Answers (2)

Jérôme Richard
Jérôme Richard

Reputation: 50836

You can vectorize your function with Numpy by working on groups of slices.

Here is an implementation:

# Similar to [e for tmp in zip(a, b) for e in tmp] ,
# but on Numpy arrays and much faster
def interleave(a, b):
    assert len(a) == len(b)
    return np.column_stack((a, b)).reshape(len(a) * 2)

# n is the length of the input range (len(a) in your example)
def fast_reorder(n):
    if n == 0:
        return np.empty(0, dtype=np.int32)

    startSlices = np.array([0], dtype=np.int32)
    endSlices = np.array([n], dtype=np.int32)
    allMidSlices = np.empty(n, dtype=np.int32)  # Similar to "out" in your implementation
    midInsertCount = 0                               # Actual size of allMidSlices

    # Generate a bunch of middle values as long as there is valid slices to split
    while midInsertCount < n:
        # Generate the new mid/left/right slices
        midSlices = (endSlices + startSlices) // 2

        # Computing the next slices is not needed for the last step
        if midInsertCount + len(midSlices) < n:
            # Generate the nexts slices (possibly with invalid ones)
            newStartSlices = interleave(startSlices, midSlices+1)
            newEndSlices = interleave(midSlices, endSlices)

            # Discard invalid slices
            isValidSlices = newStartSlices < newEndSlices
            startSlices = newStartSlices[isValidSlices]
            endSlices = newEndSlices[isValidSlices]

        # Fast appending
        allMidSlices[midInsertCount:midInsertCount+len(midSlices)] = midSlices
        midInsertCount += len(midSlices)

    return allMidSlices[0:midInsertCount]

On my machine, this is 89 times faster than your scalar implementation with the input np.arange(100_000_000) dropping from 2min35 to 1.75s. It also consume far less memory (rougthly 3~4 times less). Note that if you want a faster code, then you probably need to use a native language like C or C++.

Upvotes: 1

David Oldford
David Oldford

Reputation: 1175

Edit: The question has been updated to have a much smaller input array so I leave the below for historical reasons. Basically it was likely a typo but we often get accustomed to computers working with insanely large numbers and when memory is involved they can be a real problem.

There is already a numpy based solution submitted by someone else that I think fits the bill.

Your code requires an insane amount of RAM just to hold 100 billion 64 bit integers. Do you have 800GB of RAM? Then you convert the numpy array to a list which will be substantially larger than the array (each packed 64 bit int in the numpy array will become a much less memory efficient python int object and the list will have a pointer to that object). Then you make a lot of slices of the list which will not duplicate the data but will duplicate the pointers to the data and use even more RAM. You also append all the result values to a list a single value at a time. Lists are very fast for adding items generally but with such an extreme size this will not only be slow but the way the list is allocated is likely to be extremely wasteful RAM wise and contribute to major problems (I believe they double in size when they get to a certain level of fullness so you will end up allocating more RAM than you need and doing many allocations and likely copies). What kind of machine are you running this on? There are ways to improve your code but unless you're running it on a super computer I don't know that you're going to ever finish that calculation. I only..only? have 32GB of RAM and I'm not going to even try to create a 100B int_64 numpy array as I don't want to use up ssd write life for a mass of virtual memory.

As for improving your code stick to numpy arrays don't change to a python list it will greatly increase the RAM you need. Preallocate a numpy array to put the answer in. Then you need a new algorithm. Anything recursive or recursive like (ie a loop splitting the input,) will require tracking a lot of state, your nodes list is going to be extraordinarily gigantic and again use a lot of RAM. You could use len(a) to indicate values that are removed from your list and scan through the entire array each time to figure out what to do next but that will save RAM in favour of a tremendous amount of searching a gigantic array. I feel like there is an algorithm to cut numbers from each end and place them in the output and just track the beginning and end but I haven't figured it out at least not yet.

I also think there is a simpler algorithm where you just track the number of splits you've done instead of making a giant list of slices and keeping it all in memory. Take the middle of the left half and then the middle of the right then count up one and when you take the middle of the left half's left half you know you have to jump to the right half then the count is one so you jump over to the original right half's left half and on and on... Based on the depth into the halves and the length of the input you should be able to jump around without scanning or tracking all of those slices though I haven't been able to dedicate much time to thinking this through in my head.

With a problem of this nature if you really need to push the limits you should consider using C/C++ so you can be as efficient as possible with RAM usage and because you're doing an insane number of tiny things which doesn't map well to python performance.

Upvotes: 1

Related Questions