TFS19
TFS19

Reputation: 13

How can I get Pytorch Lightning epoch progress bar to display when training on Google Cloud TPUs?

When I run my code for GPU or CPU on my local machine or even on a Google colab TPU I get a progress bar showing the epoch/steps. However when I make the minimal adjustments to run the code on Google cloud TPUs, I can no longer get the bar to appear. I get the following message:

warning_cache.warn(
WARNING:root:Unsupported nprocs (8), ignoring...

Based on TPU usage the code is working and training is happening. The TPU vm is using python 3.8.10, torch==2.0.0, torch-xla==2.0, torchmetrics==0.11.4, torchvision==0.15.1, pl==2.0.2, transformers==4.29.2.

Here's the end of my code for reference:

if __name__ == '__main__':
    data_module = IsaDataModule(train_df, val_df, test_df, tokenizer, batch_size=BATCH_SIZE)
    data_module.setup()
    model = IsaModel()
    
    checkpoint_callback = ModelCheckpoint(
        dirpath='spec1_ckpt',
        filename='best_checkpoint',
        save_top_k=1,
        verbose=True,
        monitor='val_loss',
        mode='min'
    )
    
    #8 devices per TPU
    trainer = pl.Trainer(
        callbacks=[checkpoint_callback],
        max_epochs=N_EPOCHS,
        accelerator='tpu',
        devices=8
    )

    trainer.fit(model, data_module)

I've tried some of the fixes from this thread: https://github.com/Lightning-AI/lightning/issues/1112 But in that thread the issue is with colab and not cloud vm's. I've also tried using XRT runtime instead of PJRT, but in that case the training doesn't work at all. Any help would be appreciated, thanks.

Upvotes: 0

Views: 1357

Answers (1)

Susie Sargsyan
Susie Sargsyan

Reputation: 191

it is not recommended to enable progress bar on TPUs since it triggers device-host communication which causes significant slowdown. In any case, it should work. Can you try explicitly passing enable_progress_bar=True to the Trainer and see if that helps?

Upvotes: 1

Related Questions