Reputation: 6338
I am trying to create a copy of a nn.Sequential network. For example, the following is the easiest way to do the same-
net = nn.Sequential(
nn.Conv2d(16, 32, 3, stride=2),
nn.ReLU(),
nn.Conv2d(32, 64, 3, stride=2),
nn.ReLU(),
)
net_copy = nn.Sequential(
nn.Conv2d(16, 32, 3, stride=2),
nn.ReLU(),
nn.Conv2d(32, 64, 3, stride=2),
nn.ReLU(),
)
However, it is not so great to define the network again. I tried the following ways but it didn't work-
net_copy = nn.Sequential(net)
: In this approach, it seems that net_copy
is just a shared pointer of net
net_copy = nn.Sequential(*net.modules())
: In this approach, net_copy
contains many more layers.Finally, I tired deepcopy
in the following way which worked fine-
net_copy = deepcopy(net)
However, I am wondering if it is the proper way. I assume it is fine because it works.
Upvotes: 5
Views: 4937
Reputation: 147
Well, I just use torch.load
and torch.save
with io.BytesIO
import io, torch
# write to a buffer
buffer = io.BytesIO()
torch.save(model, buffer) #<--- model is some nn.module
print(buffer.tell()) #<---- no of bytes written
del model
# read from buffer
buffer.seek(0) #<--- must see to origin every time before reading
model = torch.load(buffer)
del buffer
Upvotes: 1