Andy Wei
Andy Wei

Reputation: 618

Opening image with PIL with tf.data

I'm currently trying to use tf.data to load VOC2012 dataset for semantic segmentation. The labels in VOC2012 uses colour map, which would be automatically converted if I use PIL library. This is not the case when I invoke tf.read_file.

from PIL import Image

train_data = tf.data.Dataset.from_tensor_slices((img_filename_list, lbl_filename_list))

def preprocessing(img_filename, lbl_filename):
    # Load image
    train_img = tf.read_file(img_path + img_filename)
    train_img = tf.image.decode_jpeg(train_img, channels=3)
    train_img = train_img / 255.0  # Normalize

    return train_img, lbl_filename

train_data = train_data.map(preprocessing).shuffle(100).repeat().batch(2)
iterator = train_data.make_initializable_iterator()
next_element = iterator.get_next()
training_init_op = iterator.make_initializer(train_data)

with tf.Session() as sess:
    sess.run(training_init_op)
    while True:
        train_images, lbl_filename = sess.run(next_element)

This is what I'm doing right now, although ideally, I want the preprocessing function to return a label image loaded using PIL so I can create one-hot vectors.

def preprocessing(img_filename, lbl_filename):
    ...# Load train images
    train_lbl = Image.open(lbl_path + lbl_filename)
    ...# Do some other stuff
    return train_img, train_lbl

This would give an error

AttributeError: 'Tensor' object has no attribute 'read'

Is there any solution to this?

Upvotes: 1

Views: 1394

Answers (1)

Andy Wei
Andy Wei

Reputation: 618

As suggested by @GPhilo, using the tf.py_func would solve this problem. Here's my solution code

def read_labels(lbl_filename):
    train_lbl = Image.open(lbl_path + lbl_filename.decode("utf-8"))
    train_lbl = np.asarray(train_lbl)
    return train_lbl

def preprocessing(img_filename, lbl_filename):
    train_lbl = tf.py_func(read_labels, [lbl_filename], tf.uint8)

Upvotes: 2

Related Questions