Reputation: 783
I am using YOLOV7 model. The pretrained weights shared are optimised and shared in float16 dtype.
How can I convert the dtype of parameters of model in PyTorch. I want to convert the type of the weights to float32 type.
weights = torch.load('yolov7-mask.pt')
model = weights['model']
Upvotes: 1
Views: 5524
Reputation: 537
Load weights to your model and just call .float()
.
example:
cp = torch.load('yolov7-mask.pt')
model.load_state_dict(cp['weight'])
model = model.float()
It'll work if the model's class is nn.Module
. (Checked for torch version 1.8)
Upvotes: 2