puffadder
puffadder

Reputation: 113

multiply only certain columns of a tensorflow array

I'm currently modifying the loss function for one of my object detection neural networks. I basically have two arrays;

y_true: labels of predictions. tf tensor of shape (x, y, z) y_pred: predicted values. tf tensor of shape (x, y, z)- The x dimension is the batch size, the y dimension is the number of predicted objects in the image, the z dimension contains a one-hot encoding of the classes as well as the bounding boxes of said classes.

Now to the real question: What I want to do is basically to multiply the first 5 values z-values in y_pred with the first 5 z-values in y_true. All other values should remain unaffected. In numpy it is extremely straight forward;

y_pred[:,:,:5] *= y_true[:,:,:5]

I'm finding this very hard to do in tensorflow as I can't assign values to the original tensor and I want to keep all other values the same. How do I go about doing this in tensorflow?

Upvotes: 1

Views: 1357

Answers (1)

benjaminplanche
benjaminplanche

Reputation: 15119

Since v1.1, Tensorflow covers such Numpy-like indexing, see Tensor.getitem.

import tensorflow as tf

with tf.Session() as sess:
    y_pred = tf.constant([[[1,2,3,4,5,6,7,8,9,10], [10,20,30,40,50,60,70,80,90,100]]])
    y_true = tf.constant([[[1,2,3,4,5,6,7,8,9,10], [10,20,30,40,50,60,70,80,90,100]]])
    print((y_pred[:,:,:5] * y_true[:,:,:5]).eval()) 
    # [[[   1    4    9   16   25]
    #   [ 100  400  900 1600 2500]]]

EDIT after comment:

Now, the problem is the "*=" part i.e. item assignment. This isn't a straightforward operation in Tensorflow. In your case however, this can be easily solved using tf.concat or tf.where (tf.dynamic_partition + tf.dynamic_stitch could be used for more complex cases).

Find below a quick implementation of the two first solutions.

Solution using Tensor.getitem and tf.concat:

import tensorflow as tf

with tf.Session() as sess:
    y_pred = tf.constant([[[1,2,3,4,5,6,7,8,9,10]]])
    y_true = tf.constant([[[1,2,3,4,5,6,7,8,9,10]]])

    # tf.where can't apply the condition to any axis (see doc).
    # In your case (condition on 2nd axis), we need either to manually broadcast the
    # condition tensor, or transpose the target tensors.
    # Here is a quick demonstration with the 2nd solution:

    y_pred_edit = y_pred[:,:,:5] * y_true[:,:,:5]
    y_pred_rest = y_pred[:,:,4:]

    y_pred = tf.concat((y_pred_edit, y_pred_rest), axis=2)
    print(y_pred.eval())
    # [[[ 1  4  9 16 25  6  7  8  9 10]]]

Solution using tf.where:

import tensorflow as tf

def select_n_fist_indices(n, batch_size):
    """ Return a list of length batch_size with the n first elements True
        and the rest False, i.e. [*[[True] * n], *[[False] * (batch_size - n)]]. 
    """
    n_ones = tf.ones((n))
    rest_zeros = tf.zeros((batch_size - n))
    indices = tf.cast(tf.concat((n_ones, rest_zeros), axis=0), dtype=tf.bool)

    return indices

with tf.Session() as sess:
    y_pred = tf.constant([[[1,2,3,4,5,6,7,8,9,10]]])
    y_true = tf.constant([[[1,2,3,4,5,6,7,8,9,10]]])

    # tf.where can't apply the condition to any axis (see doc).
    # In your case (condition on 2nd axis), we need either to manually broadcast the 
    # condition tensor, or transpose the target tensors.
    # Here is a quick demonstration with the 2nd solution:
    y_pred_tranposed = tf.transpose(y_pred, [2, 0, 1])
    y_true_tranposed = tf.transpose(y_true, [2, 0, 1])

    edit_indices = select_n_fist_indices(5, tf.shape(y_pred_tranposed)[0])

    y_pred_tranposed = tf.where(condition=edit_indices, 
                                x=y_pred_tranposed * y_true_tranposed, y=y_pred_tranposed)

    # Transpose back:  
    y_pred = tf.transpose(y_pred_tranposed, [1, 2, 0])
    print(y_pred.eval())
    # [[[ 1  4  9 16 25  6  7  8  9 10]]]

Upvotes: 1

Related Questions