Reputation: 4805
I have the following function, which does what I want using numpy.array
, but breaks when feeding a torch.Tensor
due to indexing errors.
import torch
import numpy as np
def combination_matrix(arr):
idxs = np.arange(len(arr))
idx = np.ix_(idxs, idxs)
mesh = np.stack(np.meshgrid(idxs, idxs))
def np_combination_matrix():
output = np.zeros((len(arr), len(arr), 2, *arr.shape[1:]), dtype=arr.dtype)
num_dims = len(output.shape)
output[idx] = arr[mesh].transpose((2, 1, 0, *np.arange(3, num_dims)))
return output
def torch_combination_matrix():
output = torch.zeros(len(arr), len(arr), 2, *arr.shape[1:], dtype=arr.dtype)
num_dims = len(output.shape)
print(arr[mesh].shape) # <-- This is wrong/different to numpy!
output[idx] = arr[mesh].permute(2, 1, 0, *np.arange(3, num_dims))
return output
if isinstance(arr, np.ndarray):
return np_combination_matrix()
elif isinstance(arr, torch.Tensor):
return torch_combination_matrix()
The problem is that arr[mesh]
results in different dimensions, depending on numpy and torch. Apparently, pytorch does not support indexing with index arrays of different dimensionality than the array being indexed. Ideally, the following should work:
features = np.arange(9).reshape(3, 3)
np_combs = combination_matrix(features)
features = torch.from_numpy(features)
torch_combs = combination_matrix(features)
assert np.array_equal(np_combs, torch_combs.numpy())
But the dimensions are different:
(2, 3, 3, 3)
torch.Size([3, 3])
Which results in an error (logically):
Traceback (most recent call last):
File "/home/XXX/util.py", line 226, in <module>
torch_combs = combination_matrix(features)
File "/home/XXX/util.py", line 218, in combination_matrix
return torch_combination_matrix()
File "/home/XXX/util.py", line 212, in torch_combination_matrix
output[idx] = arr[mesh].permute(2, 1, 0, *np.arange(3, num_dims))
RuntimeError: number of dims don't match in permute
How do I match the torch behavior to numpy? I've read various questions on the torch forums (e.g. this one with only one dimension), but could find how to apply this here. Similarly, index_select only works for one dimension, but I need it to work for at least 2 dimensions.
Upvotes: 2
Views: 3945
Reputation: 4805
It is actually embarrassingly easy. You just need to flatten the indices, then reshape and permute the dimensions. This is the full working version:
import torch
import numpy as np
def combination_matrix(arr):
idxs = np.arange(len(arr))
idx = np.ix_(idxs, idxs)
mesh = np.stack(np.meshgrid(idxs, idxs))
def np_combination_matrix():
output = np.zeros((len(arr), len(arr), 2, *arr.shape[1:]), dtype=arr.dtype)
num_dims = len(output.shape)
output[idx] = arr[mesh].transpose((2, 1, 0, *np.arange(3, num_dims)))
return output
def torch_combination_matrix():
output_shape = (2, len(arr), len(arr), *arr.shape[1:]) # Note that this is different to numpy!
return arr[mesh.flatten()].reshape(output_shape).permute(2, 1, 0, *range(3, len(output_shape)))
if isinstance(arr, np.ndarray):
return np_combination_matrix()
elif isinstance(arr, torch.Tensor):
return torch_combination_matrix()
I used pytest to run this on random arrays of different dimensions, and it seems to work in all cases:
import pytest
@pytest.mark.parametrize('random_dims', range(1, 5))
def test_combination_matrix(random_dims):
dim_size = np.random.randint(1, 40, size=random_dims)
elements = np.random.random(size=dim_size)
np_combs = combination_matrix(elements)
features = torch.from_numpy(elements)
torch_combs = combination_matrix(features)
assert np.array_equal(np_combs, torch_combs.numpy())
if __name__ == '__main__':
pytest.main(['-x', __file__])
Upvotes: 3