Reputation: 828
Does anybody have an explanation for the unexpected numpy slicing results dislplayed below ?
import torch
import numpy as np
some_array = np.zeros((1, 3, 42))
chooser_mask = np.zeros((42))
# mask will pick 2 values
chooser_mask[13] = 1
chooser_mask[14] = 1
out_1 = some_array[0, :, chooser_mask == 1]
print(out_1.shape) # shape 2x3 (unexpected !!!)
Dim "3" was in the front, I expect it to say in the front
tmp = some_array[0]
out_2 = tmp[:, chooser_mask == 1]
print(out_2.shape) # shape is 3x2 (expected)
some_array = torch.from_numpy(some_array)
chooser_mask = torch.from_numpy(chooser_mask)
out_1 = some_array[0, :, chooser_mask == 1]
print(out_1.shape) # shape 3x2 (expected)
tmp = some_array[0]
out_2 = tmp[:, chooser_mask == 1]
print(out_2.shape) # shape is 3x2 (expected)
Upvotes: 0
Views: 62