rzaratx
rzaratx

Reputation: 824

Does Tensorflows MirroredStrategy() split the training model?

Does Tensorflows MirroredStrategy() split the training model across multiple GPUs? I am trying to run a 3D-UNet and I am at a limit of 224x224x224 for the volume for my training data on a single GPU. I am trying to implement MirroredStrategy() and with tf.device(): to pass parts of the model to a second GPU. I still am not able to pass the 224x224x224 limit. If I go for a larger volume I get a ResourceExhaustedError.

Code:

def get_model(optimizer, loss_metric, metrics, lr=1e-3):
    with tf.device('/job:localhost/replica:0/task:0/device:GPU:0'):
        inputs = Input((sample_width, sample_height, sample_depth, 1))
        conv1 = Conv3D(32, (3, 3, 3), activation='relu', padding='same')(inputs)
        conv1 = Conv3D(32, (3, 3, 3), activation='relu', padding='same')(conv1)
        pool1 = MaxPooling3D(pool_size=(2, 2, 2))(conv1)
        drop1 = Dropout(0.5)(pool1)
        conv2 = Conv3D(64, (3, 3, 3), activation='relu', padding='same')(drop1)
        conv2 = Conv3D(64, (3, 3, 3), activation='relu', padding='same')(conv2)
        pool2 = MaxPooling3D(pool_size=(2, 2, 2))(conv2)
        drop2 = Dropout(0.5)(pool2)
        conv3 = Conv3D(128, (3, 3, 3), activation='relu', padding='same')(drop2)
        conv3 = Conv3D(128, (3, 3, 3), activation='relu', padding='same')(conv3)
        pool3 = MaxPooling3D(pool_size=(2, 2, 2))(conv3)
        drop3 = Dropout(0.3)(pool3)
        conv4 = Conv3D(256, (3, 3, 3), activation='relu', padding='same')(drop3)
        conv4 = Conv3D(256, (3, 3, 3), activation='relu', padding='same')(conv4)
        pool4 = MaxPooling3D(pool_size=(2, 2, 2))(conv4)
        drop4 = Dropout(0.3)(pool4)
        conv5 = Conv3D(512, (3, 3, 3), activation='relu', padding='same')(drop4)
        conv5 = Conv3D(512, (3, 3, 3), activation='relu', padding='same')(conv5)
    with tf.device('/job:localhost/replica:0/task:0/device:GPU:1'):
        up6 = concatenate([Conv3DTranspose(256, (2, 2, 2), strides=(2, 2, 2), padding='same')(conv5), conv4], axis=4)
        conv6 = Conv3D(256, (3, 3, 3), activation='relu', padding='same')(up6)
        conv6 = Conv3D(256, (3, 3, 3), activation='relu', padding='same')(conv6)
        up7 = concatenate([Conv3DTranspose(128, (2, 2, 2), strides=(2, 2, 2), padding='same')(conv6), conv3], axis=4)
        conv7 = Conv3D(128, (3, 3, 3), activation='relu', padding='same')(up7)
        conv7 = Conv3D(128, (3, 3, 3), activation='relu', padding='same')(conv7)
        up8 = concatenate([Conv3DTranspose(64, (2, 2, 2), strides=(2, 2, 2), padding='same')(conv7), conv2], axis=4)
        conv8 = Conv3D(64, (3, 3, 3), activation='relu', padding='same')(up8)
        conv8 = Conv3D(64, (3, 3, 3), activation='relu', padding='same')(conv8)
        up9 = concatenate([Conv3DTranspose(32, (2, 2, 2), strides=(2, 2, 2), padding='same')(conv8), conv1], axis=4)
        conv9 = Conv3D(32, (3, 3, 3), activation='relu', padding='same')(up9)
        conv9 = Conv3D(32, (3, 3, 3), activation='relu', padding='same')(conv9)
        conv10 = Conv3D(1, (1, 1, 1), activation='sigmoid')(conv9)
    with tf.device('/job:localhost/replica:0/task:0/device:CPU:0'):
        model = Model(inputs=[inputs], outputs=[conv10])
        model.compile(optimizer=optimizer(lr=lr), loss=loss_metric, metrics=metrics)
        return model

mirrored_strategy = tf.distribute.MirroredStrategy(devices=["/job:localhost/replica:0/task:0/device:GPU:0", "/job:localhost/replica:0/task:0/device:GPU:1"],
                    cross_device_ops = tf.distribute.HierarchicalCopyAllReduce())
with mirrored_strategy.scope():
    model = get_model(optimizer=Adam, loss_metric=dice_coef_loss, metrics=[dice_coef], lr=1e-3)

ResourceExhaustedError:

ResourceExhaustedError                    Traceback (most recent call last)
<ipython-input-1-7a601312fa7a> in <module>
    405     # e_drive_model_dir = '\\models\\'
    406     model_checkpoint = ModelCheckpoint('unet_seg_cs9300_3d_{epoch:04}.model', monitor=observe_var, save_best_only=False, save_freq = 1000)
--> 407     model.fit(train_x, train_y, batch_size= 2, epochs= 10000, verbose=1, shuffle=True, validation_split=0, callbacks=[model_checkpoint])
    408 
    409     model.save('unet_seg_final_3d_test.model')

~\.conda\envs\gputest\lib\site-packages\tensorflow\python\keras\engine\training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, max_queue_size, workers, use_multiprocessing, **kwargs)
    647             steps_per_epoch=steps_per_epoch,
    648             validation_steps=validation_steps,
--> 649             validation_freq=validation_freq)
    650 
    651     batch_size = self._validate_or_infer_batch_size(

~\.conda\envs\gputest\lib\site-packages\tensorflow\python\keras\engine\training_distributed.py in fit_distributed(model, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq)
    141         validation_steps=validation_steps,
    142         validation_freq=validation_freq,
--> 143         steps_name='steps_per_epoch')
    144 
    145 

~\.conda\envs\gputest\lib\site-packages\tensorflow\python\keras\engine\training_arrays.py in model_iteration(model, inputs, targets, sample_weights, batch_size, epochs, verbose, callbacks, val_inputs, val_targets, val_sample_weights, shuffle, initial_epoch, steps_per_epoch, validation_steps, validation_freq, mode, validation_in_fit, prepared_feed_values_from_dataset, steps_name, **kwargs)
    272           # `ins` can be callable in tf.distribute.Strategy + eager case.
    273           actual_inputs = ins() if callable(ins) else ins
--> 274           batch_outs = f(actual_inputs)
    275         except errors.OutOfRangeError:
    276           if is_dataset:

~\.conda\envs\gputest\lib\site-packages\tensorflow\python\keras\backend.py in __call__(self, inputs)
   3290 
   3291     fetched = self._callable_fn(*array_vals,
-> 3292                                 run_metadata=self.run_metadata)
   3293     self._call_fetch_callbacks(fetched[-len(self._fetches):])
   3294     output_structure = nest.pack_sequence_as(

~\.conda\envs\gputest\lib\site-packages\tensorflow\python\client\session.py in __call__(self, *args, **kwargs)
   1456         ret = tf_session.TF_SessionRunCallable(self._session._session,
   1457                                                self._handle, args,
-> 1458                                                run_metadata_ptr)
   1459         if run_metadata:
   1460           proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

ResourceExhaustedError: 2 root error(s) found.
  (0) Resource exhausted: OOM when allocating tensor with shape[1,32,240,240,240] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc
     [[{{node Adam/gradients/conv3d_17_1/Conv3D_grad/Conv3DBackpropInputV2}}]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info.

  (1) Resource exhausted: OOM when allocating tensor with shape[1,32,240,240,240] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc
     [[{{node Adam/gradients/conv3d_17_1/Conv3D_grad/Conv3DBackpropInputV2}}]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info.

     [[GroupCrossDeviceControlEdges_0/Adam/Adam/update_1/Const/_1070]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info.

1 successful operations.
0 derived errors ignored.

Upvotes: 0

Views: 777

Answers (1)

Susmit Agrawal
Susmit Agrawal

Reputation: 3764

Though it's late, I hope this answer helps others in the future.

Note that this has been tested on TF 2.0, so it may not work for older versions.

Short answer to the first part of the question:

MirroredStrategy() does not split the model on separate GPUs; it replicates the model on each and splits the batches. So if the model is trained using a batch size of 32 on a 2 GPU machine, each GPU gets 16 examples. The gradients are accumulated and the model is updated for all 32 examples.


How can the model itself be split?

After a lot of trial and error, I have the following:

  1. You can have individual layers and ops on separate devices, but once you wrap them under a single instance of tf.keras.Model, you can call the whole model on a single device only.

  2. The layers in a model can be referenced and used outside the model, as individual ops instead of only as a collective whole.

  3. When saving and restoring the model, you can get away with only restoring the weights and then using those weights as specified in point 2, with new instances of layers that don't have variables.

Combining these three points, one way to split the model on multiple GPUs for training and inference is to first create the graph(tf.keras.Model) on a single GPU, then replicate individual components on separate GPUs.

A bare minimum example:

def create_model():
    input = Input((None, None, 3))
    x = Conv2D(64, (3, 3), activation='relu')(input)
    y = Conv2D(64, (3, 3), activation='relu')(input)
    z = Concatenate()([x, y])
    output = Conv2D(3, (3, 3), activation='sigmoid')(z)
    return tf.keras.Model(inputs=[input], outputs=[output])

def model_graph(input, model):
    # get all layers that contain trainable parameters
    layers = []
    for layer in model.layers:
        if len(layer.trainable_variables) != 0:
            layers.append(layer)

    # use the list to access layers with trainable variables
    layer_num = 0
    with tf.device('/gpu:0'):
        x = layers[layer_num](input); layer_num += 1
    with tf.device('/gpu:1'):
        y = layers[layer_num](input); layer_num += 1
    # You can create new instances of layers that don't have variables
    z = Concatenate()([x, y])
    output = layers[layer_num](z)
    return output

    
model = create_model()

When you want to use the model on a single device, you can use:

output = model(inputs)

When you wat to split it across two devices, you can use:

output = model_graph(model, inputs)

Upvotes: 1

Related Questions