Reputation: 7457
I am trying to train a large graph-embedding using WatchYourStep algorithm using StellarGraph.
For some reason, the model is only trained on a CPU and not utilizing the GPUs.
using:
tf.debugging.set_log_device_placement(True)
)with tf.device('/GPU:0'):
tf.distribute.MirroredStrategy()
.Nevertheless, when running nvidia-smi, I don't see any activity on the GPUs, and the training is very slow.
How to debug this?
def watch_your_step_model():
'''use the config to geenrate the WatchYourStep model'''
cfg = load_config()
generator = generator_for_watch_your_step()
num_walks = cfg['num_walks']
embedding_dimension = cfg['embedding_dimension']
learning_rate = cfg['learning_rate']
wys = WatchYourStep(
generator,
num_walks=num_walks,
embedding_dimension=embedding_dimension,
attention_regularizer=regularizers.l2(0.5),
)
x_in, x_out = wys.in_out_tensors()
model = Model(inputs=x_in, outputs=x_out)
model.compile(loss=graph_log_likelihood, optimizer=optimizers.Adam(learning_rate))
return model, generator, wys
def train_watch_your_step_model(epochs = 3000):
cfg = load_config()
batch_size = cfg['batch_size']
steps_per_epoch = cfg['steps_per_epoch']
callbacks, checkpoint_file = watch_your_step_callbacks(cfg)
# strategy = tf.distribute.MirroredStrategy()
# print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
# with strategy.scope():
model, generator, wys = watch_your_step_model()
train_gen = generator.flow(batch_size=batch_size, num_parallel_calls=8)
train_gen.prefetch(20480000)
history = model.fit(
train_gen,
epochs=epochs,
verbose=1,
steps_per_epoch=steps_per_epoch,
callbacks = callbacks
)
copy_last_trained_wys_weights_to_data()
return history, checkpoint_file
with tf.device('/GPU:0'):
train_watch_your_step_model()
Upvotes: 1
Views: 882
Reputation: 1
I just followed this instructions : https://github.com/stellargraph/stellargraph/issues/546.
It worked for me.
Basically you have to edit the file setup.py from stellargraph github and remove the tensorflow requirement (line 25 and 27 https://github.com/stellargraph/stellargraph/blob/develop/setup.py) .
Upvotes: 0