Reputation: 702
I'd like to train a classifier on one ImageNet dataset (1000 classes each with around 1300 images). For some reason, I need each batch to contain 64 images from the same class, and consecutive batches from different classes. Is it possible (and efficient) with the latest TensorFlow?
tf.contrib.data.sample_from_datasets
in TF 1.9 allows sampling from a list of tf.data.Dataset
objects, with weights
indicating the probabilities. I wonder if the following idea makes sense:
tf.data.Dataset.from_generator
object as the weights
. The object samples from a Categorical distribution such that each sample looks like [0,...,0,1,0,...,0]
with 999 0
s and 1 1
;tf.data.Dataset
objects, each linked a tfrecord file. I thought, in this way, maybe at each iteration, sample_from_datasets
will first sample a sparse weight vector that indicates which tf.data.Dataset
to sample from, then same from that class.
Is it correct? Are there any other efficient ways?
Update
As kindly suggested by P-Gn, one way to sample data from one class would be:
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(some_parser_fun) # parse one datum from tfrecord
dataset = dataset.shuffle(buffer_size)
if sample_same_class:
group_fun = tf.contrib.data.group_by_window(
key_func=lambda data_x, data_y: data_y,
reduce_func=lambda key, d: d.batch(batch_size),
window_size=batch_size)
dataset = dataset.apply(group_fun)
else:
dataset = dataset.batch(batch_size)
dataset = dataset.repeat()
data_batch = dataset.make_one_shot_iterator().get_next()
A follow-up question can be found at How to sample batch from a specific class?
Upvotes: 4
Views: 1985
Reputation: 683
P-Gn's solution to create separate datasets for each class is probably optimal. However, this can be avoided as follows
# Init some dataset
num_classes = 10
label = tf.range(num_classes, dtype=tf.int32)
features = tf.cast(label * 10, dtype=tf.float32) + tf.random.uniform(shape=[tf.shape(label)[0]], maxval=0.01)
dataset = tf.data.Dataset.from_tensor_slices({'label':label, 'features': features})
dataset = dataset.repeat()
# Split to buckets by label
batch_size = 4
dataset = dataset.apply(tf.data.experimental.bucket_by_sequence_length(
element_length_func=lambda s: s['label'],
bucket_boundaries=list(range(1, num_classes)),
bucket_batch_sizes=[batch_size] * num_classes,
))
# Show result
iterator = dataset.as_numpy_iterator()
for i in range(5):
print(next(iterator))
# {'label': array([0, 0, 0, 0], dtype=int32), 'features': array([0.00370963, 0.00370963, 0.00370963, 0.00370963], dtype=float32)}
# {'label': array([1, 1, 1, 1], dtype=int32), 'features': array([10.009371, 10.009371, 10.009371, 10.009371], dtype=float32)}
# {'label': array([2, 2, 2, 2], dtype=int32), 'features': array([20.001854, 20.001854, 20.001854, 20.001854], dtype=float32)}
# {'label': array([3, 3, 3, 3], dtype=int32), 'features': array([30.005934, 30.005934, 30.005934, 30.005934], dtype=float32)}
# {'label': array([4, 4, 4, 4], dtype=int32), 'features': array([40.001686, 40.001686, 40.001686, 40.001686], dtype=float32)}
Upvotes: -1
Reputation: 24591
I don't think your solution could work, if I understand it correctly, because sample_from_dataset
expects a list of values for its weights
, not a Tensor
.
However if you don't mind having 1000 Dataset
s as in your proposed solution, then I would suggest to simply
Dataset
per class,batch
each of these datasets — each batch has samples from a single class,zip
all of them into one big Dataset
of batches,shuffle
this Dataset
— the shuffling will occur on the batches, not on the samples, so it won't change the fact that batches are single class.A more sophisticated way is to rely on tf.contrib.data.group_by_window
. Let me illustrate that with a synthetic example.
import numpy as np
import tensorflow as tf
def gen():
while True:
x = np.random.normal()
label = np.random.randint(10)
yield x, label
batch_size = 4
batch = (tf.data.Dataset
.from_generator(gen, (tf.float32, tf.int64), (tf.TensorShape([]), tf.TensorShape([])))
.apply(tf.contrib.data.group_by_window(
key_func=lambda x, label: label,
reduce_func=lambda key, d: d.batch(batch_size),
window_size=batch_size))
.make_one_shot_iterator()
.get_next())
sess = tf.InteractiveSession()
sess.run(batch)
# (array([ 0.04058843, 0.2843775 , -1.8626076 , 1.1154234 ], dtype=float32),
# array([6, 6, 6, 6], dtype=int64))
sess.run(batch)
# (array([ 1.3600663, 0.5935658, -0.6740045, 1.174328 ], dtype=float32),
# array([3, 3, 3, 3], dtype=int64))
Upvotes: 3