Mous
Mous

Reputation: 83

How to make Dataset for triplet loss

I am trying to make Dataset that would provide batches of TFRecords wherein one batch there would be 2 random Records from one class and the rest from the other random classes.

OR

A Dataset of batches of where there would be 2 random Records from each class that fits into that batch.

I tried to do this with tf.data.Dataset.from_generator and with tf.data.experimental.choose_from_datasets but with no success. Do you have an idea on how to do this?

EDIT: Today i think i implemented the second variant. Here is the code i was testing it on.

def input_fn():
  partial1 = tf.data.Dataset.from_tensor_slices(tf.range(0, 10)).repeat().shuffle(2)
  partial2 = tf.data.Dataset.from_tensor_slices(tf.range(20, 30)).repeat().shuffle(2)
  partial3 = tf.data.Dataset.from_tensor_slices(tf.range(60, 70)).repeat().shuffle(2)
  l = [partial1, partial2, partial3]

  def gen(x):
    return tf.data.Dataset.range(x,x+1).repeat(2)

  dataset = tf.data.Dataset.range(3).flat_map(gen).repeat(10)

  choice = tf.data.experimental.choose_from_datasets(l, dataset).batch(4)
  return choice

which when evaulated returns

[ 0  2 21 22]
[60 61  1  4]
[20 23 62 63]
[ 3  5 24 25]
[64 66  6  7]
[26 27 65 68]
[ 8  0 28 29]
[67 69  9  2]
[20 22 60 62]
[ 3  1 23 24]
[63 61  4  6]
[25 26 65 64]
[ 7  5 27 28]
[67 66  9  8]
[21 20 69 68]

Upvotes: 4

Views: 2725

Answers (2)

郑启航
郑启航

Reputation: 41

In TF 2.0 , Now can use dataset.interleave read diffence class's tfrecords, and use dataset.batch to make triplet pair :

h = FcaeRecHelper('data/ms1m_img_ann.npy', [112, 112], 128, use_softmax=False)
len(h.train_list)
img_shape = list(h.in_hw) + [3]

is_augment = True
is_normlize = False

def parser(stream: bytes):
    # parser tfrecords
    examples: dict = tf.io.parse_single_example(
        stream,
        {'img': tf.io.FixedLenFeature([], tf.string),
            'label': tf.io.FixedLenFeature([], tf.int64)})
    return tf.image.decode_jpeg(examples['img'], 3), examples['label']

def pair_parser(raw_imgs, labels):
    # imgs do same augment ~
    if is_augment:
        raw_imgs, _ = h.augment_img(raw_imgs, None)
    # normlize image
    if is_normlize:
        imgs: tf.Tensor = h.normlize_img(raw_imgs)
    else:
        imgs = tf.cast(raw_imgs, tf.float32)

    imgs.set_shape([4] + img_shape)
    labels.set_shape([4, ])
    # Note y_true shape will be [batch,3]
    return (imgs[0], imgs[1], imgs[2]), (labels[:3])

batch_size = 1
# h.train_list : ['a.tfrecords','b.tfrecords','c.tfrecords',...]
ds = (tf.data.Dataset.from_tensor_slices(h.train_list)
        .interleave(lambda x: tf.data.TFRecordDataset(x)
                    .shuffle(100)
                    .repeat(), cycle_length=-1,
                    # block_length = 2 is important
                    block_length=2,
                    num_parallel_calls=-1)
        .map(parser, -1)
        .batch(4, True)
        .map(pair_parser, -1)
        .batch(batch_size, True))

iters = iter(ds)
for i in range(20):
    imgs, labels = next(iters)
    fig, axs = plt.subplots(1, 3)
    axs[0].imshow(imgs[0].numpy().astype('uint8')[0])
    axs[1].imshow(imgs[1].numpy().astype('uint8')[0])
    axs[2].imshow(imgs[2].numpy().astype('uint8')[0])
    plt.show()

Upvotes: 4

Mous
Mous

Reputation: 83

Ok, I figured it out. The Dataset is generated successfully and the data randomness seems to be decent. It's not an ideal solution for triplet loss as the triplets are random and not semihard.

def input_fn(self, params):
    batch_size = params['batch_size']

    assert self.data_dir, 'data_dir is required'
    shuffle = self.is_training

    dirs = list(map(lambda x: os.path.join(x, 'train-*' if self.is_training else 'validation-*')), self.dirs)

    def prefetch_dataset(filename): 
      dataset = tf.data.TFRecordDataset( 
          filename, buffer_size=FLAGS.prefetch_dataset_buffer_size)
      return dataset

    datasets = []
    for glob in dirs:
      dataset = tf.data.Dataset.list_files(glob)
      dataset = dataset.apply( 
        tf.contrib.data.parallel_interleave( 
            prefetch_dataset, 
            cycle_length=FLAGS.num_files_infeed, 
            sloppy=True)) # if order is important 
      dataset = dataset.shuffle(batch_size, None, True).repeat().prefetch(batch_size)
      datasets.append(dataset)

    def gen(x):
      return tf.data.Dataset.range(x,x+1).repeat(2)

    choice = tf.data.Dataset.range(len(datasets)).repeat().flat_map(gen)

    dataset = tf.data.experimental.choose_from_datasets(datasets, choice).map( # apply function to each element of the dataset in parallel
        self.dataset_parser, num_parallel_calls=FLAGS.num_parallel_calls)

    dataset = dataset.batch(batch_size, drop_remainder=True).prefetch(8)

    return dataset

Upvotes: 1

Related Questions