Armin Amini
Armin Amini

Reputation: 23

Select index of a 2D Tensor with exact values

I'm sorry for asking such a trivial question, but I'm new to Tensorflow. I've got two tensors. y_true = [[1,0], [0,1], [1,0], [1,0], [0,1], [0,1], [1,0], [0,1], [1,0], [0,1]] y_pred = [[0.6,0.4], [0.3,0.7], [0.8,0.2], [0.8,0.2], [0.3,0.7],[0.1,0.9],[0.9, 0.1],[0.4,0.6],[0.6,0.4],[0.2,0.8]] Additionally, I want to filter y_true according to each of the [1,0] or [0,1] values.

I had the following concept, which I don't think is very effective. For instance, when filtering y_true on [0,1]:

ind_zero   = tf.math.equal(y_true,[1,0])
index_zero = tf.math.logical_and(ind_zero[:,0],ind_zero[:,1])
zeros      = tf.gather_nd(y_pred,tf.where(index_zero))

Exists another idea that functions more effectively? Thanks in advance.

Upvotes: 0

Views: 69

Answers (1)

delirium78
delirium78

Reputation: 614

You could filter y_true on [1,0]:

zeros = tf.gather_nd(y_pred,tf.where(tf.argmin(y_true, axis = 1)))

The same for [0,1] use argmax instead of argmin:

zeros = tf.gather_nd(y_pred,tf.where(tf.argmax(y_true, axis = 1)))

Upvotes: 1

Related Questions