gus
gus

Reputation: 71

Zero diagonal of a PyTorch tensor?

Is there a simple way to zero the diagonal of a PyTorch tensor?

For example I have:

tensor([[2.7183, 0.4005, 2.7183, 0.5236],
        [0.4005, 2.7183, 0.4004, 1.3469],
        [2.7183, 0.4004, 2.7183, 0.5239],
        [0.5236, 1.3469, 0.5239, 2.7183]])

And I want to get:

tensor([[0.0000, 0.4005, 2.7183, 0.5236],
        [0.4005, 0.0000, 0.4004, 1.3469],
        [2.7183, 0.4004, 0.0000, 0.5239],
        [0.5236, 1.3469, 0.5239, 0.0000]])

Upvotes: 3

Views: 2188

Answers (5)

Shai
Shai

Reputation: 114826

Here's another way:

x.flatten()[::(x.shape[-1]+1)] = 0

Upvotes: 1

iacob
iacob

Reputation: 24231

You can simply use:

x.fill_diagonal_(0)

Upvotes: 6

trialNerror
trialNerror

Reputation: 3563

I believe the simplest would be to use torch.diagonal:

z = torch.randn(4,4)
torch.diagonal(z, 0).zero_()
print(z)
>>> tensor([[ 0.0000, -0.6211,  0.1120,  0.8362],
            [-0.1043,  0.0000,  0.1770,  0.4197],
            [ 0.7211,  0.1138,  0.0000, -0.7486], 
            [-0.5434, -0.8265, -0.2436,  0.0000]])

This way, the code is perfectly explicit, and you delegate the performance to pytorch's built in functions.

Upvotes: 6

Ivan
Ivan

Reputation: 40678

As an alternative to indexing with two tensors separately, you could achieve this using a combination of torch.repeat, and torch.split, taking advantage of the fact the latter returns a tuple:

>>> x[torch.arange(len(x)).repeat(2).split(len(x))] = 0
>>> x
tensor([[0.0000, 0.4005, 2.7183, 0.5236],
        [0.4005, 0.0000, 0.4004, 1.3469],
        [2.7183, 0.4004, 0.0000, 0.5239],
        [0.5236, 1.3469, 0.5239, 0.0000]])

Upvotes: 1

Szymon Maszke
Szymon Maszke

Reputation: 24701

Yes, there are a couple ways to do that, simplest one would be to go directly:

import torch

tensor = torch.rand(4, 4)
tensor[torch.arange(tensor.shape[0]), torch.arange(tensor.shape[1])] = 0

This one broadcasts 0 value across all pairs, e.g. (0, 0), (1, 1), ..., (n, n)

Another way would be (readability is debatable) to use the inverse of torch.eye like this:

tensor = torch.rand(4, 4)
tensor *= ~(torch.eye(*tensor.shape).bool())

This one creates additional matrix and does way more operations, hence I'd stick with the first version.

Upvotes: 1

Related Questions