gnsrnjs
gnsrnjs

Reputation: 11

TPU V4-64 Runtime Error: TPU initialization failed: Failed to establish SliceBuilder grpc channel

During the TPU Research Program, I tried to use TPU V4-64 as I have 32 free on-demand TPU V4 chips. However, unlike TPU V4-8, the test codes provided in the tutorial didn't work whenever I used TPU V4-16 ~ TPU V4-64. Not only the test codes, my sample codes also didn't work.

GCP TPU Tutorial: https://cloud.google.com/tpu/docs/tutorials/resnet-pytorch

I follow the tutorial for setting up torch_xla in the above tutorial.

[TRIAL 1]

  1. Create TPU VM
 gcloud compute tpus queued-resources create RESOURCE_NAME \
  --node-id NODE_NAME \
  --project PROJECT_NAME \
  --zone us-central2-b \
  --accelerator-type v4-64 \
  --runtime-version tpu-ubuntu2204-base
  1. Connect to the TPU VM
gcloud compute tpus tpu-vm ssh  NODE_NAME --zone=us-central2-b  
  1. Install torch/ torch_xla
pip install torch~=2.3.0 torch_xla[tpu]~=2.3.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html
  1. Clone official torch_xla github repo
git clone --depth=1 --branch r2.3 https://github.com/pytorch/xla.git
  1. Run test Code # ResNet training
PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1

[RESULT 1 ] Then, I got the below log messages without any progress for a long time. When I used TPU V4-8, training codes worked properly(= utilizing full xla::0 ~ xla::3, and printing training/test loss)

WARNING:root:PJRT is now the default runtime. For more information, see https://github.com/pytorch/xla/blob/master/docs/pjrt.md
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1720161247.044048   21553 pjrt_api.cc:100] GetPjrtApi was found for tpu at /home/hungwon3626/.local/lib/python3.10/site-packages/libtpu/libtpu.so
I0000 00:00:1720161247.044129   21553 pjrt_api.cc:79] PJRT_Api is set for device type tpu
I0000 00:00:1720161247.044136   21553 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.46. The framework PJRT API version is 0.46. 

...

[TRIAL 2] I paused the runtime, then I ran the below codes:


>>> import os 
>>> os.environ['TPU_NUM_DEVICES'] = '32' 
# I also retried the same codes without the above lines, and got the same result.
>>> import torch_xla as xla
>>> xla.device()

[RESULT 2 ]: Then, I got the below error messages:

WARNING:root:PJRT is now the default runtime. For more information, see https://github.com/pytorch/xla/blob/master/docs/pjrt.md
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1720161247.044048   21553 pjrt_api.cc:100] GetPjrtApi was found for tpu at /home/hungwon3626/.local/lib/python3.10/site-packages/libtpu/libtpu.so
I0000 00:00:1720161247.044129   21553 pjrt_api.cc:79] PJRT_Api is set for device type tpu
I0000 00:00:1720161247.044136   21553 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.46. The framework PJRT API version is 0.46.
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/hungwon3626/.local/lib/python3.10/site-packages/torch_xla/torch_xla.py", line 21, in device
    return xm.xla_device(index)
  File "/home/hungwon3626/.local/lib/python3.10/site-packages/torch_xla/core/xla_model.py", line 212, in xla_device
    return runtime.xla_device(n, devkind)
  File "/home/hungwon3626/.local/lib/python3.10/site-packages/torch_xla/runtime.py", line 95, in wrapper
    return fn(*args, **kwargs)
  File "/home/hungwon3626/.local/lib/python3.10/site-packages/torch_xla/runtime.py", line 124, in xla_device
    return torch.device(torch_xla._XLAC._xla_get_default_device())
RuntimeError: Bad StatusOr access: UNKNOWN: TPU initialization failed: Failed to establish SliceBuilder grpc channel to 10.130.15.203:8471.

[EXPECTATION] I want to utilize full 32 chips to train a single large model on xla devices.

[package version]

torch                    2.3.1
torch-xla                2.3.0
torchvision              0.18.1
Python                   3.10.6

Upvotes: 1

Views: 456

Answers (1)

KATHIR K S
KATHIR K S

Reputation: 11

You are using the commands used to install and run on a single tpu vm host. You have to use the following format.

gcloud compute tpus tpu-vm ssh NODE --zone='zone' --worker=all --command='command to install and run'

A single tpu v4-64 pod contains 8 vms. 1 vm for every v4-8 chips and they are connected together by tpu pod interconnect. You have to use the above format to run those commands on all the eight vms at the same time.

https://pytorch.org/xla/release/2.3/index.html#:~:text=gcloud%20alpha%20compute%20tpus%20tpu%2Dvm%20ssh%20%24USER%2Dpjrt%20%2D%2Dzone%3D%24ZONE%20%2D%2Dproject%3D%24PROJECT%20%2D%2Dworker%3Dall%20%2D%2Dcommand%3D%22PJRT_DEVICE%3DTPU%20python3%20train_mnist_xla.py

Upvotes: 1

Related Questions