kynnemall
kynnemall

Reputation: 888

How to fix Tensorflow Datasets memory leak when shuffling?

I want to train a model on the Stanford Dog Breed dataset which I download using Tensorflow Datasets, but when I go to train the model in Google Colab with GPU, it results in a memory error and causes Colab to restart the runtime:

tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:39] Overriding allow_growth setting because the TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original config value was 0.

I used the example tutorial from tensorflow so I know the order of operations is right. In the code below, I found that shuffling the dataset was the issue but this only become apparent when I called model.fit(); how can I shuffle the dataset and avoid the memory error?

import tensorflow_datasets as tfds

# Load the train and test data splits
(ds_train, ds_test), ds_info = tfds.load('stanford_dogs',
    split=['train', 'test'], shuffle_files=True, as_supervised=True, with_info=True,
)

def normalize_img(image, label):
    return tf.cast(image, tf.float32) / 255., label

ds_train = ds_train.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
# ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples) # this line causes the OOM errors
ds_train = ds_train.batch(batch_size)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)

ds_test = ds_test.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(1)
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)

Upvotes: 0

Views: 1597

Answers (1)

I'mahdi
I'mahdi

Reputation: 24059

I don't see the network that you use for training, But: (shuffle_files=True then your data are shuffling)

If I Understand Correctly (IIUC), your error came from the size of your images in the dataset. You can solve this by resizing images before use in training like below:

def normalize_img(image, label):
    image = tf.image.resize(image, (64, 64))
    return tf.cast(image, tf.float32) / 255., label

Full code:

import tensorflow_datasets as tfds
import tensorflow as tf

# Load the train and test data splits
(ds_train, ds_test), ds_info = tfds.load('stanford_dogs',
    split=['train', 'test'], shuffle_files=True, as_supervised=True, with_info=True,
)

def normalize_img(image, label):
    image = tf.image.resize(image, (64, 64))
    return tf.cast(image, tf.float32) / 255., label

ds_train = ds_train.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
# ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples) # this line causes the OOM errors
ds_train = ds_train.batch(16)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)

ds_test = ds_test.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(1)
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)

model = tf.keras.Sequential()
model.add(tf.keras.Input(shape=(64, 64, 3)))

model.add(tf.keras.layers.Conv2D(128, (3,3), activation='relu'))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.Dropout(rate=.4))

model.add(tf.keras.layers.Conv2D(64, (3,3), activation='relu'))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.Dropout(rate=.4))

model.add(tf.keras.layers.Conv2D(128, (3,3), activation='relu'))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.Dropout(rate=.4))

model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(512, activation='relu'))
model.add(tf.keras.layers.Dropout(rate=.4))            
model.add(tf.keras.layers.Dense(128, activation='relu'))
model.add(tf.keras.layers.Dropout(rate=.4))
model.add(tf.keras.layers.Dense(1, activation='sigmoid'))        
model.compile(loss='categorical_crossentropy', optimizer='Adam', metrics=['accuracy'])
model.summary()

model.fit(ds_train, epochs=2)

Output:

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 conv2d (Conv2D)             (None, 62, 62, 128)       3584      
                                                                 
 batch_normalization (BatchN  (None, 62, 62, 128)      512       
 ormalization)                                                   
                                                                 
 dropout (Dropout)           (None, 62, 62, 128)       0         
                                                                 
 conv2d_1 (Conv2D)           (None, 60, 60, 64)        73792     
                                                                 
 batch_normalization_1 (Batc  (None, 60, 60, 64)       256       
 hNormalization)                                                 
                                                                 
 dropout_1 (Dropout)         (None, 60, 60, 64)        0         
                                                                 
 conv2d_2 (Conv2D)           (None, 58, 58, 128)       73856     
                                                                 
 batch_normalization_2 (Batc  (None, 58, 58, 128)      512       
 hNormalization)                                                 
                                                                 
 dropout_2 (Dropout)         (None, 58, 58, 128)       0         
                                                                 
 flatten (Flatten)           (None, 430592)            0         
                                                                 
 dense (Dense)               (None, 512)               220463616 
                                                                 
 dropout_3 (Dropout)         (None, 512)               0         
                                                                 
 dense_1 (Dense)             (None, 128)               65664     
                                                                 
 dropout_4 (Dropout)         (None, 128)               0         
                                                                 
 dense_2 (Dense)             (None, 1)                 129       
                                                                 
=================================================================
Total params: 220,681,921
Trainable params: 220,681,281
Non-trainable params: 640
_________________________________________________________________
Epoch 1/2
750/750 [==============================] - 64s 78ms/step - loss: ... - accuracy: ....
Epoch 2/2
750/750 [==============================] - 55s 73ms/step - loss: ... - accuracy: ....

Upvotes: 0

Related Questions