Ali_Sh
Ali_Sh

Reputation: 2816

How to create groups based on an index array contains unequal shaped arrays

I have an array and want to create groups based on an object-type like index array (which can get two different type forms, as that are shown under ids in the following code) e.g.:

poss = np.array([[-0.09060061464812283, -0.0686586855826326, -0.02647255439666547],
                 [0.08835853305001438, -0.06574591502086484, -0.05487544905792754],
                 [-0.06629393005890775, 0.00666283311876477, 0.06629393005890775],
                 [-0.06387415159186867, -0.1541029899066564, -0.03157952415698592],
                 [-0.0764878092204729, -0.17648780922047289, -0.04332722135612625]])

ids = np.array([[1, 3, 2], [4, 3, 1, 0]])
# <class 'numpy.ndarray'> <class 'list'> <class 'int'>
# <class 'numpy.ndarray'> <class 'numpy.ndarray'> <class 'numpy.int64'>

poss and ids are larger and in the above code they are just as an example. I know that if index array be as shape 2*4, it can be achieved by poss[ids, :]; But when the arrays in the index array have dissimilar shapes, what is the best way to create these groups Just by NumPy (preferably without looping)?

expected list/array:

# [[[ 0.08835853305001438 -0.06574591502086484 -0.05487544905792754]
#   [-0.06387415159186867 -0.1541029899066564  -0.03157952415698592]
#   [-0.06629393005890775  0.00666283311876477  0.06629393005890775]]
#
#  [[-0.0764878092204729  -0.17648780922047289 -0.04332722135612625]
#   [-0.06387415159186867 -0.1541029899066564  -0.03157952415698592]
#   [ 0.08835853305001438 -0.06574591502086484 -0.05487544905792754]
#   [-0.09060061464812283 -0.0686586855826326  -0.02647255439666547]]]

Upvotes: 0

Views: 105

Answers (1)

mozway
mozway

Reputation: 262114

This is not possible without at least looping on ids, as you have a ragged array (different sizes).

You can use:

out = np.split(poss[np.hstack(ids)], np.cumsum(list(map(len, ids)))[:-1])

output:

[array([[ 0.08835853, -0.06574592, -0.05487545],
        [-0.06387415, -0.15410299, -0.03157952],
        [-0.06629393,  0.00666283,  0.06629393]]),
 array([[-0.07648781, -0.17648781, -0.04332722],
        [-0.06387415, -0.15410299, -0.03157952],
        [ 0.08835853, -0.06574592, -0.05487545],
        [-0.09060061, -0.06865869, -0.02647255]])]

How it works:

  • stack ids as a single array a slice poss
  • split this array on the cumulated lengths of ids (here 3,7, the last split point being removed as this is the total length)

Upvotes: 1

Related Questions