サバ缶
サバ缶

Reputation: 13

Is there any ways to concatenate 3 or more tf.data.Dataset

I want to concatenate 3 or more datasets in TensorFlow. To concatenate 2 datasets,

dataset1 = tf.data.Dataset.range(1, 4)
dataset2 = tf.data.Dataset.range(4, 8)
dataset1.concatenate(dataset2)

However, in this way,3 or more datasets cannot concatenate. So I want to do like

dataset1 = tf.data.Dataset.range(1, 4)
dataset2 = tf.data.Dataset.range(4, 8)
dataset3 = tf.data.Dataset.range(8, 12)
concatenate(dataset1,dataset2,dataset3)

Is there any ways?

Upvotes: 1

Views: 232

Answers (2)

asif abdullah
asif abdullah

Reputation: 270

import tensorflow as tf

dataset1 = tf.data.Dataset.range(1, 4)

dataset2 = tf.data.Dataset.range(4, 8)

dataset3 = tf.data.Dataset.range(8, 12)

 
def func(*datasets):

    out = {}

    for dataset in datasets:

        for key in dataset:

            if key in out:

                _value = out[key]

                out[key] = tf.concat([_value, dataset[key]], axis=-1)

            else:

                out[key] = dataset[key]

    return out


tf.data.Dataset.zip((dataset1, dataset2, dataset3)).map(func)

Upvotes: 0

xdurch0
xdurch0

Reputation: 10474

In this specific example you could just do

concat_dataset = dataset1.concatenate(dataset2).concatenate(dataset3)

Note that you have to assign the result of concatenate to a new variable! It doesn't operate in-place.

Of course this doesn't scale well if you have many datasets, but this should work:

datasets = [dataset1, dataset2, dataset3]  # can be more than 3 of course

concat_dataset = datasets[0]
for dset in datasets[1:]:
    concat_dataset = concat_dataset.concatenate(dset)

Upvotes: 1

Related Questions