Reputation: 21
I am trying to manage the checkpoints of my Pytorch model through torch.save():
Pytorch 1.12.0 and Python 3.7
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict()
}, full_path)
But I am getting the following warning for model.state_dict():
/home/francesco/anaconda3/envs/env/lib/python3.7/site-packages/torch/nn/modules/module.py:1384: UserWarning: positional arguments and argument "destination" are deprecated. nn.Module.state_dict will not accept them in the future. Refer to https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict for details.
I had a look at the implementation of state_dict() here but I still don't get why I am getting the error since len(args) should be 0:
def state_dict(self, *args, destination=None, prefix='', keep_vars=False):
warn_msg = []
if len(args) > 0:
warn_msg.append('positional arguments')
if destination is None:
destination = args[0]
if len(args) > 1 and prefix == '':
prefix = args[1]
if len(args) > 2 and keep_vars is False:
keep_vars = args[2]
if destination is not None:
warn_msg.append('argument "destination"')
else:
destination = OrderedDict()
destination._metadata = OrderedDict()
if warn_msg:
# DeprecationWarning is ignored by default
warnings.warn(
" and ".join(warn_msg) + " are deprecated. nn.Module.state_dict will not accept them in the future. "
"Refer to https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict for details.")
return self._state_dict_impl(destination, prefix, keep_vars)
For the sake of completeness, here's the model:
import torch
import torch.nn as nn
import torch.nn.functional as F
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv3d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1)
self.pool1 = nn.MaxPool3d(kernel_size=2)
self.conv2 = nn.Conv3d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
self.pool2 = nn.MaxPool3d(kernel_size=2)
self.dropout = nn.Dropout(0.5)
self.fc1 = nn.Linear(16 * 16 * 16 * 64, 2)
self.sig1 = nn.Sigmoid()
def forward(self, x):
x = F.relu(self.pool1(self.conv1(x)))
x = F.relu(self.pool2(self.conv2(x)))
x = x.view(-1, 16 * 16 * 16 * 64)
x = self.dropout(x)
x = self.sig1(self.fc1(x))
return x
Anyone knows what I am missing? Thank you!
Upvotes: 2
Views: 1242
Reputation: 560
I have the same error, as a result it is not logging dict training logs. I'm training using PyTorch Lightning in DDP. I works on single GPU but gives this warming on multi-gpu system with DDP.
Upvotes: 0