Reputation: 11
During the TPU Research Program, I tried to use TPU V4-64 as I have 32 free on-demand TPU V4 chips. However, unlike TPU V4-8, the test codes provided in the tutorial didn't work whenever I used TPU V4-16 ~ TPU V4-64. Not only the test codes, my sample codes also didn't work.
GCP TPU Tutorial: https://cloud.google.com/tpu/docs/tutorials/resnet-pytorch
I follow the tutorial for setting up torch_xla in the above tutorial.
[TRIAL 1]
gcloud compute tpus queued-resources create RESOURCE_NAME \
--node-id NODE_NAME \
--project PROJECT_NAME \
--zone us-central2-b \
--accelerator-type v4-64 \
--runtime-version tpu-ubuntu2204-base
gcloud compute tpus tpu-vm ssh NODE_NAME --zone=us-central2-b
pip install torch~=2.3.0 torch_xla[tpu]~=2.3.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html
git clone --depth=1 --branch r2.3 https://github.com/pytorch/xla.git
PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1
[RESULT 1 ] Then, I got the below log messages without any progress for a long time. When I used TPU V4-8, training codes worked properly(= utilizing full xla::0 ~ xla::3, and printing training/test loss)
WARNING:root:PJRT is now the default runtime. For more information, see https://github.com/pytorch/xla/blob/master/docs/pjrt.md
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1720161247.044048 21553 pjrt_api.cc:100] GetPjrtApi was found for tpu at /home/hungwon3626/.local/lib/python3.10/site-packages/libtpu/libtpu.so
I0000 00:00:1720161247.044129 21553 pjrt_api.cc:79] PJRT_Api is set for device type tpu
I0000 00:00:1720161247.044136 21553 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.46. The framework PJRT API version is 0.46.
...
[TRIAL 2] I paused the runtime, then I ran the below codes:
>>> import os
>>> os.environ['TPU_NUM_DEVICES'] = '32'
# I also retried the same codes without the above lines, and got the same result.
>>> import torch_xla as xla
>>> xla.device()
[RESULT 2 ]: Then, I got the below error messages:
WARNING:root:PJRT is now the default runtime. For more information, see https://github.com/pytorch/xla/blob/master/docs/pjrt.md
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1720161247.044048 21553 pjrt_api.cc:100] GetPjrtApi was found for tpu at /home/hungwon3626/.local/lib/python3.10/site-packages/libtpu/libtpu.so
I0000 00:00:1720161247.044129 21553 pjrt_api.cc:79] PJRT_Api is set for device type tpu
I0000 00:00:1720161247.044136 21553 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.46. The framework PJRT API version is 0.46.
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/hungwon3626/.local/lib/python3.10/site-packages/torch_xla/torch_xla.py", line 21, in device
return xm.xla_device(index)
File "/home/hungwon3626/.local/lib/python3.10/site-packages/torch_xla/core/xla_model.py", line 212, in xla_device
return runtime.xla_device(n, devkind)
File "/home/hungwon3626/.local/lib/python3.10/site-packages/torch_xla/runtime.py", line 95, in wrapper
return fn(*args, **kwargs)
File "/home/hungwon3626/.local/lib/python3.10/site-packages/torch_xla/runtime.py", line 124, in xla_device
return torch.device(torch_xla._XLAC._xla_get_default_device())
RuntimeError: Bad StatusOr access: UNKNOWN: TPU initialization failed: Failed to establish SliceBuilder grpc channel to 10.130.15.203:8471.
[EXPECTATION] I want to utilize full 32 chips to train a single large model on xla devices.
[package version]
torch 2.3.1
torch-xla 2.3.0
torchvision 0.18.1
Python 3.10.6
Upvotes: 1
Views: 456
Reputation: 11
You are using the commands used to install and run on a single tpu vm host. You have to use the following format.
gcloud compute tpus tpu-vm ssh NODE --zone='zone' --worker=all --command='command to install and run'
A single tpu v4-64 pod contains 8 vms. 1 vm for every v4-8 chips and they are connected together by tpu pod interconnect. You have to use the above format to run those commands on all the eight vms at the same time.
Upvotes: 1