hayo
hayo

Reputation: 1

embedding dimension and tokenizer max length mismatch while using pretrained gpt model. RuntimeError target size mismatch

I want to evaluate pretrained gpt model. gpt model's embedding layer is (tokens_embed): Embedding(40478, 768) If I set tokenizer's max_length as 512, RuntimeError: Expected target size [2, 40478], got [2, 512] appears. How can I fix this?

Please help me.

This is my code.

from transformers import AutoTokenizer, OpenAIGPTLMHeadModel
tokenizer = AutoTokenizer.from_pretrained("openai-community/openai-gpt")
from torch.utils.data import DataLoader

def preprocess_function(examples):
        inputs = examples['text']
        tokenizer.pad_token = '</s>'
        model_inputs = tokenizer(inputs, max_length=512, padding='max_length', truncation=True)

        return model_inputs


column_names = ['text']
train_dataset = sm.map(
            preprocess_function,
            batched=True,
            remove_columns=column_names
        )

train_dataset.set_format(type="torch")
train_dataloader = DataLoader(
        train_dataset, shuffle=True, batch_size=2)

model = OpenAIGPTLMHeadModel.from_pretrained("openai-community/openai-gpt")
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=5e-6)


def evaluate(model, data_loader, criterion, device):
    model.eval()
    total_loss = 0

    with torch.no_grad():
        for i, batch in enumerate(data_loader):
            torch.cuda.nvtx.range_push("Batch iterations {}".format(i))
            input_data = batch['input_ids']
            target = batch['input_ids'].clone()

            output = model(input_data)

            loss = criterion(output.logits, target)

            total_loss += loss.item()


    return total_loss / len(data_loader)


i = 0
for epoch in tqdm(range(10)):
    val_loss = evaluate(model, train_dataloader, criterion, 'cuda')
    val_perplexity = math.exp(val_loss)
    print(f'Epoch: {epoch+1}, Validation Loss: {val_loss:.4f}, Validation Perplexity: {val_perplexity:.4f}')
    i += 1

Upvotes: 0

Views: 73

Answers (0)

Related Questions