Gal R
Gal R

Reputation: 31

apply a function over all combination of tensor rows in pytorch

I want to make a function f1(arg_tensor) which gets a pytorch tensor as an argument.

In this function I use another function: f2(tensor_row_1, tensor_row_2) which gets two pytorch's tensor rows as an arguments and outputs a scalar.

f2(..) should be applied over all combinations of tensor's rows [1..n] (i.e. apply function f2(..) on tensor rows' indices: [0,1], [0,2], [0,3]...[0,n-1]...[n-1,0]..[n-1,n-1]).

The output of f1(..) should be a tensor such that at element [0,0] there will the output value of f2(tensor_rows[0], tensor_rows[0]) and so on...

Is there a way to perform it efficiently (and not with double for loop)?

Upvotes: 3

Views: 1622

Answers (1)

Szymon Maszke
Szymon Maszke

Reputation: 24681

Yes, one can do it with a simple broadcasting trick:

def f1(tensor):
    tensor = tensor.permute(1, 0)
    return torch.nn.functional.kl_div(
        tensor.unsqueeze(dim=2), tensor.unsqueeze(dim=1), reduction="none"
    ).mean(dim=0)


def manual_f1(tensor):
    result = []
    for row1 in tensor:
        for row2 in tensor:
            result.append(torch.nn.functional.kl_div(row1, row2))
    return torch.stack(result).reshape(tensor.shape[0], -1)


data = torch.randn(5, 7)

result = f1(data)
manual_result = manual_f1(data)

print(torch.all(result == manual_result).item())

Please notice, for more rows the result will differ due to numerical difference. You can:

  • print the values and inspect manually
  • use torch.isclose to verify similarity

In the second case, last print would become:

print(torch.all(torch.isclose(result, manual_result)).item())

Upvotes: 2

Related Questions