Reputation: 2816
I have an array and want to create groups based on an object-type like index array (which can get two different type forms, as that are shown under ids
in the following code) e.g.:
poss = np.array([[-0.09060061464812283, -0.0686586855826326, -0.02647255439666547],
[0.08835853305001438, -0.06574591502086484, -0.05487544905792754],
[-0.06629393005890775, 0.00666283311876477, 0.06629393005890775],
[-0.06387415159186867, -0.1541029899066564, -0.03157952415698592],
[-0.0764878092204729, -0.17648780922047289, -0.04332722135612625]])
ids = np.array([[1, 3, 2], [4, 3, 1, 0]])
# <class 'numpy.ndarray'> <class 'list'> <class 'int'>
# <class 'numpy.ndarray'> <class 'numpy.ndarray'> <class 'numpy.int64'>
poss
and ids
are larger and in the above code they are just as an example. I know that if index array be as shape 2*4, it can be achieved by poss[ids, :]
; But when the arrays in the index array have dissimilar shapes, what is the best way to create these groups Just by NumPy (preferably without looping)?
expected list/array:
# [[[ 0.08835853305001438 -0.06574591502086484 -0.05487544905792754]
# [-0.06387415159186867 -0.1541029899066564 -0.03157952415698592]
# [-0.06629393005890775 0.00666283311876477 0.06629393005890775]]
#
# [[-0.0764878092204729 -0.17648780922047289 -0.04332722135612625]
# [-0.06387415159186867 -0.1541029899066564 -0.03157952415698592]
# [ 0.08835853305001438 -0.06574591502086484 -0.05487544905792754]
# [-0.09060061464812283 -0.0686586855826326 -0.02647255439666547]]]
Upvotes: 0
Views: 105
Reputation: 262114
This is not possible without at least looping on ids
, as you have a ragged array (different sizes).
You can use:
out = np.split(poss[np.hstack(ids)], np.cumsum(list(map(len, ids)))[:-1])
output:
[array([[ 0.08835853, -0.06574592, -0.05487545],
[-0.06387415, -0.15410299, -0.03157952],
[-0.06629393, 0.00666283, 0.06629393]]),
array([[-0.07648781, -0.17648781, -0.04332722],
[-0.06387415, -0.15410299, -0.03157952],
[ 0.08835853, -0.06574592, -0.05487545],
[-0.09060061, -0.06865869, -0.02647255]])]
How it works:
ids
as a single array a slice poss
split
this array on the cumulated lengths of ids (here 3,7, the last split point being removed as this is the total length)Upvotes: 1