Reputation: 95
I'm coding a Pix2Pix network, with my own load_input/real_image function, and I'm currently creating the dataset with tf.data.Dataset. The problem is that my dataset has the wrong shape:
I've tried applying a few tf.data.experimemtal functions, none of them work as I want.
raw_data = [load_image_train(category)
for category in SELECTED_CATEGORIES
for _ in range(min(MAX_SAMPLES_PER_CATEGORY, category[1]))]
train_dataset = tf.data.Dataset.from_tensor_slices(raw_data)
train_dataset = train_dataset.cache().shuffle(BUFFER_SIZE)
train_dataset = train_dataset.batch(1)
I have : < BatchDataset shapes: (None, 2, 256, 256, 3), types: tf.float32>
I want : < DatasetV1Adapter shapes: ((None, 256, 256, 3), (None, 256, 256, 3)), types: (tf.float32, tf.float32)>
Upvotes: 4
Views: 1811
Reputation: 6034
You can do it in two ways.
Option 1 (Preferred)
raw_data1, raw_data2 = tf.unstack(raw_data, axis=1)
train_dataset = tf.data.Dataset.from_tensor_slices((raw_data1, raw_data2))
Option 2
def map_fn(data):
return tf.unstack(data, axis=0)
train_dataset = tf.data.Dataset.from_tensor_slices(raw_data)
train_dataset = train_dataset.map(map_fn)
Upvotes: 2