Adham Ali
Adham Ali

Reputation: 91

Enable multiprocessing on pytorch XLA for TPU VM

I'm fairly new to this and have little to no experience. I had a notebook running PyTorch that I wanted to run a Google Cloud TPU VM. Machine specs:

- Ubuntu
- TPU v2-8
- pt-2.0

I should have 8 cores. Correct me if I'm wrong.

So, I followed the guidelines for making the notebook TPU-compatible via XLA. I did the following:

os.environ['PJRT_DEVICE'] = 'TPU'
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
device = xm.xla_device()

print(device)

It printed xla:0.

def _mp_fn(index):
     # models creation
     # data preparation
     # training loop 
if __name__ == '__main__':
  xmp.spawn(_mp_fn, args=())

BrokenProcessPool: A process in the process pool was terminated abruptly while the future was running or pending.

I could be all totally wrong about this. So, I'm sorry for that. If you need a further look at the code, I can share the notebook if you want. When I follow the guidelines for single-core processing, and I don't use xmp.spawn, I get 1.2 iterations/sec which can be significantly increased if used all cores.

Upvotes: 0

Views: 901

Answers (1)

Susie Sargsyan
Susie Sargsyan

Reputation: 191

PjRt runtime should be fully supported starting v4. On v2-8 you still need to use XRT runtime. For that purpose you might need to set two env variables:

os.environ['TPU_NUM_DEVICES'] = 8
os.environ['XRT_TPU_CONFIG'] = 'localservice;0;localhost:51011'

I would first suggest to test an example like https://pytorch.org/xla/release/2.0/index.html#running-on-multiple-xla-devices-with-multi-processing to make sure everything is setup anc works correctly. Then you can work on your model.

Upvotes: 0

Related Questions