Reputation: 1820
I have trained and deployed a model in Pytorch with Sagemaker. I am able to call the endpoint and get a prediction. I am using the default input_fn() function (i.e. not defined in my serve.py).
model = PyTorchModel(model_data=trained_model_location,
role=role,
framework_version='1.0.0',
entry_point='serve.py',
source_dir='source')
predictor = model.deploy(initial_instance_count=1, instance_type='ml.m4.xlarge')
A prediction can be made as follows:
input ="0.12787057, 1.0612601, -1.1081504"
predictor.predict(np.genfromtxt(StringIO(input), delimiter=",").reshape(1,3) )
I want to be able to serve the model with REST API and am HTTP POST using lambda and API gateway. I was able to use invoke_endpoint() for this with an XGBOOST model in Sagemaker this way. I am not sure what to send into the body for Pytorch.
client = boto3.client('sagemaker-runtime')
response = client.invoke_endpoint(EndpointName=ENDPOINT ,
ContentType='text/csv',
Body=???)
I believe I need to understand how to write the customer input_fn to accept and process the type of data I am able to send through invoke_client. Am I on the right track and if so, how could the input_fn be written to accept a csv from invoke_endpoint?
Upvotes: 1
Views: 2844
Reputation: 4037
Yes you are on the right track. You can send csv-serialized input to the endpoint without using the predictor
from the SageMaker SDK, and using other SDKs such as boto3
which is installed in lambda:
import boto3
runtime = boto3.client('sagemaker-runtime')
payload = '0.12787057, 1.0612601, -1.1081504'
response = runtime.invoke_endpoint(
EndpointName=ENDPOINT_NAME,
ContentType='text/csv',
Body=payload.encode('utf-8'))
result = json.loads(response['Body'].read().decode())
This will pass to the endpoint a csv-formatted input, that you may need to reshape back in the input_fn
to put in the appropriate dimension expected by the model.
for example:
def input_fn(request_body, request_content_type):
if request_content_type == 'text/csv':
return torch.from_numpy(
np.genfromtxt(StringIO(request_body), delimiter=',').reshape(1,3))
Note: I wasn't able to test the specific input_fn
above with your input content and shape but I used the approach on Sklearn RandomForest couple times, and looking at the Pytorch SageMaker serving doc the above rationale should work.
Don't hesitate to use endpoint logs in Cloudwatch to diagnose any inference error (available from the endpoint UI in the console), those logs are usually much more verbose that the high-level logs returned by the inference SDKs
Upvotes: 2