Aditta Das
Aditta Das

Reputation: 36

Pytorch : RuntimeError: mat1 dim 1 must match mat2 dim 0

using resnet50 model. Customize the last layer and it showing runtime error..Im new to PyTorch and I keep getting the error mat1 dim1 must match mat1 dim0

this is my code for the network

from torchvision import models
model = models.resnet50(pretrained=True)

for param in model.parameters():
    param.requires_grad = False
    
class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()
    def forward(self, x):
        return x

    
model.avgpool = Identity()
model.fc = nn.Linear(2048, 2, bias=True)

for param in model.fc.parameters():
    param.requires_grad = True
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)



def train(num_epoch, model):
    for epoch in range(0, 3):
        losses = []
        model.train()
        loop = tqdm(enumerate(train_loader), total=len(train_loader))
        for batch_idx, (data, targets) in loop:
            data = data.to(device=device)
            targets = targets.to(device=device)
            scores = model.forward(data)
            loss = criterion(scores, targets)

            optimizer.zero_grad()
            losses.append(loss)
            loss.backward()
            optimizer.step()

            loop.set_description(f"Epoch {epoch+1}/{num_epoch} process: {int((batch_idx / len(train_loader)) * 100)}")
            loop.set_postfix(loss=loss.data.item())

train(1, model)

RuntimeError: mat1 dim 1 must match mat2 dim 0

Upvotes: 0

Views: 3839

Answers (1)

Shai
Shai

Reputation: 114786

This error comes from the nn.Linear you changed.
As you recall, nn.Linear computes a simple matrix dot product, and therefore the input dimension coming from the previous layer must equal the weight matrix shape (you set it to 2048).
my guess is that since you removed the model.avgpool layer, you now have more than 2048 input dimension resulting with the error you got.


BTW, you do not need to implement "identity" layer yourself, pytorch already has nn.Identity.

Upvotes: 3

Related Questions