Reputation: 83
I am trying to make Dataset that would provide batches of TFRecords wherein one batch there would be 2 random Records from one class and the rest from the other random classes.
OR
A Dataset of batches of where there would be 2 random Records from each class that fits into that batch.
I tried to do this with tf.data.Dataset.from_generator
and with tf.data.experimental.choose_from_datasets
but with no success. Do you have an idea on how to do this?
EDIT: Today i think i implemented the second variant. Here is the code i was testing it on.
def input_fn():
partial1 = tf.data.Dataset.from_tensor_slices(tf.range(0, 10)).repeat().shuffle(2)
partial2 = tf.data.Dataset.from_tensor_slices(tf.range(20, 30)).repeat().shuffle(2)
partial3 = tf.data.Dataset.from_tensor_slices(tf.range(60, 70)).repeat().shuffle(2)
l = [partial1, partial2, partial3]
def gen(x):
return tf.data.Dataset.range(x,x+1).repeat(2)
dataset = tf.data.Dataset.range(3).flat_map(gen).repeat(10)
choice = tf.data.experimental.choose_from_datasets(l, dataset).batch(4)
return choice
which when evaulated returns
[ 0 2 21 22]
[60 61 1 4]
[20 23 62 63]
[ 3 5 24 25]
[64 66 6 7]
[26 27 65 68]
[ 8 0 28 29]
[67 69 9 2]
[20 22 60 62]
[ 3 1 23 24]
[63 61 4 6]
[25 26 65 64]
[ 7 5 27 28]
[67 66 9 8]
[21 20 69 68]
Upvotes: 4
Views: 2725
Reputation: 41
In TF 2.0 , Now can use dataset.interleave
read diffence class's tfrecords, and use dataset.batch
to make triplet pair :
h = FcaeRecHelper('data/ms1m_img_ann.npy', [112, 112], 128, use_softmax=False)
len(h.train_list)
img_shape = list(h.in_hw) + [3]
is_augment = True
is_normlize = False
def parser(stream: bytes):
# parser tfrecords
examples: dict = tf.io.parse_single_example(
stream,
{'img': tf.io.FixedLenFeature([], tf.string),
'label': tf.io.FixedLenFeature([], tf.int64)})
return tf.image.decode_jpeg(examples['img'], 3), examples['label']
def pair_parser(raw_imgs, labels):
# imgs do same augment ~
if is_augment:
raw_imgs, _ = h.augment_img(raw_imgs, None)
# normlize image
if is_normlize:
imgs: tf.Tensor = h.normlize_img(raw_imgs)
else:
imgs = tf.cast(raw_imgs, tf.float32)
imgs.set_shape([4] + img_shape)
labels.set_shape([4, ])
# Note y_true shape will be [batch,3]
return (imgs[0], imgs[1], imgs[2]), (labels[:3])
batch_size = 1
# h.train_list : ['a.tfrecords','b.tfrecords','c.tfrecords',...]
ds = (tf.data.Dataset.from_tensor_slices(h.train_list)
.interleave(lambda x: tf.data.TFRecordDataset(x)
.shuffle(100)
.repeat(), cycle_length=-1,
# block_length = 2 is important
block_length=2,
num_parallel_calls=-1)
.map(parser, -1)
.batch(4, True)
.map(pair_parser, -1)
.batch(batch_size, True))
iters = iter(ds)
for i in range(20):
imgs, labels = next(iters)
fig, axs = plt.subplots(1, 3)
axs[0].imshow(imgs[0].numpy().astype('uint8')[0])
axs[1].imshow(imgs[1].numpy().astype('uint8')[0])
axs[2].imshow(imgs[2].numpy().astype('uint8')[0])
plt.show()
Upvotes: 4
Reputation: 83
Ok, I figured it out. The Dataset is generated successfully and the data randomness seems to be decent. It's not an ideal solution for triplet loss as the triplets are random and not semihard.
def input_fn(self, params):
batch_size = params['batch_size']
assert self.data_dir, 'data_dir is required'
shuffle = self.is_training
dirs = list(map(lambda x: os.path.join(x, 'train-*' if self.is_training else 'validation-*')), self.dirs)
def prefetch_dataset(filename):
dataset = tf.data.TFRecordDataset(
filename, buffer_size=FLAGS.prefetch_dataset_buffer_size)
return dataset
datasets = []
for glob in dirs:
dataset = tf.data.Dataset.list_files(glob)
dataset = dataset.apply(
tf.contrib.data.parallel_interleave(
prefetch_dataset,
cycle_length=FLAGS.num_files_infeed,
sloppy=True)) # if order is important
dataset = dataset.shuffle(batch_size, None, True).repeat().prefetch(batch_size)
datasets.append(dataset)
def gen(x):
return tf.data.Dataset.range(x,x+1).repeat(2)
choice = tf.data.Dataset.range(len(datasets)).repeat().flat_map(gen)
dataset = tf.data.experimental.choose_from_datasets(datasets, choice).map( # apply function to each element of the dataset in parallel
self.dataset_parser, num_parallel_calls=FLAGS.num_parallel_calls)
dataset = dataset.batch(batch_size, drop_remainder=True).prefetch(8)
return dataset
Upvotes: 1