Derk
Derk

Reputation: 1395

Training with tf.data API and sample weights

All my training images are in tfrecords files. Now they are used in a standard way like this:

dataset = dataset.apply(tf.data.experimental.map_and_batch(
            map_func=lambda x: preprocess(x, data_augmentation_options=data_augmentation), 
            batch_size=images_per_batch)

where preprocess returns the decoded image and the label which both come from the tfrecord file.

Now the new situation. I want also a sample weight for each example. So instead of

return image,label

in preprocess, it should be

return image, label, sample_weight

However, this sample_weight is not in the tfrecord file. It is computed when training start based on number of examples for each class. Basically it is a Python dictionary weights[label] = sample_weights.

The question is how to use these sample weights in the tf.data pipeline. Because label is a Tensor it cannot be used to index the Python dictionary.

Upvotes: 3

Views: 2509

Answers (1)

Tuco
Tuco

Reputation: 128

There are some things that are no clear on your question, as what is x? It would be better if you can post a whole code example with your question.

I'm assuming that x is as tensor with an image and label. If so you can use the map function to add a tensor of sample weights to your dataset. Something as (note that this code was not tested):

def im_add_weight(image, label, sample_weight):
   #convert to tensor if they are not and make sure to us
   image= tf.convert_to_tensor(image, dtype= tf.float32)
   label = tf.convert_to_tensor(label, dtype= tf.float32)
   sample_weight = tf.convert_to_tensor(sample_weight, dtype= tf.float32)
   return image, label, sample_weight

dataset = dataset .map(
lambda image, label, sample_weight: tuple(tf.py_func(
    im_add_weight, [image, label,sample_weight], [tf.float32, tf.float32,tf.float32])))

Upvotes: 3

Related Questions