Reputation: 591
I am new to mxnet
and running a script lightly modified from the documentation on RNN with gluon. I modified the code so that I am working with strictly numerical time series rather than an NLP problem. Everything was running great until I modified this line:
context = mx.gpu()
to
GPU_COUNT = 3
context = [mx.gpu(i) for i in range(GPU_COUNT)]
now the variable initialization triggers an error that causes a crash, in particular at this line:
model.collect_params().initialize(mx.init.Xavier(), ctx=context)
The line causes this error:
mxnet.base.MXNetError: include/mxnet/./base.h:388: Invalid context string[]
I have only been using mxnet
for a few days, so I am not very knowledgeable about what could go wrong. However, I have run another sample script - for an MLP - where I also swapped in multiple gpus for a single one, and that ran fine.
This made me think that it is the RNN, and indeed when I removed the RNN portion of the code (so that it's just essentially a feed forward network) it runs the troublesome line just fine with any valid number of GPUs. Also, I tried both with a 'rnn_relu' option and a 'gru' option, and they both failed with the same error.
So my question is: do mxnet
RNNs work with multuple GPUs (one machine) in mxnet
currently via the gluon API? I don't see this discussed one way or the other in the docs, although I have seen some discussions on github as to certain functions not being implemented for multi device use. How would I confirm this theory? Also are there other explanations I should be checking into?
Upvotes: 0
Views: 387
Reputation: 1063
Yes, you can train RNNs across multiple GPU (and multiple machines) in MXNet. I just confirmed the code below works with MXNet v1.3.0 on a machine with 4 GPUs.
import mxnet as mx
GPU_COUNT = 4
context = [mx.gpu(i) for i in range(GPU_COUNT)]
model = mx.gluon.rnn.RNN(hidden_size=10, num_layers=1)
model.collect_params().initialize(mx.init.Xavier(), ctx=context)
You might want to double check that nothing is overriding your context, since it looks like you're using an empty context here (i.e. string[]
). You'll also get a similar error when trying to create an array on multiple contexts at the same time.
mx.nd.zeros(shape=(10,10), ctx=context)
Gives the following error (notice the context contains multiple devices):
MXNetError: [20:15:03] include/mxnet/./base.h:388: Invalid context string [gpu(0), gpu(1), gpu(2), gpu(3)]
Upvotes: 1