Pranjal Garg
Pranjal Garg

Reputation: 1

Attempting multiprocessing with GCP TPUs but the shell dies unexpectedly

I have access to 32 preemptible Cloud TPU v4 chips in a US zone and am attempting to run my PyTorch model using the following Python code:

import os 
import sys 
import pickle 
import torch_xla.distributed.xla_multiprocessing as xmp 
from transformers 
import AutoTokenizer 
import main_utils as utils 
import multiprocessing as mp

lock = mp.Manager().Lock()

def _mp_fn(i): 

    A_tasks, B_tasks, desired_output_lengths, keys = utils.get_total_tasks()
    
    
    import torch
    import torch_xla.core.xla_model as xm
    from modeling_mamba import MambaForCausalLM
    
    DEVICE_NAME = xm.xla_device()
    tokenizer = AutoTokenizer.from_pretrained("~/Mamba-1B/")
    model = MambaForCausalLM.from_pretrained("~/Mamba-1B/", torch_dtype="auto",
                                             device_map="auto", low_cpu_mem_usage=True)
    model = model.to_empty(device='cpu')
    model.apply(lambda module: module.reset_parameters() if hasattr(module, 'reset_parameters') else None)
    model = model.to(xm.xla_device())
    
    params_to_save = ["out_proj_y"]
    
    def generate_logits(task, logits_list, desired_output_length):
        for i in range(48):
            layer_to_save = [i + 1]
            input_ids = tokenizer.encode(task, return_tensors="pt").to(DEVICE_NAME)
            model.saved_activation(params_to_save, layer_to_save, precision='r')
            output = model.generate(input_ids, max_length=desired_output_length, no_repeat_ngram_size=2)
    
            saved_dict = model.reset_everything_and_save()
            param = saved_dict[f"out_proj_y_{i + 1}"][0, -1, :].to(DEVICE_NAME)
    
            xm.all_gather(param)
            logits = model.get_unembed_for_layer(param, norm="False")
    
            xm.all_gather(logits)
            logits_list.append(logits.to("cpu"))
    
            del input_ids, output, saved_dict, param, logits
    
    task_dict = {}
    for A_task, B_task, dol, key in zip(A_tasks, B_tasks, desired_output_lengths, keys):
        A_logits = []
        B_logits = []
        generate_logits(A_task, A_logits, dol)
        generate_logits(B_task, B_logits, dol)
    
        aux_dict = {'A': A_logits, 'B': B_logits}
        task_dict[key] = aux_dict
    
        device = xm.xla_device()
        with lock:
            print(f'Process {i}, Device {device}')
    
    xm.mark_step()
    with open(f'~/main_file.pickle', 'wb') as handle:
        pickle.dump(task_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)

if name == "main": 
    xmp.spawn(_mp_fn, args=(), nprocs=8, start_method='fork')`

To execute this code, I'm using Cloud Shell with the command (after setting up TPUs and installing libraries on all 8 workers):

gcloud compute tpus tpu-vm ssh ${TPU_NAME} --zone=${ZONE} --project=${PROJECT_ID} --worker=all --command="PJRT_DEVICE=TPU python3 ~test.py"

I expected the model to execute the tasks defined in _mp_fn() across all TPUs, gather and save logits as specified, and finally store results in main_file.pickle. However, after initiating the command in Cloud Shell, the shell printed some anticipated warnings during model loading but did not proceed further. Instead, after some time, the shell session terminated unexpectedly. I am not sure how to troubleshoot this?

Upvotes: 0

Views: 41

Answers (0)

Related Questions