Reputation: 1395
I want to use the Tensorflow Dataset API to create one batch per folder (each folder containing images). I have the following simple code snippet:
import tensorflow as tf
import os
import pdb
def parse_file(filename):
image_string = tf.read_file(filename)
image_decoded = tf.image.decode_png(image_string)
image_resized = tf.image.resize_images(image_decoded, [48, 48])
return image_resized #, label
def parse_dir(frame_dir):
filenames = tf.gfile.ListDirectory(frame_dir)
batch = tf.constant(5)
batch = tf.map_fn(parse_file, filenames)
return batch
directory = "../Detections/NAC20171125"
# filenames = tf.constant([os.path.join(directory, f) for f in os.listdir(directory)])
frames = [os.path.join(directory, str(f)) for f in range(10)]
dataset = tf.data.Dataset.from_tensor_slices((frames))
dataset = dataset.map(parse_dir)
dataset = dataset.batch(256)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
sess.run(iterator.initializer)
while True:
try:
batch = sess.run(next_element)
print(batch.shape)
except tf.errors.OutOfRangeError:
break
However, tf.gfile.ListDirectory (in parse_dir) expects a normal string instead of a Tensor. So now the error is
TypeError: Expected binary or unicode string, got <tf.Tensor 'arg0:0' shape=() dtype=string>
Is there a simple way to solve this?
Upvotes: 2
Views: 2086
Reputation: 126184
The problem here is that tf.gfile.ListDirectory()
is a Python function that expects a Python string, and the frame_dir
argument to parse_dir()
is a tf.Tensor
. Therefore you require an equivalent TensorFlow operation to list the files in the directory, and tf.data.Dataset.list_files()
(based on tf.matching_files()
) is probably the closest equivalent.
directory = "../Detections/NAC20171125"
frames = [os.path.join(directory, str(f)) for f in range(10)]
# Start with a dataset of directory names.
dataset = tf.data.Dataset.from_tensor_slices(frames)
# Maps each subdirectory to the list of files in that subdirectory and flattens
# the result.
dataset = dataset.flat_map(lambda dir: tf.data.Dataset.list_files(dir + "/*"))
# Maps each filename to the parsed and resized image data.
dataset = dataset.map(parse_file)
dataset = dataset.batch(256)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
Upvotes: 4