Reputation: 3291
I have a Tensor A of size [N x 3 x 3], and a Matrix B of size [N*3 x N*3]
I want to copy the contents of A -> B, so that the diagonal elements are filled up basically, and I want to do this efficiently:
It should kind of fill up B to look something filled like this:
So each [i,3,3] fills into each [3x3] part in B diagonally down the line.
How do I do this? As efficiently as possible as this is for a real time application. I could write a CUDA kernel to do this, but I would prefer to do it with some special Pytorch function
Upvotes: 1
Views: 3375
Reputation: 24231
Use torch.block_diag()
:
# Setup
A = torch.ones(3,3,3, dtype=int)
# Unpack blocks and apply
B = torch.block_diag(*A)
>>> B
tensor([[1, 1, 1, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 1, 1, 1, 0, 0, 0],
[0, 0, 0, 1, 1, 1, 0, 0, 0],
[0, 0, 0, 1, 1, 1, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 1, 1, 1],
[0, 0, 0, 0, 0, 0, 1, 1, 1],
[0, 0, 0, 0, 0, 0, 1, 1, 1]])
Upvotes: 3
Reputation: 8981
Here is simple (inplace) example, not sure about the performance for really big tensors:
Code:
import torch
# Create some tensors
N = 3
A = torch.ones(N, 3, 3)
A[1] *= 2
A[2] *= 3
B = torch.zeros(N*3, N*3)
def diagonalizer(A, B):
N = A.shape[0]
i_min = 0
j_min = 0
i_max = 3
j_max = 3
for t in range(N):
B[i_min:i_max, j_min:j_max] = A[t] # NOTE! this is inplace operation
# do the step:
i_min += 3
j_min += 3
i_max += 3
j_max += 3
print('before:\n', B, sep='')
diagonalizer(A, B)
print('after:\n', B, sep='')
Output:
before:
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.]])
after:
tensor([[1., 1., 1., 0., 0., 0., 0., 0., 0.],
[1., 1., 1., 0., 0., 0., 0., 0., 0.],
[1., 1., 1., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 2., 2., 2., 0., 0., 0.],
[0., 0., 0., 2., 2., 2., 0., 0., 0.],
[0., 0., 0., 2., 2., 2., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 3., 3., 3.],
[0., 0., 0., 0., 0., 0., 3., 3., 3.],
[0., 0., 0., 0., 0., 0., 3., 3., 3.]])
Upvotes: -1