cbournho
cbournho

Reputation: 145

Tensorflow, how to concatenate multiple datasets with varying batch sizes

Imagine I have:

I want to have take batches from both datasets and concatenate them so that I get batches of size 3 where:

I also want to read the final batch if some datasets get emptied first. In this instance, I would get [5, 5, 4], [5, 5, 4], [5] as my final result.

How can I do this? I've seen the answer here: Tensorflow how to generate unbalanced combined data sets

It is a good try, but it does not work if one of the datasets gets emptied before the others (because then tf.errors.OutOfRangeError gets outputted pre-emptively when you try to fetch elements from the dataset that gets emptied first and I do not get the final batch). Therefore I only get [5, 5, 4], [5, 5, 4]

I thought of using tf.contrib.data.choose_from_datasets:

ds1 = tf.data.Dataset.from_tensor_slices([5, 5, 5, 5, 5]).batch(2)
ds2 = tf.data.Dataset.from_tensor_slices([4, 4, 4, 4]).batch(1)
choice_dataset = [1, 2, 1, 2, 1]
ds = tf.contrib.data.choose_from_datasets([ds1, ds2], choice_dataset)
ds = ds.apply(tf.contrib.data.unbatch())
ds = ds.batch(3, drop_remainder=False)

This kind of works but is rather inelegant (there is unbatch and batch); also, I don't really have a great control over exactly what goes into a batch. (for instance if ds1 was [7] * 7 with batch size 2, and ds2 was [2, 2] with batch size 1, I would get [7, 7, 1], [7, 7, 1], [7, 7, 7]. But what if I actually want to have [7, 7, 1], [7, 7, 1], [7, 7], [7]? i.e. keep the number of elements from each dataset fixed.

Is there another better solution?

Another idea I had was to try to use tf.data.Dataset.flat_map:

ds1 = tf.data.Dataset.from_tensor_slices([5, 5, 5, 5, 5])
ds2 = tf.data.Dataset.from_tensor_slices([4, 4, 4, 4])
batch_sizes = [2, 1]
def concat(*inputs):
  concat = partial(functools.reduce, lambda x, y: x.concatenate(y))
  datasets = [tf.data.Dataset.from_tensors(input) for input in inputs]
  datasets = [dataset.batch(batch_size) for batch_size, dataset in zip(batch_sizes, datasets)]
  return concat(datasets)
dataset = (tf.data.Dataset
           .zip((ds1, ds2))
           .flat_map(_concat_and_batch)
           .batch(sum(batch_sizes)))

but it does not seem to work..

Upvotes: 6

Views: 7088

Answers (3)

Lior
Lior

Reputation: 2019

Here is a solution that requires you to use a "control input", to choose which batch to use, and you decide on this according to which dataset was consumed first. This can be detected using the thrown exception.

To explain this solution, I will first present an attempt that does not work.

Attempted solution #1

import tensorflow as tf

ds1 = tf.data.Dataset.from_tensor_slices([5,5,5,5,5])
ds2 = tf.data.Dataset.from_tensor_slices([4,4])

ds1 = ds1.batch(2)
ds2 = ds2.batch(1)

iter1 = ds1.make_one_shot_iterator()
iter2 = ds2.make_one_shot_iterator()

batch1 = iter1.get_next(name='batch1')
batch2 = iter2.get_next(name='batch2')
batch12 = tf.concat((batch1, batch2), 0)

# this is a "control" placeholder. Its value determines whether to use `batch1`,`batch2` or `batch12`
which_batch = tf.placeholder(tf.int32)

batch = tf.cond(
               tf.equal(which_batch,0), # if `which_batch`==0, use `batch12`
                       lambda:batch12,
        lambda:tf.cond(tf.equal(which_batch,1), # elif `which_batch`==1, use `batch1`
                       lambda:batch1,
        lambda:batch2)) # else, use `batch2`

sess = tf.Session()

which = 0 # this value will be fed into the control placeholder `which_batch`
while True:
    try:
        print(sess.run(batch,feed_dict={which_batch:which}))
    except tf.errors.OutOfRangeError as e:
        # use the error to detect which dataset was consumed, and update `which` accordingly
        if which==0:
            if 'batch2' in e.op.name:
                which = 1
            else:
                which = 2
        else:
            break

This solution does not work, since for any value of which_batch, the tf.cond() command will evaluate all the predecessors of its branches (see this answer) . Therefore, even when which_batch has the value 1, batch2 will be evaluated and an OutOfRangeError will be thrown.

Attempted solution #2

This problem can be fixed by moving the definitions of batch1, batch2 and batch12 into functions.

import tensorflow as tf

ds1 = tf.data.Dataset.from_tensor_slices([5,5,5,5,5])
ds2 = tf.data.Dataset.from_tensor_slices([4,4])

ds1 = ds1.batch(2)
ds2 = ds2.batch(1)

iter1 = ds1.make_one_shot_iterator()
iter2 = ds2.make_one_shot_iterator()

def get_batch1():
    batch1 = iter1.get_next(name='batch1')
    return batch1

def get_batch2():
    batch2 = iter2.get_next(name='batch2')
    return batch2

def get_batch12():
    batch1 = iter1.get_next(name='batch1_')
    batch2 = iter2.get_next(name='batch2_')
    batch12 = tf.concat((batch1, batch2), 0)
    return batch12

# this is a "control" placeholder. It's value determines whether to ues `batch1`,`batch2` or `batch12`
which_batch = tf.placeholder(tf.int32)

batch = tf.cond(
               tf.equal(which_batch,0), # if `which_batch`==0, use `batch12`
                       get_batch12,
        lambda:tf.cond(tf.equal(which_batch,1), # elif `which_batch`==1, use `batch1`
                       get_batch1,
        get_batch2)) # elif `which_batch`==2, use `batch2`

sess = tf.Session()

which = 0 # this value will be fed into the control placeholder `which_batch`
while True:
    try:
        print(sess.run(batch,feed_dict={which_batch:which}))
    except tf.errors.OutOfRangeError as e:
        # use the error to detect which dataset was consumed, and update `which` accordingly
        if which==0:
            if 'batch2' in e.op.name:
                which = 1
            else:
                which = 2
        else:
            break

However, this doesn't work either. The reason is that at the step when batch12 was formed and the dataset ds2 was consumed, then we took the batch from dataset ds1 and "dropped" it without using it.

Solution

We need a mechanism to make sure that we don't "drop" any batch in the event that the other dataset is consumed. We can do this by defining a variable which will be assigned the current batch of ds1, but only immediately before trying to obtain batch12. Otherwise, this variable will persist its previous value. Then, if batch12 fails due to ds1 being consumed, then this assignment will fail and batch2 was not dropped and we can use it next time. Otherwise, if batch12 fails due to ds2 being consumed, then we have a backup of batch1 in the variable that we have defined, and after using this backup we can proceed with taking batch1.

import tensorflow as tf

ds1 = tf.data.Dataset.from_tensor_slices([5,5,5,5,5])
ds2 = tf.data.Dataset.from_tensor_slices([4,4])

ds1 = ds1.batch(2)
ds2 = ds2.batch(1)

iter1 = ds1.make_one_shot_iterator()
iter2 = ds2.make_one_shot_iterator()

# this variable will store a backup of `batch1`, in case it is dropped
batch1_backup = tf.Variable(0, trainable=False, validate_shape=False)

def get_batch12():
    batch1 = iter1.get_next(name='batch1')

    # form the combined batch `batch12` only after backing-up `batch1`
    with tf.control_dependencies([tf.assign(batch1_backup, batch1, validate_shape=False)]):
        batch2 = iter2.get_next(name='batch2')
        batch12 = tf.concat((batch1, batch2), 0)
    return batch12

def get_batch1():
    batch1 = iter1.get_next()
    return batch1

def get_batch2():
    batch2 = iter2.get_next()
    return batch2

# this is a "control" placeholder. Its value determines whether to use `batch12`, `batch1_backup`, `batch1`, or `batch2`
which_batch = tf.Variable(0,trainable=False)

batch = tf.cond(
               tf.equal(which_batch,0), # if `which_batch`==0, use `batch12`
                       get_batch12,
        lambda:tf.cond(tf.equal(which_batch,1), # elif `which_batch`==1, use `batch1_backup`
                       lambda:batch1_backup,
        lambda:tf.cond(tf.equal(which_batch,2), # elif `which_batch`==2, use `batch1`
                       get_batch1,
       get_batch2))) # else, use `batch2`

sess = tf.Session()
sess.run(tf.global_variables_initializer())

which = 0  # this value will be fed into the control placeholder
while True:
    try:
        print(sess.run(batch,feed_dict={which_batch:which}))

        # if just used `batch1_backup`, proceed with `batch1`
        if which==1:
            which = 2
    except tf.errors.OutOfRangeError as e:
        # use the error to detect which dataset was consumed, and update `which` accordingly
        if which == 0:
            if 'batch2' in e.op.name:
                which = 1
            else:
                which = 3
        else:
            break

Upvotes: 1

Lior
Lior

Reputation: 2019

If you don't mind running a session during the construction of the new dataset, you can do the following:

import tensorflow as tf
import numpy as np

ds1 = tf.data.Dataset.from_tensor_slices([5,5,5,5,5])
ds2 = tf.data.Dataset.from_tensor_slices([4,4])

ds1 = ds1.batch(2)
ds2 = ds2.batch(1)

iter1 = ds1.make_one_shot_iterator()
iter2 = ds2.make_one_shot_iterator()

batch1 = iter1.get_next()
batch2 = iter2.get_next()

sess = tf.Session()

# define a generator that will sess.run both datasets, and will return the concatenation of both
def GetBatch():
    while True:
        try:
            b1 = sess.run(batch1)
        except tf.errors.OutOfRangeError:
            b1 = None
        try:
            b2 = sess.run(batch2)
        except tf.errors.OutOfRangeError:
            b2 = None
        if (b1 is None) and (b2 is None):
            break
        elif b1 is None:
            yield b2
        elif b2 is None:
            yield b1
        else:
            yield np.concatenate((b1,b2))

# create a dataset from the above generator
ds = tf.data.Dataset.from_generator(GetBatch,tf.int32)

Notice that the above session can be hidden\encapsulated if you wish (for example, inside a function), for example:

iter = ds.make_one_shot_iterator()
batch = iter.get_next()

sess2 = tf.Session()

while True:
    print(sess2.run(batch))

Upvotes: 3

Lior
Lior

Reputation: 2019

Here is a solution. It has some problems, but I hope it satisfies your needs.

The idea is as follows: You batch each of the two datasets, zip them together, and do a map function to combine each zipped tuple into one batch (so far, this is similar to what is suggested in this and this answers.)

The problem, as you noticed, is that the zipping only works well for two datasets that are of the same length. Otherwise, one dataset is consumed before the other, and the remaining unconsumed elements are not used.

My (kind of hacky) solution to this is to concatenate to both of the datasets another infinite dummy dataset. This dummy dataset consists only of values which you know are not going to appear in your real dataset. This eliminates the problem with zipping. However, you need to get rid of all the dummy elements. This can be easily be done by filtering and a mapping.

import tensorflow as tf

ds1 = tf.data.Dataset.from_tensor_slices([5,5,5,5,5])
ds2 = tf.data.Dataset.from_tensor_slices([4,4])

# we assume that this value will never occur in `ds1` and `ds2`:
UNUSED_VALUE = -1 

# an infinite dummy dataset:
dummy_ds = tf.data.Dataset.from_tensors(UNUSED_VALUE).repeat() 

# make `ds1` and `ds2` infinite:
ds1 = ds1.concatenate(dummy_ds)
ds2 = ds2.concatenate(dummy_ds)

ds1 = ds1.batch(2)
ds2 = ds2.batch(1)

# this is the solution mentioned in the links above
ds = tf.data.Dataset.zip((ds1,ds2))
ds = ds.map(lambda x1, x2: tf.concat((x1,x2),0))

# filter the infinite dummy tail:
ds = ds.filter(lambda x: tf.reduce_any(tf.not_equal(x,UNUSED_VALUE)))

# filter from batches the dummy elements:
ds = ds.map(lambda x: tf.boolean_mask(x,tf.not_equal(x,UNUSED_VALUE)))

There are two major problems with this solution:

(1) We need to have a value for UNUSED_VALUE which we are certain will not appear in the datasets. I suspect that there is a workaround this, maybe by making the dummy dataset consist of empty tensors (instead of tensors with a constant value ), but I couldn't figure out how to do this yet.

(2) Although this dataset has a finite number of elements, the following loop will never terminate:

iter = ds.make_one_shot_iterator()
batch = iter.get_next()
sess = tf.Session()
while True:
    print(sess.run(batch))

The reason is that the iterator keeps filtering out dummy examples, without knowing when to stop. This can be addressed by changing the repeat() call above to repeat(n), where n is a number that you know is longer than the difference between the lengths of the two datasets.

Upvotes: 2

Related Questions