Reputation: 3
If I have a batch of matrices such that my matrix is shape (?, 600, 600), how would I go about retrieving the row and col indices of the maximum value in each matrix in the batch? Such that my row and column return matrices are each of shape (?) (the row return matrix has the index of the row of the max for each example in the batch and similar for the col return matrix).
Thank you!
Upvotes: 0
Views: 1119
Reputation: 5206
You can reshape + argmax. Something like:
x = tf.reshape(matrix, [tf.shape(matrix, 0), -1])
indices = tf.argmax(x, axis=1) # this gives you indices from 0 to 600^2
col_indices = indices / 600
row_indices = indices % 600
final_indices = tf.transpose(tf.stack(col_indices, row_indices))
Upvotes: 1