FOXAAFOX
FOXAAFOX

Reputation: 61

pytorch cnn model stop at loss.backward() without any prompt?

My aim is to make a five-category text classification

I am running bert fine tuning with cnnbase model but my project stops at loss.backward() without any prompt in cmd.

My program runs successfully in rnn base such as lstm and rcnn.

But when I am running some cnnbase model a strange bug appears.

My cnn model code:

import torch
import torch.nn as nn
import torch.nn.functional as F
# from ..Models.Conv import Conv1d
from transformers.modeling_bert import BertPreTrainedModel, BertModel
n_filters = 200
filter_sizes = [2,3,4]
class BertCNN(BertPreTrainedModel):
    def __init__(self, config):
        super(BertPreTrainedModel, self).__init__(config)
        self.num_filters = n_filters
        self.filter_sizes = filter_sizes
        self.bert = BertModel(config)
        for param in self.bert.parameters():
            param.requires_grad = True
        self.convs = nn.ModuleList(
            [nn.Conv2d(1, self.num_filters, (k, config.hidden_size))
                for k in self.filter_sizes])
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.fc_cnn = nn.Linear(self.num_filters *
                                len(self.filter_sizes), config.num_labels)

    def conv_and_pool(self, x, conv):
        x = F.relu(conv(x)).squeeze(3)
        x = F.max_pool1d(x, x.size(2)).squeeze(2)
        return x

    def forward(self, input_ids,
                attention_mask=None, token_type_ids=None, head_mask=None):
        outputs = self.bert(input_ids,
                            attention_mask=attention_mask,
                            token_type_ids=token_type_ids,
                            head_mask=head_mask)
        encoder_out, text_cls = outputs
        out = encoder_out.unsqueeze(1)
        out = torch.cat([self.conv_and_pool(out, conv)
                         for conv in self.convs], 1)
        out = self.dropout(out)
        out = self.fc_cnn(out)
        return out

My train code:

        for step, batch in enumerate(data):
            self.model.train()
            batch = tuple(t.to(self.device) for t in batch)
            input_ids, input_mask, segment_ids, label_ids = batch
            print("input_ids, input_mask, segment_ids, label_ids SIZE: \n")   
            print(input_ids.size(), input_mask.size(),segment_ids.size(), label_ids.size()) 
            # torch.Size([2, 80]) torch.Size([2, 80]) torch.Size([2, 80]) torch.Size([2])
            logits = self.model(input_ids, segment_ids, input_mask)
            print("logits and label ids size: ",logits.size(), label_ids.size())
            # torch.Size([2, 5]) torch.Size([2])
            loss = self.criterion(output=logits, target=label_ids)
            if len(self.n_gpu) >= 2:
                loss = loss.mean()
            if self.gradient_accumulation_steps > 1:
                loss = loss / self.gradient_accumulation_steps
            if self.fp16:
                with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                    scaled_loss.backward()
                clip_grad_norm_(amp.master_params(self.optimizer), self.grad_clip)
            else:
                loss.backward() # I debug find that the program stop at this line without any error prompt

enter image description here

change the batchsize to 1 the bug still occured

the step1 logits :

logits tensor([[ 0.8831, -0.0368, -0.2206, -2.3484, -1.3595]], device='cuda:1', grad_fn=)

the step1 loss:

tensor(1.5489, device='cuda:1', grad_fn=NllLossBackward>)

but why can't loss.backward()?

Upvotes: 3

Views: 863

Answers (2)

won5830
won5830

Reputation: 592

I also met the same problem. And in my case, this issue originated from of pytorch's version compatibility. Issue got resolved when I upgraded my pytorch to newest release(1.5.1 -> 1.8.x). I think this kind of issue came from pytorch's nn.Conv class... since I found my script running well when I removed them.

Upvotes: 0

FOXAAFOX
FOXAAFOX

Reputation: 61

I tried to run my program on linux platform, and it ran successfully.

Therefore, it is very likely that it is caused by different os

Previous os:win 10

Upvotes: 3

Related Questions