Reputation: 1
I'm new to using torch and calculate on huge amount of data. I want to create embeddings for an large text corpus and write my embedding function, it works well but it seems pretty slow, so I'm not sure if it really uses the GPU or does still uses the CPU.
import torch
from transformers import BartTokenizer, BartForConditionalGeneration
model_name = 'facebook/bart-base'
def load_model():
# load the pretrained bart model
model = BartForConditionalGeneration.from_pretrained(model_name)
# load the tokenizer
tokenizer = BartTokenizer.from_pretrained(model_name)
return model, tokenizer
def calculate_text_embeddings(text, model, tokenizer):
# tokenize the text
inputs = tokenizer(text, return_tensors="pt", max_length=512, truncation=True)
# generate the embedding
with torch.no_grad():
outputs = model.encoder(input_ids=inputs.input_ids, attention_mask=inputs.attention_mask)
# get the last hidden state
last_hidden_states = outputs.encoder_last_hidden_state
return last_hidden_states
return embeddings
torch.cuda.is_available()
returns true
I want to calculate the embeddings on the GPU using CUDA.
Upvotes: 0
Views: 176
Reputation: 5283
You need to explicitly move the model and the model inputs to the GPU.
You can run nvidia-smi
to verify things are running on the GPU.
import torch
from transformers import BartTokenizer, BartForConditionalGeneration
model_name = 'facebook/bart-base'
device = 'cuda'
def load_model():
# load the pretrained bart model
model = BartForConditionalGeneration.from_pretrained(model_name)
model.to(device)
# load the tokenizer
tokenizer = BartTokenizer.from_pretrained(model_name)
return model, tokenizer
def calculate_text_embeddings(text, model, tokenizer):
# tokenize the text
inputs = tokenizer(text, return_tensors="pt", max_length=512, truncation=True)
# generate the embedding
with torch.no_grad():
outputs = model.encoder(input_ids=inputs.input_ids.to(device), attention_mask=inputs.attention_mask.to(device))
# get the last hidden state
last_hidden_states = outputs.encoder_last_hidden_state
return last_hidden_states.cpu()
Upvotes: 0