Mahesh
Mahesh

Reputation: 41

How to pass inputs for my triton model using tritionclient python package?

My triton model config.pbtxt file looks like below. How can I pass inputs and outputs using tritonclient and perform an infer request.

name: “cifar10”
platform: “tensorflow_savedmodel”
max_batch_size: 10000
input [
{
name: “input_1”
data_type: TYPE_FP32
dims: [ 32, 32, 3 ]
}
]
output [
{
name: “fc10”
data_type: TYPE_FP32
dims: [ 10 ]
}
]

Any help is appreciated as I am very new to using tritonclient python package please

Upvotes: 0

Views: 418

Answers (1)

Sysanin
Sysanin

Reputation: 1803

See example for your case, but it was not tested as I don't have script which creates your model and hence minor issues might arise with some explicit exceptions. You can use this code snippet as reference and baseline for you solution.

Before executing - make sure you have running triton server on localhost port 8000 or change URI address appropriately.

import tritonclient.http as httpclient
import argparse

import numpy as np

def test_infer(model_name, input_data, output_data):
    r = triton_client.infer(
        model_name=model_name,
        inputs=input_data,
        outputs=output_data
    )
    return r


payload_input = []
payload_output = []

payload_input.append(httpclient.InferInput("input_1", [32, 32, 3], "FP32"))

print("-----------fill inputs with data------------")

input1_data = np.full(shape=(32, 32, 3), fill_value="trousers", dtype=np.float32)
payload_input[0].set_data_from_numpy(input1_data, binary_data=False)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '-u',
        '--url',
        default='localhost:8000',
        help='Endpoint URL'
    )

    params = parser.parse_args()
    # setup client:
    triton_client = httpclient.InferenceServerClient(params.url, verbose=True)
    # send request:
    results = test_infer("“cifar10”", payload_input, payload_output)

    print("----------_RESPONSE_-----------")
    print(results.get_response())
    print(results.as_numpy("“fc10”"))

Upvotes: 0

Related Questions