ADA
ADA

Reputation: 259

Diagonal embedding of a batch of matrices in pytorch?

If you are given a collection of n x n matrices say m of them, is there a predefined function in pytorch that performs a diagonal embedding on all of these into a larger matrix of dimension nm x nm?

To be concrete, what I am looking for is say you have two 2 x 2 identity matrices, then their diagonal embedding into a 4 x 4 matrix would be the identity 4 x 4 matrix.

Something like:

torch.block_diag

but this expects you to feed each matrix as a separate argument.

Upvotes: 1

Views: 1066

Answers (1)

Quang Hoang
Quang Hoang

Reputation: 150755

Your question doesn't specify how you get your m tensors. Let's say you have

# channel first tensors
a = torch.ones(4,2,2)

or

# a list of tensors
a = [torch.ones(2,2) for _ in range(4)]

then you can unpack that in block_diag:

>>> torch.block_diag(*a)

tensor([[1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 1., 0., 0., 0., 0.],
        [0., 0., 1., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 1., 0., 0.],
        [0., 0., 0., 0., 1., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 0., 1., 1.]])

Upvotes: 1

Related Questions