Reputation: 1
I have access to 32 preemptible Cloud TPU v4 chips in a US zone and am attempting to run my PyTorch model using the following Python code:
import os
import sys
import pickle
import torch_xla.distributed.xla_multiprocessing as xmp
from transformers
import AutoTokenizer
import main_utils as utils
import multiprocessing as mp
lock = mp.Manager().Lock()
def _mp_fn(i):
A_tasks, B_tasks, desired_output_lengths, keys = utils.get_total_tasks()
import torch
import torch_xla.core.xla_model as xm
from modeling_mamba import MambaForCausalLM
DEVICE_NAME = xm.xla_device()
tokenizer = AutoTokenizer.from_pretrained("~/Mamba-1B/")
model = MambaForCausalLM.from_pretrained("~/Mamba-1B/", torch_dtype="auto",
device_map="auto", low_cpu_mem_usage=True)
model = model.to_empty(device='cpu')
model.apply(lambda module: module.reset_parameters() if hasattr(module, 'reset_parameters') else None)
model = model.to(xm.xla_device())
params_to_save = ["out_proj_y"]
def generate_logits(task, logits_list, desired_output_length):
for i in range(48):
layer_to_save = [i + 1]
input_ids = tokenizer.encode(task, return_tensors="pt").to(DEVICE_NAME)
model.saved_activation(params_to_save, layer_to_save, precision='r')
output = model.generate(input_ids, max_length=desired_output_length, no_repeat_ngram_size=2)
saved_dict = model.reset_everything_and_save()
param = saved_dict[f"out_proj_y_{i + 1}"][0, -1, :].to(DEVICE_NAME)
xm.all_gather(param)
logits = model.get_unembed_for_layer(param, norm="False")
xm.all_gather(logits)
logits_list.append(logits.to("cpu"))
del input_ids, output, saved_dict, param, logits
task_dict = {}
for A_task, B_task, dol, key in zip(A_tasks, B_tasks, desired_output_lengths, keys):
A_logits = []
B_logits = []
generate_logits(A_task, A_logits, dol)
generate_logits(B_task, B_logits, dol)
aux_dict = {'A': A_logits, 'B': B_logits}
task_dict[key] = aux_dict
device = xm.xla_device()
with lock:
print(f'Process {i}, Device {device}')
xm.mark_step()
with open(f'~/main_file.pickle', 'wb') as handle:
pickle.dump(task_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
if name == "main":
xmp.spawn(_mp_fn, args=(), nprocs=8, start_method='fork')`
To execute this code, I'm using Cloud Shell with the command (after setting up TPUs and installing libraries on all 8 workers):
gcloud compute tpus tpu-vm ssh ${TPU_NAME} --zone=${ZONE} --project=${PROJECT_ID} --worker=all --command="PJRT_DEVICE=TPU python3 ~test.py"
I expected the model to execute the tasks defined in _mp_fn() across all TPUs, gather and save logits as specified, and finally store results in main_file.pickle. However, after initiating the command in Cloud Shell, the shell printed some anticipated warnings during model loading but did not proceed further. Instead, after some time, the shell session terminated unexpectedly. I am not sure how to troubleshoot this?
Upvotes: 0
Views: 41