Reputation: 131
Is there a way to extract the diagonal of a square matrix in TensorFlow? That is, for a matrix like this:
[
[0, 1, 2],
[3, 4, 5],
[6, 7, 8]
]
I want to fetch the elements: [0, 4, 8]
In numpy, this is pretty straight-forward via np.diag:
In TensorFlow, there is a diag function, but it only forms a new matrix with the elements specified in the argument on the diagonal, which is not what I want.
I could imagine how this could be done via striding... but I don't see striding for tensors in TensorFlow.
Upvotes: 11
Views: 16994
Reputation: 222591
Currently it is possible to extract diagonal elements with tf.diag_part. Here is their example:
"""
'input' is [[1, 0, 0, 0],
[0, 2, 0, 0],
[0, 0, 3, 0],
[0, 0, 0, 4]]
"""
tf.diag_part(input) ==> [1, 2, 3, 4]
Old answer (when diag_part) was not available (still relevant if you want to achieve something that is not available now):
After looking though the math operations and tensor transformations, it does not look like such operation exists. Even if you can extract this data with matrix multiplications it would not be efficient (get diagonal is O(n)
).
You have three approaches, starting with easy to hard.
tf.shape
Upvotes: 5
Reputation: 2187
with tensorflow 0.8 its possible to extract the diagonal elements with tf.diag_part()
(see documentation)
UPDATE
for tensorflow >= r1.12 its tf.linalg.tensor_diag_part
(see documentation)
Upvotes: 12
Reputation: 295
Use the tf.diag_part()
with tf.Session() as sess:
x = tf.ones(shape=[3, 3])
x_diag = tf.diag_part(x)
print(sess.run(x_diag ))
Upvotes: 3
Reputation: 827
Depending on the context, a mask can be a nice way to `cancel' off diagonal elements of the matrix, especially if you plan in reducing it anyway:
mask = tf.diag(tf.ones([n]))
y = tf.mul(mask,y)
cost = -tf.reduce_sum(y)
Upvotes: 0
Reputation: 966
Use the gather
operation.
x = tensorflow.Variable([[1,2,3],[4,5,6],[7,8,9]])
x_flat = tf.reshape(x, [-1]) # flatten the matrix
x_diag = tf.gather(x, [0, 3, 6])
Upvotes: 0
Reputation: 2001
This is probably is a workaround, but works.
>> sess = tensorflow.InteractiveSession()
>> x = tensorflow.Variable([[1,2,3],[4,5,6],[7,8,9]])
>> x.initializer.run()
>> z = tensorflow.pack([x[i,i] for i in range(3)])
>> z.eval()
array([1, 5, 9], dtype=int32)
Upvotes: 2