Reputation: 113
I found a code that would solve my problem that looks like this:
(self.conv_diag(input_tensor.diagonal(dim1=2, dim2=3))).diag_embed(dim1=2, dim2=3)
While self.conv_diag
is a layer I have defined before.
As far as I understood it extracts the diagonal of a subtensor in the second and third dimension puts it into the layer and constructs a new tensor filled with zeros and replaces its second and third dimension with the new values calculated by my layer.
What I have found to extract the diagonal is
tf.math.reduce_diag(input_tensor)
but I cannot choose the axis and I have not yet found an equivalent function to replace torch.diag_embed()
How can I express it in Tensorflow?
Upvotes: 0
Views: 174
Reputation: 4313
This is maybe what you are looking for:
tf.linalg.diag(
diagonal, name='diag', k=0, num_rows=-1, num_cols=-1, padding_value=0,
align='RIGHT_LEFT'
)
Upvotes: 2