RunOrVeith
RunOrVeith

Reputation: 4805

Index pytorch tensor with different dimension index array

I have the following function, which does what I want using numpy.array, but breaks when feeding a torch.Tensor due to indexing errors.

import torch
import numpy as np


def combination_matrix(arr):
    idxs = np.arange(len(arr))
    idx = np.ix_(idxs, idxs)
    mesh = np.stack(np.meshgrid(idxs, idxs))

    def np_combination_matrix():
        output = np.zeros((len(arr), len(arr), 2, *arr.shape[1:]), dtype=arr.dtype)
        num_dims = len(output.shape)
        output[idx] = arr[mesh].transpose((2, 1, 0, *np.arange(3, num_dims)))
        return output

    def torch_combination_matrix():
        output = torch.zeros(len(arr), len(arr), 2, *arr.shape[1:], dtype=arr.dtype)
        num_dims = len(output.shape)
        print(arr[mesh].shape)  # <-- This is wrong/different to numpy!
        output[idx] = arr[mesh].permute(2, 1, 0, *np.arange(3, num_dims))
        return output

    if isinstance(arr, np.ndarray):
        return np_combination_matrix()
    elif isinstance(arr, torch.Tensor):
        return torch_combination_matrix()

The problem is that arr[mesh] results in different dimensions, depending on numpy and torch. Apparently, pytorch does not support indexing with index arrays of different dimensionality than the array being indexed. Ideally, the following should work:

features = np.arange(9).reshape(3, 3)
np_combs = combination_matrix(features)
features = torch.from_numpy(features)
torch_combs = combination_matrix(features)
assert np.array_equal(np_combs, torch_combs.numpy())

But the dimensions are different:

(2, 3, 3, 3)
torch.Size([3, 3])

Which results in an error (logically):

Traceback (most recent call last):
  File "/home/XXX/util.py", line 226, in <module>
    torch_combs = combination_matrix(features)
  File "/home/XXX/util.py", line 218, in combination_matrix
    return torch_combination_matrix()
  File "/home/XXX/util.py", line 212, in torch_combination_matrix
    output[idx] = arr[mesh].permute(2, 1, 0, *np.arange(3, num_dims))
RuntimeError: number of dims don't match in permute

How do I match the torch behavior to numpy? I've read various questions on the torch forums (e.g. this one with only one dimension), but could find how to apply this here. Similarly, index_select only works for one dimension, but I need it to work for at least 2 dimensions.

Upvotes: 2

Views: 3945

Answers (1)

RunOrVeith
RunOrVeith

Reputation: 4805

It is actually embarrassingly easy. You just need to flatten the indices, then reshape and permute the dimensions. This is the full working version:

import torch
import numpy as np


def combination_matrix(arr):
    idxs = np.arange(len(arr))
    idx = np.ix_(idxs, idxs)
    mesh = np.stack(np.meshgrid(idxs, idxs))

    def np_combination_matrix():
        output = np.zeros((len(arr), len(arr), 2, *arr.shape[1:]), dtype=arr.dtype)
        num_dims = len(output.shape)
        output[idx] = arr[mesh].transpose((2, 1, 0, *np.arange(3, num_dims)))
        return output

    def torch_combination_matrix():
        output_shape = (2, len(arr), len(arr), *arr.shape[1:])  # Note that this is different to numpy!
        return arr[mesh.flatten()].reshape(output_shape).permute(2, 1, 0, *range(3, len(output_shape)))

    if isinstance(arr, np.ndarray):
        return np_combination_matrix()
    elif isinstance(arr, torch.Tensor):
        return torch_combination_matrix()

I used pytest to run this on random arrays of different dimensions, and it seems to work in all cases:

import pytest

@pytest.mark.parametrize('random_dims', range(1, 5))
def test_combination_matrix(random_dims):
    dim_size = np.random.randint(1, 40, size=random_dims)
    elements = np.random.random(size=dim_size)
    np_combs = combination_matrix(elements)
    features = torch.from_numpy(elements)
    torch_combs = combination_matrix(features)

    assert np.array_equal(np_combs, torch_combs.numpy())

if __name__ == '__main__':
    pytest.main(['-x', __file__])

Upvotes: 3

Related Questions