Explorer
Explorer

Reputation: 1647

Sagemaker python sdk: accessing custom_attributes in inference job

I am using Sagemaker python sdk for my inference job and following this guide. I am triggering my sagemaker inference job from Airflow with below python callable:

def transform(sage_role, inference_file_local_path, **kwargs):
    """
    Python callable to execute Sagemaker SDK train job. It takes infer_batch_output, infer_batch_input, model_artifact,
    instance_type and infer_file_name as run time parameter.
    :param inference_file_local_path: Local entry_point path for Inference file.
    :param sage_role: Sagemaker execution role.
    """
    model = TensorFlowModel(entry_point=infer_file_name,
                            source_dir=inference_file_local_path,
                            model_data=model_artifact,
                            role=sage_role,
                            framework_version="2.5.1")

    tensorflow_serving_transformer = model.transformer(
        instance_count=1,
        instance_type=instance_type,
        accept="text/csv",
        strategy="SingleRecord",
        max_payload=10,
        max_concurrent_transforms=10,
        output_path=batch_output)

    return tensorflow_serving_transformer.transform(data=batch_input, content_type='text/csv')

and my simply inference.py looks like:


def input_handler(data, context):
    """ Pre-process request input before it is sent to TensorFlow Serving REST API
    Args:
        data (obj): the request data, in format of dict or string
        context (Context): an object containing request and configuration details
    Returns:
        (dict): a JSON-serializable dict that contains request body and headers
    """
    if context.request_content_type == 'application/x-npy':
        # very simple numpy handler
        payload = np.load(data.read().decode('utf-8'))
        x_user_feature = np.asarray(payload.item().get('test').get('feature_a_list'))
        x_channel_feature = np.asarray(payload.item().get('test').get('feature_b_list'))
        examples = []
        for index, elem in enumerate(x_user_feature):
            examples.append({'feature_a_list': elem, 'feature_b_list': x_channel_feature[index]})
        return json.dumps({'instances': examples})

    if context.request_content_type == 'text/csv':
        payload = pd.read_csv(data)
        print("Model name is ..............")
        model_name = context.model_name
        print(model_name)
        examples = []
        row_ch = []
        if config_exists(model_bucket, "{}{}".format(config_path, model_name)):
            config_keys = get_s3_json_file(model_bucket, "{}{}".format(config_path, model_name))
            feature_b_list = config_keys["feature_b_list"].split(",")
            row_ch = [float(ch_feature_str) for ch_feature_str in feature_b_list]
            if "column_names" in config_keys.keys():
                cols = config_keys["column_names"].split(",")
                payload.columns = cols
        for index, row in payload.iterrows():
            row_user = row['feature_a_list'].replace('[', '').replace(']', '').split()
            row_user = [float(x) for x in row_user]
            if not row_ch:
                row_ch = row['feature_b_list'].replace('[', '').replace(']', '').split()
                row_ch = [float(x) for x in row_ch]
            example = {'feature_a_list': row_user, 'feature_b_list': row_ch}
            examples.append(example)

    raise ValueError('{{"error": "unsupported content type {}"}}'.format(
        context.request_content_type or "unknown"))


def output_handler(data, context):
    """Post-process TensorFlow Serving output before it is returned to the client.
    Args:
        data (obj): the TensorFlow serving response
        context (Context): an object containing request and configuration details
    Returns:
        (bytes, string): data to return to client, response content type
    """
    if data.status_code != 200:
        raise ValueError(data.content.decode('utf-8'))

    response_content_type = context.accept_header
    prediction = data.content
    return prediction, response_content_type

It is working fine however I want to pass custom arguments to inference.py so that I can modify the input data accordingly based on requirement. I thought of using a config file per requirement and download it from s3 based on model name but as I am using model_data and passes model.tar.gz at runtime context.model_name is always None.

Is there a way I can pass run time argument to inference.py that I can use for customization? In the docs I see sagemaker provides custom_attributes but I don't see any example of it on how to use it and access it in inference.py.

custom_attributes (string): content of ‘X-Amzn-SageMaker-Custom-Attributes’ header from the original request. For example, ‘tfs-model-name=half*plus*three,tfs-method=predict’

Upvotes: 1

Views: 615

Answers (1)

Marc Karp
Marc Karp

Reputation: 1314

Currently CustomAttributes is supported in the InvokeEndpoint API call when using a realtime Endpoint.

As an example, you can look at passing JSON Lines as input to your Transform Job that contains the input payload and some custom arguments which you can consume in your inference.py file.

For example,

{
   "input":"1,2,3,4",
   "custom_args":"my_custom_arg"
}

Upvotes: 1

Related Questions