arve
arve

Reputation: 801

How to get batch predictions with jsonl data in sagemaker?

I have a pytorch model that i have tested as a real-time endpoint in sagemaker, now i want to test it with batch inference. I am using jsonl data, and setting up a batch transform job as documented in aws documentation, in addition, i'm using my own inference.py (see sample below). I'm getting a json decode error inside the input_fn , function, when i try => json.loads(request_body).

the error is => raise JSONDecodeError("Extra data", s, end)

has anyone tried this? I sucessfully tested this model and json input with a real time endpoint in sagemaker, but now i'm trying to switch to batch and it is erroring it out.

inference.py

def model_fn(model_dir):
   ....


def input_fn(request_body, request_content_type):
    data = json.loads(request_body)
    return data

def predict_fn(data, model)
  ...

set up for batch job via lambda

response = client.create_transform_job(
    TransformJobName='some-job',
    ModelName='mypytorchmodel',
    ModelClientConfig={
        'InvocationsTimeoutInSeconds': 3600,
        'InvocationsMaxRetries': 1
    },
    BatchStrategy='MultiRecord',
    TransformInput={
        'DataSource': {
            'S3DataSource': {
                'S3DataType': 'S3Prefix',
                'S3Uri': 's3://inputpath'
            }
        },
        'ContentType': 'application/json',
        'SplitType': 'Line'
    },
    TransformOutput={
        'S3OutputPath': 's3://outputpath',
        'Accept': 'application/json',
        'AssembleWith': 'Line',
    },
    TransformResources={
        'InstanceType': 'ml.g4dn.xlarge'
        'InstanceCount': 1
    }
)

input file

{"input" : "input line one"}
{"input" : "input line two"}
{"input" : "input line three"}
{"input" : "input line four"}
{"input" : "input line five"}
...

Upvotes: 3

Views: 1523

Answers (1)

Ram Vegiraju
Ram Vegiraju

Reputation: 379

What is your client side code where you are invoking the endpoint? You should also be properly serializing the data on the client side and handling it in your inference script. Example:

import json
data = json.loads(json.dumps(request_body))
payload = json.dumps(data)
response = client.invoke_endpoint(
    EndpointName=endpoint_name,
    ContentType=content_type,
    Body=payload)
result = json.loads(response['Body'].read().decode())['Output']
result

Make sure to also specify your content_type appropriately "application/jsonlines".

Upvotes: 2

Related Questions