Qinqing Liu
Qinqing Liu

Reputation: 422

How to construct a 3D Tensor where every 2D sub tensor is a diagonal matrix in PyTorch?

Consider I have 2D Tensor, index_in_batch * diag_ele. How can I get a 3D Tensor index_in_batch * Matrix (who is a diagonal matrix, construct by drag_ele)?

The torch.diag() construct diagonal matrix only when input is 1D, and return diagonal element when input is 2D.

Upvotes: 9

Views: 5966

Answers (3)

Zhi Zhang
Zhi Zhang

Reputation: 101

Use torch.diag_embed:

>>> a = torch.randn(2, 3)
>>> torch.diag_embed(a)
tensor([[[ 1.5410,  0.0000,  0.0000],
         [ 0.0000, -0.2934,  0.0000],
         [ 0.0000,  0.0000, -2.1788]],

        [[ 0.5684,  0.0000,  0.0000],
         [ 0.0000, -1.0845,  0.0000],
         [ 0.0000,  0.0000, -1.3986]]])

Upvotes: 6

Qinqing Liu
Qinqing Liu

Reputation: 422

The solution for auto backward by wrapping in Variable.

import torch

a = torch.rand(2, 3)
print(a)

b = Variable(torch.eye(a.size(1)))
c = a.unsqueeze(2).expand(*a.size(), degree_inv.size(1))
b_expand =  b.unsqueeze(0).expand(c.size(0), *b.size())
d = torch.mul(c.double(), b_expand.double())

print(d)

Upvotes: 0

Wasi Ahmad
Wasi Ahmad

Reputation: 37711

import torch

a = torch.rand(2, 3)
print(a)
b = torch.eye(a.size(1))
c = a.unsqueeze(2).expand(*a.size(), a.size(1))
d = c * b
print(d)

Output

 0.5938  0.5769  0.0555
 0.9629  0.5343  0.2576
[torch.FloatTensor of size 2x3]


(0 ,.,.) = 
  0.5938  0.0000  0.0000
  0.0000  0.5769  0.0000
  0.0000  0.0000  0.0555

(1 ,.,.) = 
  0.9629  0.0000  0.0000
  0.0000  0.5343  0.0000
  0.0000  0.0000  0.2576
[torch.FloatTensor of size 2x3x3]

Upvotes: 10

Related Questions