SergeantIdiot
SergeantIdiot

Reputation: 113

What does this code in PyTorch do? How can I express it with tensorflow

I found a code that would solve my problem that looks like this:

(self.conv_diag(input_tensor.diagonal(dim1=2, dim2=3))).diag_embed(dim1=2, dim2=3)

While self.conv_diag is a layer I have defined before.

As far as I understood it extracts the diagonal of a subtensor in the second and third dimension puts it into the layer and constructs a new tensor filled with zeros and replaces its second and third dimension with the new values calculated by my layer.

What I have found to extract the diagonal is

tf.math.reduce_diag(input_tensor)

but I cannot choose the axis and I have not yet found an equivalent function to replace torch.diag_embed()

How can I express it in Tensorflow?

Upvotes: 0

Views: 174

Answers (1)

Mr. For Example
Mr. For Example

Reputation: 4313

This is maybe what you are looking for:

tf.linalg.diag(
    diagonal, name='diag', k=0, num_rows=-1, num_cols=-1, padding_value=0,
    align='RIGHT_LEFT'
)

Upvotes: 2

Related Questions