Reputation: 213
I trained a custom model with PyTorch using colab environment. I successfully saved the trained model to Google Drive with the name model_final.pth
. I want to convert model_final.pth
to model_final.pt
so that it can be used on mobile devices.
The code I use to train the model is as follows:
from detectron2.engine import DefaultTrainer
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.DATASETS.TRAIN = ("mouse_train",)
cfg.DATASETS.TEST = ()
cfg.DATALOADER.NUM_WORKERS = 2
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
cfg.SOLVER.IMS_PER_BATCH = 2
cfg.SOLVER.BASE_LR = 0.00025
cfg.SOLVER.MAX_ITER = 1000
cfg.SOLVER.STEPS = []
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 512
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1
cfg.OUTPUT_DIR="drive/Detectron2/"
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
trainer = DefaultTrainer(cfg)
trainer.resume_or_load(resume=False)
trainer.train()
The code I used to convert the model is as follows:
from detectron2.modeling import build_model
import torch
import torchvision
print("cfg.MODEL.WEIGHTS: ",cfg.MODEL.WEIGHTS) ## RETURNS : cfg.MODEL.WEIGHTS: drive/Detectron2/model_final.pth
model = build_model(cfg)
model.eval()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("drive/Detectron2/model-final.pt")
But I am getting this error IndexError: too many indices for tensor of dimension 3 :
cfg.MODEL.WEIGHTS: drive/Detectron2/model_final.pth
/usr/local/lib/python3.6/dist-packages/torch/tensor.py:593: RuntimeWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).
'incorrect results).', category=RuntimeWarning)
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
<ipython-input-17-8e544c0f39c8> in <module>()
7 model.eval()
8 example = torch.rand(1, 3, 224, 224)
----> 9 traced_script_module = torch.jit.trace(model, example)
10 traced_script_module.save("drive/Detectron2/model_final.pt")
7 frames
/usr/local/lib/python3.6/dist-packages/detectron2/modeling/meta_arch/rcnn.py in <listcomp>(.0)
219 Normalize, pad and batch the input images.
220 """
--> 221 images = [x["image"].to(self.device) for x in batched_inputs]
222 images = [(x - self.pixel_mean) / self.pixel_std for x in images]
223 images = ImageList.from_tensors(images, self.backbone.size_divisibility)
IndexError: too many indices for tensor of dimension 3
Upvotes: 6
Views: 18866
Reputation: 101
This example can help. This is approach like bottom method. But it with .pth using obviously.
import torch
import torchvision
from unet import UNet
model = UNet(3, 2)
model.load_state_dict(torch.load("best_weights.pth"))
model.eval()
example = torch.rand(1, 3, 320, 480)
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("model.pt")
Code from this site.
Upvotes: 2
Reputation: 510
Detectron2 models expect a dictionary or a list of dictionaries as input by default.
So you can not directly use torch.jit.trace
function. But they provide a wrapper, called TracingAdapter, that allows models to take a tensor or a tuple of tensors as input. You can find out how to use it in their torchscript tests.
The code for tracing your Mask RCNN model could be (I did not try it):
import torch
import torchvision
from detectron2.export.flatten import TracingAdapter
def inference_func(model, image):
inputs = [{"image": image}]
return model.inference(inputs, do_postprocess=False)[0]
print("cfg.MODEL.WEIGHTS: ",cfg.MODEL.WEIGHTS) ## RETURNS : cfg.MODEL.WEIGHTS: drive/Detectron2/model_final.pth
model = build_model(cfg)
example = torch.rand(1, 3, 224, 224)
wrapper = TracingAdapter(model, example, inference_func)
wrapper.eval()
traced_script_module = torch.jit.trace(wrapper, (example,))
traced_script_module.save("drive/Detectron2/model-final.pt")
More info on detectron2 deployment with tracing can be found here.
Upvotes: 7