Minh-Long Luu
Minh-Long Luu

Reputation: 2731

PyTorch: load weights from another model without saving

Assume that I have two models in PyTorch, how can I load the weights of model 1 by weights of model 2 without saving the weights?

Like this:

model1.weights = model2.weights

In TensorFlow I can do this:

variables1 = model1.trainable_variables
variables2 = model2.trainable_variables
for v1, v2 in zip(variables1, variables2):
    v1.assign(v2.numpy())

Upvotes: 3

Views: 3438

Answers (3)

Minh-Long Luu
Minh-Long Luu

Reputation: 2731

Adding another way to the solution, although it is the same as load_state_dict(), but might be useful when load_state_dict() throws error for any reason:

with torch.no_grad():
    for source_param, target_param in zip(model_to_copy_from.parameters(), model.parameters()):
        target_param.data.copy_(source_param.data)

Upvotes: 0

jodag
jodag

Reputation: 22184

Assuming you have two instances of the same model (must subclass nn.Module), then you can use nn.Module.state_dict() and nn.Module.load_state_dict(). You can find a brief introduction to state dictionaries here.

model1.load_state_dict(model2.state_dict())

Upvotes: 2

Alex Metsai
Alex Metsai

Reputation: 1950

Here's two ways to do that.

# Use load state dict
model_source = Model()
model_dest = Model()
model_dest.load_state_dict(model_source.state_dict())

# Use deep copy
model_source = Model()
model_dest = copy.deepcopy(model_source )

Upvotes: 1

Related Questions