BBQuercus
BBQuercus

Reputation: 879

Loading segmentation data with tf.data as Dataset?

I want to load and augment a custom dataset for segmentation. For segmentation, I prepared a npz file containing four subsets:

with np.load(PATH) as data:
    train_x = data['x_train']
    valid_x = data['x_valid']
    train_y = data['y_train']
    valid_y = data['y_valid']

Train / valid have their corresponding meanings and x / y stand for the input image (x) and the segmentation mask (y). In training, my model will take the input x and the loss will be calculated on the model output relative to y.

My question now is how to go ahead to get a tf.data Dataset that I can iterate over in training. I have tried the following:

train_dataset = tf.data.Dataset.from_tensor_slices((train_x, train_y))

>>> train_dataset
<TensorSliceDataset shapes: ((520, 696), (520, 696)), types: (tf.uint16, tf.uint8)>

def load(data_group):
    image, mask = data_group
    image = tf.cast(image, tf.float32)
    mask = tf.cast(mask, tf.float32)
    return image, mask

def normalize(image):
    return (image / 65535/2) - 1

def load_image_train(data_group):
    image, mask = load(data_group)
    image = normalize(image)
    # Perform augmentation (not shown)
    return image, mask

train_dataset = train_dataset.map(load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
train_dataset = train_dataset.batch(BATCH_SIZE)

This, however, fails while trying to map the load_image_train train function returning an error tf__load_image_train() takes 1 positional argument but 2 were given. In general this approach feels slightly clunky and would love to know alternatives / possibilities to improve this data import.

Thanks in advance

Upvotes: 0

Views: 1676

Answers (1)

Shubham Shaswat
Shubham Shaswat

Reputation: 1310

You should write this way:

def load_image_train(image,mask):

  image = tf.cast(image, tf.float32)
  mask = tf.cast(mask, tf.float32)
  image = normalize(image)

  return image, mask

The tf.data.Dataset will return the pair of tensors in your case.

Also check out the Tensorflow Guide

Upvotes: 2

Related Questions