Reputation: 679
based on documentation provided here , https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html#model-directory-structure, the model file saved from training is model.pth. I also read that it can be .pt extension or even bin extension. I have seen a example of pytorch_model.bin, but when i tried to serve the model with the pytorch_model.bin, it warns me that .pt or .pth file needs to exist. has anyone run into this?
Upvotes: 1
Views: 846
Reputation: 56
Interesting question.
I'm assuming you're trying to use the PyTorch container from SageMaker in what we call "script mode" - where you just provide the .py
entrypoint.
Have you tried to define a model_fn()
function, where you specify how to load your model? The documentation talks about this here.
More details:
Before a model can be served, it must be loaded. The SageMaker PyTorch model server loads your model by invoking a model_fn function that you must provide in your script when you are not using Elastic Inference.
import torch
import os
import YOUR_MODEL_DEFINITION
def model_fn(model_dir):
model = YOUR_MODEL_DEFINITION()
with open(os.path.join(model_dir, 'YOUR-MODEL-FILE-HERE'), 'rb') as f:
model.load_state_dict(torch.load(f))
return model
Let me know if this works!
Upvotes: 1