Reputation: 91
I'm fairly new to this and have little to no experience. I had a notebook running PyTorch that I wanted to run a Google Cloud TPU VM. Machine specs:
- Ubuntu
- TPU v2-8
- pt-2.0
I should have 8 cores. Correct me if I'm wrong.
So, I followed the guidelines for making the notebook TPU-compatible via XLA. I did the following:
os.environ['PJRT_DEVICE'] = 'TPU'
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
device = xm.xla_device()
print(device)
It printed
xla:0
.
model.to(device)
function.pl.MpDeviceLoader(loader, device)
xm.optimizer_step(optimizer)
functiondef _mp_fn(index):
# models creation
# data preparation
# training loop
if __name__ == '__main__':
xmp.spawn(_mp_fn, args=())
BrokenProcessPool: A process in the process pool was terminated abruptly while the future was running or pending.
I could be all totally wrong about this. So, I'm sorry for that.
If you need a further look at the code, I can share the notebook if you want.
When I follow the guidelines for single-core processing, and I don't use xmp.spawn
, I get 1.2 iterations/sec
which can be significantly increased if used all cores.
Upvotes: 0
Views: 901
Reputation: 191
PjRt runtime should be fully supported starting v4. On v2-8 you still need to use XRT runtime. For that purpose you might need to set two env variables:
os.environ['TPU_NUM_DEVICES'] = 8
os.environ['XRT_TPU_CONFIG'] = 'localservice;0;localhost:51011'
I would first suggest to test an example like https://pytorch.org/xla/release/2.0/index.html#running-on-multiple-xla-devices-with-multi-processing to make sure everything is setup anc works correctly. Then you can work on your model.
Upvotes: 0