Reputation: 823
I want to change the typical MSE loss function. Right now I have the following code:
squared_difference = tf.reduce_sum(tf.square(target - output), [1])
mse_loss = tf.reduce_mean(squared_difference)
the shape of both the tensors is [batch_size, 10]
and an example for the target is [0,1,2,3,0.5,0.5,0.5,7,8,9]
. The 0.5
s are always at the indices 4, 5 and 6.
What I want to do now is ignore these indices completly and don't increase the loss if the ouput of the network doesn't have 0.5 at these indices.
So if the ouput is [0,1,2,3,20,10,14,7,8,9]
the loss should be 0
.
What is the best possible way to achieve this?
Upvotes: 1
Views: 480
Reputation: 3276
There are many ways you can handle this. One straightforward way is to use the weights
parameter of tf.losses.mean_squared_error
. Pass a bsz x labels
tensor which serves as a sort of mask with 1s for the values you want to consider and 0s to ignore. The weights
parameter exists for most loss functions.
Upvotes: 1