John
John

Reputation: 1947

How to define plenty of diagonal matricies?

Let's consider tensor:

scale = torch.tensor([[1.0824, 1.0296, 1.0065, 0.9395, 0.9424, 1.0260, 0.9805, 1.0509],
        [1.1002, 1.0358, 1.0112, 0.9466, 0.9454, 0.9942, 0.9891, 1.0485],
        [1.1060, 1.0157, 1.0216, 0.9544, 0.9378, 1.0160, 0.9671, 1.0240]])

which has shape :

scale.shape
torch.Size([3, 8])

I want to have a tensor of shape [3, 8, 8] where in which I have three diagonal matricies using values from tensor scale. In other words, first matrix will have diagonals only using scale[0], second one scale[1] and last one scale[2].

We can do it brainless:

import torch
temp = torch.tensor([])
for i in range(0, 3):
    temp = torch.cat([temp, torch.diag(scale[i])])
temp = temp.view(3, 8, 8)
temp

But I'm wondering if there is any other more efficient way to do this.

Upvotes: 1

Views: 173

Answers (1)

Shai
Shai

Reputation: 114826

I think you are looking for diag_embed:

temp = torch.diag_embed(scale)

For example:

scale = torch.arange(24).view(3,8)
torch.diag_embed(scale)
tensor([[[ 0,  0,  0,  0,  0,  0,  0,  0],
         [ 0,  1,  0,  0,  0,  0,  0,  0],
         [ 0,  0,  2,  0,  0,  0,  0,  0],
         [ 0,  0,  0,  3,  0,  0,  0,  0],
         [ 0,  0,  0,  0,  4,  0,  0,  0],
         [ 0,  0,  0,  0,  0,  5,  0,  0],
         [ 0,  0,  0,  0,  0,  0,  6,  0],
         [ 0,  0,  0,  0,  0,  0,  0,  7]],

        [[ 8,  0,  0,  0,  0,  0,  0,  0],
         [ 0,  9,  0,  0,  0,  0,  0,  0],
         [ 0,  0, 10,  0,  0,  0,  0,  0],
         [ 0,  0,  0, 11,  0,  0,  0,  0],
         [ 0,  0,  0,  0, 12,  0,  0,  0],
         [ 0,  0,  0,  0,  0, 13,  0,  0],
         [ 0,  0,  0,  0,  0,  0, 14,  0],
         [ 0,  0,  0,  0,  0,  0,  0, 15]],

        [[16,  0,  0,  0,  0,  0,  0,  0],
         [ 0, 17,  0,  0,  0,  0,  0,  0],
         [ 0,  0, 18,  0,  0,  0,  0,  0],
         [ 0,  0,  0, 19,  0,  0,  0,  0],
         [ 0,  0,  0,  0, 20,  0,  0,  0],
         [ 0,  0,  0,  0,  0, 21,  0,  0],
         [ 0,  0,  0,  0,  0,  0, 22,  0],
         [ 0,  0,  0,  0,  0,  0,  0, 23]]])


If you insist on using a loop and torch.cat, you can use a list comprehension:

temp = torch.stack([torch.diag(s_) for s_ in scale])

Upvotes: 1

Related Questions