Reddspark
Reddspark

Reputation: 7567

Weighted cost function in tensorflow

I'm trying to introduce weighting into the following cost function:

_cost = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=_logits, labels=y))

But without having to do the softmax cross entropy myself. So I was thinking of breaking the cost calc up into cost1 and cost2 and feeding in a modified version of my logits and y values to each one.

I want to do something like this but not sure what is the correct code:

mask=(y==0)
y0 = tf.boolean_mask(y,mask)*y1Weight

(This gives the error that mask cannot be scalar)

Upvotes: 1

Views: 1779

Answers (2)

Vijay Mariappan
Vijay Mariappan

Reputation: 17191

The weight masks can be computed using tf.where. Here is the weighted cost example:

batch_size = 100
y1Weight = 0.25
y0Weight = 0.75


_logits = tf.Variable(tf.random_normal(shape=(batch_size, 2), stddev=1.))
y = tf.random_uniform(shape=(batch_size,), maxval=2, dtype=tf.int32)

_cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=_logits, labels=y)

#Weight mask, the weights for label=0 is y0Weight and for 1 is y1Weight
y_w = tf.where(tf.cast(y, tf.bool), tf.ones((batch_size,))*y0Weight, tf.ones((batch_size,))*y1Weight)

# New weighted cost
cost_w = tf.reduce_mean(tf.multiply(_cost, y_w))

As suggested by @user1761806, the simpler solution would be to use tf.losses.sparse_softmax_cross_entropy() which has allows weighting of the classes.

Upvotes: 1

Ishant Mrinal
Ishant Mrinal

Reputation: 4918

you can calculate the weighted cost as follows; use a predefined weights_per_class tensor with shape (num_classes, 1). For label use one_hot encoding.

# here labels shape should be [batch_size, num_classes] ; obtained using one_hot
_cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=_logits, labels=y)

# Here you can define a deterministic weights tensor. 
# weights_per_class = tf.constant(np.array([y0weights, y1weights, ...]))
weights_per_class =tf.random_normal(shape=(num_classes, 1), dtype=tf.float32)

# Use the weights tensor to compute weighted loss
_weighted_cost =  tf.reduce_mean(tf.matmul(_cost, weights_per_class))

Upvotes: 0

Related Questions