Karnivaurus
Karnivaurus

Reputation: 24111

TensorFlow: defining variables in global scope

I'm building an image classifier in TensorFlow, and there is a class imbalance in my training data. Therefore, when computing the loss, I need to weight the loss for each class by the inverse frequency of that class in the training data.

So, here is my code:

# Get the softmax from the final layer of the network
softmax = tf.nn.softmax(final_layer)
# Weight the softmax by the inverse frequency of the weights
weighted_softmax = tf.mul(softmax, class_weights)
# Compute the cross entropy
cross_entropy = -tf.reduce_sum(y_ * tf.log(softmax))
# Define the optimisation
train_step = tf.train.AdamOptimizer(1e-5).minimize(cross_entropy)

# Run the training
session.run(tf.initialize_all_variables())
for i in range(10000):
  # Get the next batch
  batch = datasets.train.next_batch(64)
  # Run a training step
  train_step.run(feed_dict = {x: batch[0], y_: batch[1]})

My question is: Can I store class_weights as just a tf.constant(...) in global scope? Or do I need to pass it as a parameter when computing cross_entropy?

The reason I am wondering is that class_weights is different for every batch. Therefore, I am worried that if it is just defined in global scope, then when the Tensor Flow graph is constructed, it just takes the initial values in class_weights, and then never updates them. Whereas if I were to pass class_weights using the feed_dict when computing weighted_softmax, then I am explicitly telling Tensor Flow to use the recent, updated values in class_weights.

Any help would be appreciated. Thanks!

Upvotes: 1

Views: 991

Answers (1)

Daniel Slater
Daniel Slater

Reputation: 4143

I think having class_weights a tf.constant is fine. Class weighting should be done for the whole dataset, not per mini-batch.

Also another approach to this you might want to consider is sampling so each batch has equal numbers of each class?

Upvotes: 1

Related Questions