Fnord
Fnord

Reputation: 5895

Grouping elements of a numpy array using an array of group counts

Given two arrays, one representing a stream of data, and another representing group counts, such as:

import numpy as np

# given group counts:       3            4           3         2
# given flattened data:[ 0, 1, 2,   3, 4, 5, 6,   7, 8, 9,   10, 11 ]

group_counts = np.array([3,4,3,2])
data = np.arange(group_counts.sum()) # placeholder data, real life application will be a very large array

I want to generate matrices based on the group counts for the streamed data, such as:

target_count = 3 # I want to make a matrix of all data items who's group_counts = target_count
# Expected result
# [[ 0  1  2]
#  [ 7  8  9]]

To do this I wrote the following:

# Find all matches
match = np.where(groups == group_target)[0]
i1 = np.cumsum(groups)[match] # start index for slicing
i0 = i1 - groups[match] # end index for slicing

# Prep the blank matrix and fill with resuls
matched_matrix = np.empty((match.size,target_count))

# Is it possible to get rid of this loop?
for i in xrange(match.size):
    matched_matrix[i] = data[i0[i]:i1[i]]

matched_matrix
# Result: array([[ 0,  1,  2],
                 [ 7,  8,  9]]) # 

This works, but I would like to get rid of the loop and I can't figure out how.

Doing some research I did find numpy.split and numpy.array_split:

match = np.where(group_counts == target_count)[0]
match = np.array(np.split(data,np.cumsum(groups)))[match]
# Result: array([array([0, 1, 2]), array([7, 8, 9])], dtype=object) # 

But numpy.split produces a list of dtype=object that I have to convert.

Is there an elegant way to produce the desired result without a loop?

Upvotes: 0

Views: 1522

Answers (1)

akuiper
akuiper

Reputation: 214967

You can repeat group_counts so it has the same size as data, then filter and reshape based on the target:

group_counts = np.array([3,4,3,2])
data = np.arange(group_counts.sum())

target = 3
data[np.repeat(group_counts, group_counts) == target].reshape(-1, target)

#array([[0, 1, 2],
#       [7, 8, 9]])

Upvotes: 1

Related Questions