Ryan
Ryan

Reputation: 8241

Partition training data by class in NumPy

I have a 50000 x 784 data matrix (50000 samples and 784 features) and the corresponding 50000 x 1 class vector (classes are integers 0-9). I'm looking for an efficient way to group the data matrix into 10 data matrices and class vectors that each have only the data for a particular class 0-9.

I can't seem to find an elegant way to do this, aside from just looping through the data matrix and constructing the 10 other matrices that way.

Does anyone know if there is a clean way to do this with something in scipy, numpy, or sklearn?

Upvotes: 4

Views: 1321

Answers (3)

mprat
mprat

Reputation: 2471

If your data and labels matrices are in numpy format, you can do:

data_class_3 = data[labels == 3, :]

If they aren't, turn them into numpy format:

import numpy as np
data = np.array(data)
labels = np.array(labels)
data_class_3 = data[labels == 3, :]

You can loop and do this for all labels automatically if you like. Something like this:

import numpy as np
split_classes = np.array([data[labels == i, :] for i in range(10)])

Upvotes: 3

B. M.
B. M.

Reputation: 18628

After @Jaime numpy optimal answer, I suggest you pandas, specialized in data manipulations :

import pandas
df=pandas.DataFrame(data,index=classes).sort_index()

then df.loc[i] is your class i.

if you want a list, just do

 metadata=[df.loc[i].values for i in range(10)]

so metadata[i] is the subset you want, or make a panel with pandas. All that is based on numpy arrays, so efficiency is preserved.

Upvotes: 1

Jaime
Jaime

Reputation: 67427

Probably the cleanest way of doing this in numpy, especially if you have many classes, is through sorting:

SAMPLES = 50000
FEATURES = 784
CLASSES = 10
data = np.random.rand(SAMPLES, FEATURES)
classes = np.random.randint(CLASSES, size=SAMPLES)

sorter = np.argsort(classes)
classes_sorted = classes[sorter]
splitter, = np.where(classes_sorted[:-1] != classes_sorted[1:])
data_splitted = np.split(data[sorter], splitter + 1)

data_splitted will be a list of arrays, one for each class found in classes. Running the above code with SAMPLES = 10, FEATURES = 2 and CLASSES = 3 I get:

>>> data
array([[ 0.45813694,  0.47942962],
       [ 0.96587082,  0.73260743],
       [ 0.70539842,  0.76376921],
       [ 0.01031978,  0.93660231],
       [ 0.45434223,  0.03778273],
       [ 0.01985781,  0.04272293],
       [ 0.93026735,  0.40216376],
       [ 0.39089845,  0.01891637],
       [ 0.70937483,  0.16077439],
       [ 0.45383099,  0.82074859]])

>>> classes
array([1, 1, 2, 1, 1, 2, 0, 2, 0, 1])

>>> data_splitted 
[array([[ 0.93026735,  0.40216376],
        [ 0.70937483,  0.16077439]]),
 array([[ 0.45813694,  0.47942962],
        [ 0.96587082,  0.73260743],
        [ 0.01031978,  0.93660231],
        [ 0.45434223,  0.03778273],
        [ 0.45383099,  0.82074859]]),
 array([[ 0.70539842,  0.76376921],
        [ 0.01985781,  0.04272293],
        [ 0.39089845,  0.01891637]])]

If you want to make sure the sort is stable, i.e. that data points in the same class remain in the same relative order after sorting, you will need to specify sorter = np.argsort(classes, kind='mergesort').

Upvotes: 3

Related Questions