Reputation: 259
If you are given a collection of n x n
matrices say m
of them, is there a predefined function in pytorch that performs a diagonal embedding on all of these into a larger matrix of dimension nm x nm
?
To be concrete, what I am looking for is say you have two 2 x 2
identity matrices, then their diagonal embedding into a 4 x 4
matrix would be the identity 4 x 4
matrix.
Something like:
torch.block_diag
but this expects you to feed each matrix as a separate argument.
Upvotes: 1
Views: 1066
Reputation: 150755
Your question doesn't specify how you get your m
tensors. Let's say you have
# channel first tensors
a = torch.ones(4,2,2)
or
# a list of tensors
a = [torch.ones(2,2) for _ in range(4)]
then you can unpack that in block_diag
:
>>> torch.block_diag(*a)
tensor([[1., 1., 0., 0., 0., 0., 0., 0.],
[1., 1., 0., 0., 0., 0., 0., 0.],
[0., 0., 1., 1., 0., 0., 0., 0.],
[0., 0., 1., 1., 0., 0., 0., 0.],
[0., 0., 0., 0., 1., 1., 0., 0.],
[0., 0., 0., 0., 1., 1., 0., 0.],
[0., 0., 0., 0., 0., 0., 1., 1.],
[0., 0., 0., 0., 0., 0., 1., 1.]])
Upvotes: 1