n0tis
n0tis

Reputation: 828

Numpy slicing gives unexpected result

Does anybody have an explanation for the unexpected numpy slicing results dislplayed below ?

Unexpected behavior demo

import torch
import numpy as np

some_array = np.zeros((1, 3, 42))
chooser_mask = np.zeros((42))
# mask will pick 2 values
chooser_mask[13] = 1
chooser_mask[14] = 1

out_1 = some_array[0, :, chooser_mask == 1]
print(out_1.shape)  # shape 2x3 (unexpected !!!)

Dim "3" was in the front, I expect it to say in the front

workaround the weird behavior

tmp = some_array[0]
out_2 = tmp[:, chooser_mask == 1]
print(out_2.shape)  # shape is 3x2 (expected)

pytorch version does not display the unexpected behavior

some_array = torch.from_numpy(some_array)
chooser_mask = torch.from_numpy(chooser_mask)
out_1 = some_array[0, :, chooser_mask == 1]
print(out_1.shape)  # shape 3x2 (expected)

tmp = some_array[0]
out_2 = tmp[:, chooser_mask == 1]
print(out_2.shape)  # shape is 3x2 (expected)

Upvotes: 0

Views: 62

Answers (0)

Related Questions