Lelouche Lamperouge
Lelouche Lamperouge

Reputation: 191

How to use custom dataset generators with TPU?

My dataset is large (around 13gb). I have a hdf5 file of the dataset and i'm using a custom generator to load batches from the dataset. My model runs fine on Kaggle GPU but when i switch to TPU i'm getting an error. Below is my generator function and the error i receive when running model.fit.

def generate_data():

while True: # Loop forever so the generator never terminates
    
    for _ in range(0, num_samples, BATCH_SIZE):
        # Get the samples you'll use in this batch
        offset=np.random.randint(num_samples)
        X_train = hdf5_file['data'][offset]
        X_train=X_train.transpose(1,2,0)
        X_train=X_train.astype(np.float32)
        X_train=(X_train-127.5)/127.5
        X_train = cv2.resize(X_train, dsize=(IMG_SHAPE[1],IMG_SHAPE[1]), interpolation=cv2.INTER_CUBIC)
        X_train = np.array(X_train)
        
        #yield the next training batch
        yield X_train

Here's the code to generate tf datasets from the generator fn.

dataset = tf.data.Dataset.from_generator(generate_data, (tf.float32))
dataset = dataset.batch(BATCH_SIZE,drop_remainder=True)

Now here's the error i receive when i run model.fit with the above dataset.

TypeError: in user code:

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py:571 train_function  *
    outputs = self.distribute_strategy.run(
/opt/conda/lib/python3.7/site-packages/tensorflow/python/distribute/tpu_strategy.py:174 run  **
    return self.extended.tpu_run(fn, args, kwargs, options)
/opt/conda/lib/python3.7/site-packages/tensorflow/python/distribute/tpu_strategy.py:867 tpu_run
    return func(args, kwargs)
/opt/conda/lib/python3.7/site-packages/tensorflow/python/distribute/tpu_strategy.py:916 tpu_function
    maximum_shape = tensor_shape.TensorShape([None] * rank)

TypeError: can't multiply sequence by non-int of type 'NoneType'

As i mentioned the code works fine on GPU without any modification. What should i do to make it work on TPU?

Upvotes: 3

Views: 1099

Answers (1)

Stsh4lson
Stsh4lson

Reputation: 41

Have you tried defining shape of an output in from_generator() function?

Upvotes: 1

Related Questions