raaj
raaj

Reputation: 3291

Pytorch: Set Block-Diagonal Matrix Efficiently?

I have a Tensor A of size [N x 3 x 3], and a Matrix B of size [N*3 x N*3]

I want to copy the contents of A -> B, so that the diagonal elements are filled up basically, and I want to do this efficiently:

It should kind of fill up B to look something filled like this:

enter image description here

So each [i,3,3] fills into each [3x3] part in B diagonally down the line.

How do I do this? As efficiently as possible as this is for a real time application. I could write a CUDA kernel to do this, but I would prefer to do it with some special Pytorch function

Upvotes: 1

Views: 3375

Answers (2)

iacob
iacob

Reputation: 24231

Use torch.block_diag():

# Setup
A = torch.ones(3,3,3, dtype=int)

# Unpack blocks and apply
B = torch.block_diag(*A)
>>> B
tensor([[1, 1, 1, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 1, 1, 1, 0, 0, 0],
        [0, 0, 0, 1, 1, 1, 0, 0, 0],
        [0, 0, 0, 1, 1, 1, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 1, 1, 1]])

Upvotes: 3

trsvchn
trsvchn

Reputation: 8981

Here is simple (inplace) example, not sure about the performance for really big tensors:

Code:

import torch

# Create some tensors
N = 3
A = torch.ones(N, 3, 3)
A[1] *= 2
A[2] *= 3
B = torch.zeros(N*3, N*3)


def diagonalizer(A, B):
    N = A.shape[0]
    i_min = 0
    j_min = 0
    i_max = 3
    j_max = 3

    for t in range(N):
        B[i_min:i_max, j_min:j_max] = A[t]  # NOTE! this is inplace operation

        # do the step:
        i_min += 3
        j_min += 3
        i_max += 3
        j_max += 3


print('before:\n', B, sep='')

diagonalizer(A, B)

print('after:\n', B, sep='')

Output:

before:
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.]])
after:
tensor([[1., 1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 2., 2., 2., 0., 0., 0.],
        [0., 0., 0., 2., 2., 2., 0., 0., 0.],
        [0., 0., 0., 2., 2., 2., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 3., 3., 3.],
        [0., 0., 0., 0., 0., 0., 3., 3., 3.],
        [0., 0., 0., 0., 0., 0., 3., 3., 3.]])

Upvotes: -1

Related Questions