Reputation: 24111
I'm building an image classifier in TensorFlow, and there is a class imbalance in my training data. Therefore, when computing the loss, I need to weight the loss for each class by the inverse frequency of that class in the training data.
So, here is my code:
# Get the softmax from the final layer of the network
softmax = tf.nn.softmax(final_layer)
# Weight the softmax by the inverse frequency of the weights
weighted_softmax = tf.mul(softmax, class_weights)
# Compute the cross entropy
cross_entropy = -tf.reduce_sum(y_ * tf.log(softmax))
# Define the optimisation
train_step = tf.train.AdamOptimizer(1e-5).minimize(cross_entropy)
# Run the training
session.run(tf.initialize_all_variables())
for i in range(10000):
# Get the next batch
batch = datasets.train.next_batch(64)
# Run a training step
train_step.run(feed_dict = {x: batch[0], y_: batch[1]})
My question is: Can I store class_weights
as just a tf.constant(...)
in global scope? Or do I need to pass it as a parameter when computing cross_entropy
?
The reason I am wondering is that class_weights
is different for every batch. Therefore, I am worried that if it is just defined in global scope, then when the Tensor Flow graph is constructed, it just takes the initial values in class_weights
, and then never updates them. Whereas if I were to pass class_weights
using the feed_dict
when computing weighted_softmax
, then I am explicitly telling Tensor Flow to use the recent, updated values in class_weights
.
Any help would be appreciated. Thanks!
Upvotes: 1
Views: 991
Reputation: 4143
I think having class_weights
a tf.constant is fine. Class weighting should be done for the whole dataset, not per mini-batch.
Also another approach to this you might want to consider is sampling so each batch has equal numbers of each class?
Upvotes: 1