Sargam Modak
Sargam Modak

Reputation: 783

Change dtype of weights for pytorch pretrained model

I am using YOLOV7 model. The pretrained weights shared are optimised and shared in float16 dtype.
How can I convert the dtype of parameters of model in PyTorch. I want to convert the type of the weights to float32 type.

weights = torch.load('yolov7-mask.pt')
model = weights['model']

Upvotes: 1

Views: 5524

Answers (1)

Hayoung
Hayoung

Reputation: 537

Load weights to your model and just call .float().

example:

cp = torch.load('yolov7-mask.pt')
model.load_state_dict(cp['weight'])
model = model.float()

It'll work if the model's class is nn.Module. (Checked for torch version 1.8)

Upvotes: 2

Related Questions