Reputation: 13
I want to concatenate 3 or more datasets in TensorFlow. To concatenate 2 datasets,
dataset1 = tf.data.Dataset.range(1, 4)
dataset2 = tf.data.Dataset.range(4, 8)
dataset1.concatenate(dataset2)
However, in this way,3 or more datasets cannot concatenate. So I want to do like
dataset1 = tf.data.Dataset.range(1, 4)
dataset2 = tf.data.Dataset.range(4, 8)
dataset3 = tf.data.Dataset.range(8, 12)
concatenate(dataset1,dataset2,dataset3)
Is there any ways?
Upvotes: 1
Views: 232
Reputation: 270
import tensorflow as tf
dataset1 = tf.data.Dataset.range(1, 4)
dataset2 = tf.data.Dataset.range(4, 8)
dataset3 = tf.data.Dataset.range(8, 12)
def func(*datasets):
out = {}
for dataset in datasets:
for key in dataset:
if key in out:
_value = out[key]
out[key] = tf.concat([_value, dataset[key]], axis=-1)
else:
out[key] = dataset[key]
return out
tf.data.Dataset.zip((dataset1, dataset2, dataset3)).map(func)
Upvotes: 0
Reputation: 10474
In this specific example you could just do
concat_dataset = dataset1.concatenate(dataset2).concatenate(dataset3)
Note that you have to assign the result of concatenate
to a new variable! It doesn't operate in-place.
Of course this doesn't scale well if you have many datasets, but this should work:
datasets = [dataset1, dataset2, dataset3] # can be more than 3 of course
concat_dataset = datasets[0]
for dset in datasets[1:]:
concat_dataset = concat_dataset.concatenate(dset)
Upvotes: 1