Reputation: 145
Imagine I have:
I want to have take batches from both datasets and concatenate them so that I get batches of size 3 where:
I also want to read the final batch if some datasets get emptied first. In this instance, I would get [5, 5, 4], [5, 5, 4], [5] as my final result.
How can I do this? I've seen the answer here: Tensorflow how to generate unbalanced combined data sets
It is a good try, but it does not work if one of the datasets gets emptied before the others (because then tf.errors.OutOfRangeError
gets outputted pre-emptively when you try to fetch elements from the dataset that gets emptied first and I do not get the final batch). Therefore I only get [5, 5, 4], [5, 5, 4]
I thought of using tf.contrib.data.choose_from_datasets
:
ds1 = tf.data.Dataset.from_tensor_slices([5, 5, 5, 5, 5]).batch(2)
ds2 = tf.data.Dataset.from_tensor_slices([4, 4, 4, 4]).batch(1)
choice_dataset = [1, 2, 1, 2, 1]
ds = tf.contrib.data.choose_from_datasets([ds1, ds2], choice_dataset)
ds = ds.apply(tf.contrib.data.unbatch())
ds = ds.batch(3, drop_remainder=False)
This kind of works but is rather inelegant (there is unbatch and batch); also, I don't really have a great control over exactly what goes into a batch. (for instance if ds1 was [7] * 7 with batch size 2, and ds2 was [2, 2] with batch size 1, I would get [7, 7, 1], [7, 7, 1], [7, 7, 7]. But what if I actually want to have [7, 7, 1], [7, 7, 1], [7, 7], [7]? i.e. keep the number of elements from each dataset fixed.
Is there another better solution?
Another idea I had was to try to use tf.data.Dataset.flat_map
:
ds1 = tf.data.Dataset.from_tensor_slices([5, 5, 5, 5, 5])
ds2 = tf.data.Dataset.from_tensor_slices([4, 4, 4, 4])
batch_sizes = [2, 1]
def concat(*inputs):
concat = partial(functools.reduce, lambda x, y: x.concatenate(y))
datasets = [tf.data.Dataset.from_tensors(input) for input in inputs]
datasets = [dataset.batch(batch_size) for batch_size, dataset in zip(batch_sizes, datasets)]
return concat(datasets)
dataset = (tf.data.Dataset
.zip((ds1, ds2))
.flat_map(_concat_and_batch)
.batch(sum(batch_sizes)))
but it does not seem to work..
Upvotes: 6
Views: 7088
Reputation: 2019
Here is a solution that requires you to use a "control input", to choose which batch to use, and you decide on this according to which dataset was consumed first. This can be detected using the thrown exception.
To explain this solution, I will first present an attempt that does not work.
import tensorflow as tf
ds1 = tf.data.Dataset.from_tensor_slices([5,5,5,5,5])
ds2 = tf.data.Dataset.from_tensor_slices([4,4])
ds1 = ds1.batch(2)
ds2 = ds2.batch(1)
iter1 = ds1.make_one_shot_iterator()
iter2 = ds2.make_one_shot_iterator()
batch1 = iter1.get_next(name='batch1')
batch2 = iter2.get_next(name='batch2')
batch12 = tf.concat((batch1, batch2), 0)
# this is a "control" placeholder. Its value determines whether to use `batch1`,`batch2` or `batch12`
which_batch = tf.placeholder(tf.int32)
batch = tf.cond(
tf.equal(which_batch,0), # if `which_batch`==0, use `batch12`
lambda:batch12,
lambda:tf.cond(tf.equal(which_batch,1), # elif `which_batch`==1, use `batch1`
lambda:batch1,
lambda:batch2)) # else, use `batch2`
sess = tf.Session()
which = 0 # this value will be fed into the control placeholder `which_batch`
while True:
try:
print(sess.run(batch,feed_dict={which_batch:which}))
except tf.errors.OutOfRangeError as e:
# use the error to detect which dataset was consumed, and update `which` accordingly
if which==0:
if 'batch2' in e.op.name:
which = 1
else:
which = 2
else:
break
This solution does not work, since for any value of which_batch
, the tf.cond()
command will evaluate all the predecessors of its branches (see this answer) . Therefore, even when which_batch
has the value 1, batch2
will be evaluated and an OutOfRangeError
will be thrown.
This problem can be fixed by moving the definitions of batch1
, batch2
and batch12
into functions.
import tensorflow as tf
ds1 = tf.data.Dataset.from_tensor_slices([5,5,5,5,5])
ds2 = tf.data.Dataset.from_tensor_slices([4,4])
ds1 = ds1.batch(2)
ds2 = ds2.batch(1)
iter1 = ds1.make_one_shot_iterator()
iter2 = ds2.make_one_shot_iterator()
def get_batch1():
batch1 = iter1.get_next(name='batch1')
return batch1
def get_batch2():
batch2 = iter2.get_next(name='batch2')
return batch2
def get_batch12():
batch1 = iter1.get_next(name='batch1_')
batch2 = iter2.get_next(name='batch2_')
batch12 = tf.concat((batch1, batch2), 0)
return batch12
# this is a "control" placeholder. It's value determines whether to ues `batch1`,`batch2` or `batch12`
which_batch = tf.placeholder(tf.int32)
batch = tf.cond(
tf.equal(which_batch,0), # if `which_batch`==0, use `batch12`
get_batch12,
lambda:tf.cond(tf.equal(which_batch,1), # elif `which_batch`==1, use `batch1`
get_batch1,
get_batch2)) # elif `which_batch`==2, use `batch2`
sess = tf.Session()
which = 0 # this value will be fed into the control placeholder `which_batch`
while True:
try:
print(sess.run(batch,feed_dict={which_batch:which}))
except tf.errors.OutOfRangeError as e:
# use the error to detect which dataset was consumed, and update `which` accordingly
if which==0:
if 'batch2' in e.op.name:
which = 1
else:
which = 2
else:
break
However, this doesn't work either. The reason is that at the step when batch12
was formed and the dataset ds2
was consumed, then we took the batch from dataset ds1
and "dropped" it without using it.
We need a mechanism to make sure that we don't "drop" any batch in the event that the other dataset is consumed. We can do this by defining a variable which will be assigned the current batch of ds1
, but only immediately before trying to obtain batch12
. Otherwise, this variable will persist its previous value. Then, if batch12
fails due to ds1
being consumed, then this assignment will fail and batch2
was not dropped and we can use it next time. Otherwise, if batch12
fails due to ds2
being consumed, then we have a backup of batch1
in the variable that we have defined, and after using this backup we can proceed with taking batch1
.
import tensorflow as tf
ds1 = tf.data.Dataset.from_tensor_slices([5,5,5,5,5])
ds2 = tf.data.Dataset.from_tensor_slices([4,4])
ds1 = ds1.batch(2)
ds2 = ds2.batch(1)
iter1 = ds1.make_one_shot_iterator()
iter2 = ds2.make_one_shot_iterator()
# this variable will store a backup of `batch1`, in case it is dropped
batch1_backup = tf.Variable(0, trainable=False, validate_shape=False)
def get_batch12():
batch1 = iter1.get_next(name='batch1')
# form the combined batch `batch12` only after backing-up `batch1`
with tf.control_dependencies([tf.assign(batch1_backup, batch1, validate_shape=False)]):
batch2 = iter2.get_next(name='batch2')
batch12 = tf.concat((batch1, batch2), 0)
return batch12
def get_batch1():
batch1 = iter1.get_next()
return batch1
def get_batch2():
batch2 = iter2.get_next()
return batch2
# this is a "control" placeholder. Its value determines whether to use `batch12`, `batch1_backup`, `batch1`, or `batch2`
which_batch = tf.Variable(0,trainable=False)
batch = tf.cond(
tf.equal(which_batch,0), # if `which_batch`==0, use `batch12`
get_batch12,
lambda:tf.cond(tf.equal(which_batch,1), # elif `which_batch`==1, use `batch1_backup`
lambda:batch1_backup,
lambda:tf.cond(tf.equal(which_batch,2), # elif `which_batch`==2, use `batch1`
get_batch1,
get_batch2))) # else, use `batch2`
sess = tf.Session()
sess.run(tf.global_variables_initializer())
which = 0 # this value will be fed into the control placeholder
while True:
try:
print(sess.run(batch,feed_dict={which_batch:which}))
# if just used `batch1_backup`, proceed with `batch1`
if which==1:
which = 2
except tf.errors.OutOfRangeError as e:
# use the error to detect which dataset was consumed, and update `which` accordingly
if which == 0:
if 'batch2' in e.op.name:
which = 1
else:
which = 3
else:
break
Upvotes: 1
Reputation: 2019
If you don't mind running a session during the construction of the new dataset, you can do the following:
import tensorflow as tf
import numpy as np
ds1 = tf.data.Dataset.from_tensor_slices([5,5,5,5,5])
ds2 = tf.data.Dataset.from_tensor_slices([4,4])
ds1 = ds1.batch(2)
ds2 = ds2.batch(1)
iter1 = ds1.make_one_shot_iterator()
iter2 = ds2.make_one_shot_iterator()
batch1 = iter1.get_next()
batch2 = iter2.get_next()
sess = tf.Session()
# define a generator that will sess.run both datasets, and will return the concatenation of both
def GetBatch():
while True:
try:
b1 = sess.run(batch1)
except tf.errors.OutOfRangeError:
b1 = None
try:
b2 = sess.run(batch2)
except tf.errors.OutOfRangeError:
b2 = None
if (b1 is None) and (b2 is None):
break
elif b1 is None:
yield b2
elif b2 is None:
yield b1
else:
yield np.concatenate((b1,b2))
# create a dataset from the above generator
ds = tf.data.Dataset.from_generator(GetBatch,tf.int32)
Notice that the above session can be hidden\encapsulated if you wish (for example, inside a function), for example:
iter = ds.make_one_shot_iterator()
batch = iter.get_next()
sess2 = tf.Session()
while True:
print(sess2.run(batch))
Upvotes: 3
Reputation: 2019
Here is a solution. It has some problems, but I hope it satisfies your needs.
The idea is as follows: You batch each of the two datasets, zip them together, and do a map function to combine each zipped tuple into one batch (so far, this is similar to what is suggested in this and this answers.)
The problem, as you noticed, is that the zipping only works well for two datasets that are of the same length. Otherwise, one dataset is consumed before the other, and the remaining unconsumed elements are not used.
My (kind of hacky) solution to this is to concatenate to both of the datasets another infinite dummy dataset. This dummy dataset consists only of values which you know are not going to appear in your real dataset. This eliminates the problem with zipping. However, you need to get rid of all the dummy elements. This can be easily be done by filtering and a mapping.
import tensorflow as tf
ds1 = tf.data.Dataset.from_tensor_slices([5,5,5,5,5])
ds2 = tf.data.Dataset.from_tensor_slices([4,4])
# we assume that this value will never occur in `ds1` and `ds2`:
UNUSED_VALUE = -1
# an infinite dummy dataset:
dummy_ds = tf.data.Dataset.from_tensors(UNUSED_VALUE).repeat()
# make `ds1` and `ds2` infinite:
ds1 = ds1.concatenate(dummy_ds)
ds2 = ds2.concatenate(dummy_ds)
ds1 = ds1.batch(2)
ds2 = ds2.batch(1)
# this is the solution mentioned in the links above
ds = tf.data.Dataset.zip((ds1,ds2))
ds = ds.map(lambda x1, x2: tf.concat((x1,x2),0))
# filter the infinite dummy tail:
ds = ds.filter(lambda x: tf.reduce_any(tf.not_equal(x,UNUSED_VALUE)))
# filter from batches the dummy elements:
ds = ds.map(lambda x: tf.boolean_mask(x,tf.not_equal(x,UNUSED_VALUE)))
There are two major problems with this solution:
(1) We need to have a value for UNUSED_VALUE
which we are certain will not appear in the datasets. I suspect that there is a workaround this, maybe by making the dummy dataset consist of empty tensors (instead of tensors with a constant value ), but I couldn't figure out how to do this yet.
(2) Although this dataset has a finite number of elements, the following loop will never terminate:
iter = ds.make_one_shot_iterator()
batch = iter.get_next()
sess = tf.Session()
while True:
print(sess.run(batch))
The reason is that the iterator keeps filtering out dummy examples, without knowing when to stop. This can be addressed by changing the repeat()
call above to repeat(n)
, where n
is a number that you know is longer than the difference between the lengths of the two datasets.
Upvotes: 2