dim_tz
dim_tz

Reputation: 1531

Tensorflow - shuffling at "batch-level" instead of"example-level"

I have a problem that I will try to explain with an example for easier understanding.

I want to classify oranges (O) and apples (A). For technical/legacy reasons (a component in the network) each batch should have either only O or only A examples. So traditional shuffling at example-level is not possible/adequate, since I cannot afford to have a batch that includes a mixture of O and A examples. However some kind of shuffling is desirable, as it is a common practise to train deep networks.

These are the steps that I take:

Upvotes: 0

Views: 133

Answers (1)

RobR
RobR

Reputation: 2190

If you use the Dataset api it's fairly straightforward. Just zip the O and A batches, then apply a random selection function with Dataset.map():

ds0 = tf.data.Dataset.from_tensor_slices([0])
ds0 = ds0.repeat()
ds0 = ds0.batch(5)
ds1 = tf.data.Dataset.from_tensor_slices([1])
ds1 = ds1.repeat()
ds1 = ds1.batch(5)

def rand_select(ds0, ds1):
    rval = tf.random_uniform([])
    return tf.cond(rval<0.5, lambda: ds0, lambda: ds1)

dataset = tf.data.Dataset()
dataset = dataset.zip((ds0, ds1)).map(lambda ds0, ds1: rand_select(ds0, ds1))
iterator = dataset.make_one_shot_iterator()
ds = iterator.get_next()

with tf.Session() as sess:
    for _ in range(5):
        print(sess.run(ds))

> [0 0 0 0 0]
  [1 1 1 1 1]
  [1 1 1 1 1]
  [0 0 0 0 0]
  [0 0 0 0 0]

Upvotes: 2

Related Questions