Stepan Yakovenko
Stepan Yakovenko

Reputation: 9216

How can I enable pytorch GPU support in Google Colab?

How can I enable pytorch to work on GPU?

I've installed pytorch successfully in google colab notebook: enter image description here Tensorflow reports GPU to be in place:

enter image description here

But torch.device function fails somehow:

enter image description here

How can I fix this?

Upvotes: 4

Views: 34817

Answers (4)

abdullahselek
abdullahselek

Reputation: 8463

You can enable GPU by clicking on "Change Runtime Type" under the "Runtime" menu. There is also "TPU" support available in these days.

Runtime Menu

Runtime options

You can define define device using torch.device:

import torch

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Upvotes: 9

Christian
Christian

Reputation: 31

In addition to having GPU enabled under the menu "Runtime" -> Change Runtime Type, GPU support is enabled with:

import torch

if torch.cuda.is_available():
  device = torch.device("cuda")
else:
  device = torch.device("cpu")

Upvotes: 3

Navid Rezaei
Navid Rezaei

Reputation: 1041

You can use this tutorial: https://medium.com/@nrezaeis/pytorch-in-google-colab-640e5d166f13

For example for CUDA 9.2 and Python 3.6:

!pip3 install http://download.pytorch.org/whl/cu92/torch-0.4.1-cp36-cp36m-linux_x86_64.whl
!pip3 install torchvision

Now to check the GPU device using PyTorch:

torch.cuda.get_device_name(0)

My result in Google Colab is Tesla K80.

Upvotes: 4

Adam Bittlingmayer
Adam Bittlingmayer

Reputation: 1277

I hit the same issue.

Try installing Torch like this:

# http://pytorch.org/
from os import path
from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag
platform = '{}{}-{}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag())

accelerator = 'cu80' #'cu80' if path.exists('/opt/bin/nvidia-smi') else 'cpu'
print('Platform:', platform, 'Accelerator:', accelerator)

!pip install --upgrade --force-reinstall -q http://download.pytorch.org/whl/{accelerator}/torch-0.4.0-{platform}-linux_x86_64.whl torchvision

import torch
print('Torch', torch.__version__, 'CUDA', torch.version.cuda)
print('Device:', torch.device('cuda:0'))

The output should be:

Platform: cp36-cp36m Accelerator: cu80 Torch 0.4.0 CUDA 8.0.61
Device: cuda:0

Some snippets floating around use torch-0.3.0.post4-{platform}-linux_x86_64.whl, which will lead to the same error, because device is a Torch 4 feature. If you have already installed the wrong version, you may need to do !pip uninstall torch.

Also be sure to enable GPU under Edit > Notebook settings > Hardware accelerator.

Upvotes: 7

Related Questions