PyStraw45
PyStraw45

Reputation: 214

Error when getting features from tensorflow-dataset

Im getting an error when attempting to load the Caltech tensorflow-dataset. I'm using the standard code found in the tensorflow-datasets GitHub

The error is this:

tensorflow.python.framework.errors_impl.InvalidArgumentError: Cannot batch tensors with different shapes in component 0. First element had shape [204,300,3] and element 1 had shape [153,300,3]. [Op:IteratorGetNextSync]

The error points to the line for features in ds_train.take(1)

Code:

ds_train, ds_test = tfds.load(name="caltech101", split=["train", "test"])

ds_train = ds_train.shuffle(1000).batch(128).prefetch(10)
for features in ds_train.take(1):
    image, label = features["image"], features["label"]

Upvotes: 0

Views: 557

Answers (1)

GPhilo
GPhilo

Reputation: 19123

The issue comes from the fact that the dataset contains variable-sized images (see the dataset description here). Tensorflow can only batch together things with the same shape, so you first need to either reshape the images to a common shape (e.g., the input shape of your network) or pad them accordingly.

If you want to resize, use tf.image.resize_images:

def preprocess(features, label):
  features['image'] = tf.image.resize_images(features['image'], YOUR_TARGET_SIZE)
  # Other possible transformations needed (e.g., converting to float, normalizing to [0,1]
  return features, label

If, instead, you want to pad, use tf.image.pad_to_bounding_box (just replace it in the above preprocess function and adapt the parameters as needed). Normally, for most of the networks I'm aware of, resizing is used.

Finally, map the function on your dataset:

ds_train = (ds_train
            .map(prepocess)
            .shuffle(1000)
            .batch(128)
            .prefetch(10))

Note: The variable shapes in the error codes come from the shuffle call.

Upvotes: 1

Related Questions