Murali Kadambi
Murali Kadambi

Reputation: 101

Replace a slice of one tensor with another bigger Tensor in Python

In a custom loss function in Tensorflow 2.4.1. and Python 3.7 y_pred tensor is of shape (batch, 21,21,3) where 3 is the number of classes from the model

y_true is an integer that signifies the correct class that the model should have classified. In order to do the calculation, a zeros tensor of shape y_pred was created, by name y_true_array

There should be an easier way to do this than the workaround below:

  y_true_array= tf.zeros_like(y_pred, dtype=tf.float32, name="loss_tru_array")
  one_class_shape = y_pred.shape[:-1] + [1]

  ones =tf.ones(one_class_shape, dtype=tf.float32, name="loss_ones")
  zeros = tf.zeros(one_class_shape, dtype=tf.float32, name="loss_ones")
  if (y_true == 0):
    y_true_array = tf.concat([ones, zeros, zeros], axis=-1)
  elif (y_true == 1):
    y_true_array = tf.concat([zeros, ones, zeros], axis=-1)
  elif (y_true == 2):
    y_true_array = tf.concat([zeros, zeros, ones], axis=-1)

For example For a y_pred array of size(3, 2, 3). 2 in the second dimension is the feature size. if y_true = 1, y_true_array should be

[ [[0, 1, 0],[0,1,0]] ,
[[0, 1, 0],[0, 1, 0]] ,
[[0, 1, 0],[0,1,0]] ]

Upvotes: 0

Views: 157

Answers (1)

thushv89
thushv89

Reputation: 11333

Seems to me that, you're just one-hot encoding the class labels in y_true?

y_true_array = tf.one_hot(y_pred, depth=3)

Basically, if you have the y_pred array

[0, 1, 2]

y_true_array will be,

[
 [1, 0, 0],
 [0, 1, 0],
 [0, 0, 1],
]

Upvotes: 1

Related Questions