azal
azal

Reputation: 1260

Replace diagonal elements with vector in PyTorch

I have been searching everywhere for something equivalent of the following to PyTorch, but I cannot find anything.

L_1 = np.tril(np.random.normal(scale=1., size=(D, D)), k=0)
L_1[np.diag_indices_from(L_1)] = np.exp(np.diagonal(L_1))

I guess there is no way to replace the diagonal elements in such an elegant way using Pytorch.

Upvotes: 9

Views: 13916

Answers (4)

tsveti_iko
tsveti_iko

Reputation: 7982

For simplicity, let's say you have a matrix L_1 and want to replace it's diagonal with zeros. You can do this in multiple ways.

Using fill_diagonal_():

L_1 = torch.tensor([[1,2,3],[4,5,6],[7,8,9]], dtype=torch.float32)
L_1 = L_1.fill_diagonal_(0.)

Using advanced indexing:

L_1 = torch.tensor([[1,2,3],[4,5,6],[7,8,9]], dtype=torch.float32)
length = len(L_1)
zero_vector = torch.zeros(length, dtype=torch.float32)
L_1[range(length), range(length)] = zero_vector

Using scatter_():

L_1 = torch.tensor([[1,2,3],[4,5,6],[7,8,9]], dtype=torch.float32)  
diag_idx = torch.arange(len(L_1)).unsqueeze(1)
zero_matrix = torch.zeros(L_1.shape, dtype=torch.float32)
L_1 = L_1.scatter_(1, diag_idx, zero_matrix) 

Note that all of the above solutions are in-place operations and will affect the backward pass, because the original value might be needed to compute it. So if you want to keep the backward pass unaffected, meaning to "break the graph" by not recording the change (operation), meaning not computing the gradients in the backward pass corresponding to what you computed in the forward pass, then you can just add the .data when using advanced indexing or scatter_().

Using advanced indexing with .data:

L_1 = torch.tensor([[1,2,3],[4,5,6],[7,8,9]], dtype=torch.float32)
length = len(L_1)
zero_vector = torch.zeros(length, dtype=torch.float32)
L_1[range(length), range(length)] = zero_vector.data

Using scatter_() with .data:

L_1 = torch.tensor([[1,2,3],[4,5,6],[7,8,9]], dtype=torch.float32)  
diag_idx = torch.arange(len(L_1)).unsqueeze(1)
zero_matrix = torch.zeros(L_1.shape, dtype=torch.float32)
L_1 = L_1.scatter_(1, diag_idx, zero_matrix.data)

For reference check out this discussion.

Upvotes: 1

iacob
iacob

Reputation: 24231

You can extract the diagonal elements with diagonal(), and then assign the transformed values inplace with copy_():

new_diags = L_1.diagonal().exp()
L_1.diagonal().copy_(new_diags)

Upvotes: 2

sgt pepper
sgt pepper

Reputation: 604

There is an easier way to do it

dest_matrix[range(len(dest_matrix)), range(len(dest_matrix))] = source_vector

In fact we have to generate diagonal indices ourselves.

Usage example:

dest_matrix = torch.randint(10, (3, 3))
source_vector = torch.randint(100, 200, (len(dest_matrix), ))
print('dest_matrix:\n', dest_matrix)
print('source_vector:\n', source_vector)

dest_matrix[range(len(dest_matrix)), range(len(dest_matrix))] = source_vector

print('result:\n', dest_matrix)

# dest_matrix:
#  tensor([[3, 2, 5],
#         [0, 3, 5],
#         [3, 1, 1]])
# source_vector:
#  tensor([182, 169, 147])
# result:
#  tensor([[182,   2,   5],
#         [  0, 169,   5],
#         [  3,   1, 147]])

in case dest_matrix is not square you have to take min(dest_matrix.size()) instead of len(dest_matrix) in range()

Not as elegant as numpy but this doesn't require to store a new matrix of indices.

And yes, this preserves gradients

Upvotes: 4

layog
layog

Reputation: 4801

I do not think that such a functionality is implemented as of now. But, you can implement the same functionality using mask as follows.

# Assuming v to be the vector and a be the tensor whose diagonal is to be replaced
mask = torch.diag(torch.ones_like(v))
out = mask*torch.diag(v) + (1. - mask)*a

So, your implementation will be something like

L_1 = torch.tril(torch.randn((D, D)))
v = torch.exp(torch.diag(L_1))
mask = torch.diag(torch.ones_like(v))
L_1 = mask*torch.diag(v) + (1. - mask)*L_1

Not as elegant as numpy, but not too bad either.

Upvotes: 3

Related Questions