Reputation: 618
I'm currently trying to use tf.data to load VOC2012 dataset for semantic segmentation. The labels in VOC2012 uses colour map, which would be automatically converted if I use PIL library. This is not the case when I invoke tf.read_file.
from PIL import Image
train_data = tf.data.Dataset.from_tensor_slices((img_filename_list, lbl_filename_list))
def preprocessing(img_filename, lbl_filename):
# Load image
train_img = tf.read_file(img_path + img_filename)
train_img = tf.image.decode_jpeg(train_img, channels=3)
train_img = train_img / 255.0 # Normalize
return train_img, lbl_filename
train_data = train_data.map(preprocessing).shuffle(100).repeat().batch(2)
iterator = train_data.make_initializable_iterator()
next_element = iterator.get_next()
training_init_op = iterator.make_initializer(train_data)
with tf.Session() as sess:
sess.run(training_init_op)
while True:
train_images, lbl_filename = sess.run(next_element)
This is what I'm doing right now, although ideally, I want the preprocessing function to return a label image loaded using PIL so I can create one-hot vectors.
def preprocessing(img_filename, lbl_filename):
...# Load train images
train_lbl = Image.open(lbl_path + lbl_filename)
...# Do some other stuff
return train_img, train_lbl
This would give an error
AttributeError: 'Tensor' object has no attribute 'read'
Is there any solution to this?
Upvotes: 1
Views: 1394
Reputation: 618
As suggested by @GPhilo, using the tf.py_func would solve this problem. Here's my solution code
def read_labels(lbl_filename):
train_lbl = Image.open(lbl_path + lbl_filename.decode("utf-8"))
train_lbl = np.asarray(train_lbl)
return train_lbl
def preprocessing(img_filename, lbl_filename):
train_lbl = tf.py_func(read_labels, [lbl_filename], tf.uint8)
Upvotes: 2