How do I convert a torch tensor to an image to be returned by FastAPI?

I have a torch tensor which I need to convert to a byte object so that I can pass it to starlette's StreamingResponse which will return a reconstructed image from the byte object. I am trying to convert the tensor and return it like so:

def some_unimportant_function(params):
    return_image = io.BytesIO()
    torch.save(some_img, return_image)
    return_image.seek(0)
    return_img = return_image.read()
    
    return StreamingResponse(content=return_img, media_type="image/jpeg")

The below works fine on regular byte objects and my API returns the reconstructed image:

def some_unimportant_function(params):
    image = Image.open(io.BytesIO(some_image))

    return_image = io.BytesIO()
    image.save(return_image, "JPEG")
    return_image.seek(0)
    return StreamingResponse(content=return_image, media_type="image/jpeg")

Using PIL library for this

what am I doing wrong here?

Upvotes: 7

Views: 6434

Answers (1)

Mazen
Mazen

Reputation: 554

Converting PyTorch Tensor to the PIL Image object using torchvision.transforms.ToPILImage() module and then treating it as PIL Image as your second function would work. Here is an example.

def some_unimportant_function(params):
    tensor = # read the tensor from disk or whatever
    image = torchvision.transforms.ToPILImage()(tensor.unsqueeze(0))
    return_image = io.BytesIO()
    image.save(return_image, "JPEG")
    return_image.seek(0)
    return StreamingResponse(content=return_image, media_type="image/jpeg")

Upvotes: 5

Related Questions