spiridon_the_sun_rotator
spiridon_the_sun_rotator

Reputation: 1054

How to convert torchscript model in PyTorch to ordinary nn.Module?

I am loading the torchscript model in the following way:

model = torch.jit.load("model.pt").to(device)

The children modules of this model are identified as RecursiveScriptModule. I would like to finetune the uploaded weights and in order to make it simplier and cast them to torch.float32 It is preferable to convert all this stuff to ordinary PyTorch nn.Module.

In the official docs https://pytorch.org/docs/stable/jit.html it is told how to convert nn.Module to torchscript, but I have not found any examples in doing this in the opposite direction. Is there a way to do this?

P.S the example of loading model pretrained model is given here: https://github.com/openai/CLIP/blob/main/notebooks/Interacting_with_CLIP.ipynb

Upvotes: 8

Views: 7980

Answers (1)

Jiarui Xu
Jiarui Xu

Reputation: 86

You may try to load it as it e.g. state_dict = torch.load(src).state_dict(). Then manually convert every key and value new_v = state_dict[k].cpu().float().

Upvotes: 3

Related Questions