Reputation: 113
I'm currently modifying the loss function for one of my object detection neural networks. I basically have two arrays;
y_true: labels of predictions. tf tensor of shape (x, y, z) y_pred: predicted values. tf tensor of shape (x, y, z)- The x dimension is the batch size, the y dimension is the number of predicted objects in the image, the z dimension contains a one-hot encoding of the classes as well as the bounding boxes of said classes.
Now to the real question: What I want to do is basically to multiply the first 5 values z-values in y_pred with the first 5 z-values in y_true. All other values should remain unaffected. In numpy it is extremely straight forward;
y_pred[:,:,:5] *= y_true[:,:,:5]
I'm finding this very hard to do in tensorflow as I can't assign values to the original tensor and I want to keep all other values the same. How do I go about doing this in tensorflow?
Upvotes: 1
Views: 1357
Reputation: 15119
Since v1.1, Tensorflow covers such Numpy-like indexing, see Tensor.getitem.
import tensorflow as tf
with tf.Session() as sess:
y_pred = tf.constant([[[1,2,3,4,5,6,7,8,9,10], [10,20,30,40,50,60,70,80,90,100]]])
y_true = tf.constant([[[1,2,3,4,5,6,7,8,9,10], [10,20,30,40,50,60,70,80,90,100]]])
print((y_pred[:,:,:5] * y_true[:,:,:5]).eval())
# [[[ 1 4 9 16 25]
# [ 100 400 900 1600 2500]]]
EDIT after comment:
Now, the problem is the "*=" part i.e. item assignment. This isn't a straightforward operation in Tensorflow. In your case however, this can be easily solved using tf.concat or tf.where (tf.dynamic_partition + tf.dynamic_stitch could be used for more complex cases).
Find below a quick implementation of the two first solutions.
Solution using Tensor.getitem and tf.concat:
import tensorflow as tf
with tf.Session() as sess:
y_pred = tf.constant([[[1,2,3,4,5,6,7,8,9,10]]])
y_true = tf.constant([[[1,2,3,4,5,6,7,8,9,10]]])
# tf.where can't apply the condition to any axis (see doc).
# In your case (condition on 2nd axis), we need either to manually broadcast the
# condition tensor, or transpose the target tensors.
# Here is a quick demonstration with the 2nd solution:
y_pred_edit = y_pred[:,:,:5] * y_true[:,:,:5]
y_pred_rest = y_pred[:,:,4:]
y_pred = tf.concat((y_pred_edit, y_pred_rest), axis=2)
print(y_pred.eval())
# [[[ 1 4 9 16 25 6 7 8 9 10]]]
Solution using tf.where:
import tensorflow as tf
def select_n_fist_indices(n, batch_size):
""" Return a list of length batch_size with the n first elements True
and the rest False, i.e. [*[[True] * n], *[[False] * (batch_size - n)]].
"""
n_ones = tf.ones((n))
rest_zeros = tf.zeros((batch_size - n))
indices = tf.cast(tf.concat((n_ones, rest_zeros), axis=0), dtype=tf.bool)
return indices
with tf.Session() as sess:
y_pred = tf.constant([[[1,2,3,4,5,6,7,8,9,10]]])
y_true = tf.constant([[[1,2,3,4,5,6,7,8,9,10]]])
# tf.where can't apply the condition to any axis (see doc).
# In your case (condition on 2nd axis), we need either to manually broadcast the
# condition tensor, or transpose the target tensors.
# Here is a quick demonstration with the 2nd solution:
y_pred_tranposed = tf.transpose(y_pred, [2, 0, 1])
y_true_tranposed = tf.transpose(y_true, [2, 0, 1])
edit_indices = select_n_fist_indices(5, tf.shape(y_pred_tranposed)[0])
y_pred_tranposed = tf.where(condition=edit_indices,
x=y_pred_tranposed * y_true_tranposed, y=y_pred_tranposed)
# Transpose back:
y_pred = tf.transpose(y_pred_tranposed, [1, 2, 0])
print(y_pred.eval())
# [[[ 1 4 9 16 25 6 7 8 9 10]]]
Upvotes: 1