Reputation: 2854
I am training a model with Tensorflow Estimator
, and my data is not balanced. I want to correct for this by weighting each training example.
In raw Tensorflow one might do it like this. Is there an easy way to do this in Estimator
? Perhaps building a custom input_fn
?
Upvotes: 4
Views: 1945
Reputation: 233
If you are building a custom estimator model, you should forward the class weight for each sample of your dataset as a feature to your model_fn
and when defining the loss function op you can pass the class weight on weight
parameter.
Example:
tf.losses.softmax_cross_entropy(target, logits,weights=features['weight'])
Upvotes: 1
Reputation: 53758
I assume you're doing classification. If so, use tf.estimator.DNNClassifier
:
weight_column: A string or a
_NumericColumn
created bytf.feature_column.numeric_column
defining feature column representing weights. It is used to down weight or boost examples during training. It will be multiplied by the loss of the example. If it is a string, it is used as a key to fetch weight tensor from thefeatures
. If it is a_NumericColumn
, raw tensor is fetched by keyweight_column.key
, thenweight_column.normalizer_fn
is applied on it to get weight tensor.
Upvotes: 1