Jinu
Jinu

Reputation: 586

How to use ITK to convert PNG into tensor for PyTorch

I'm trying to run a C++ PyTorch framework and ran into the following problem.

I successfully scripted a model and is now ready to run. Now I have to feed in a png image to the model.

I found someone with simmilar issue on the internet, his idea was to use ITK module to read in a PNG file and convert it to a array, then make it to Tensor.

PNG -> RGBPixel[] -> tensor

So the following is what I am trying out right now.

using PixelTyupe = itk::RGBPixel<unsinged char>;
const unsigned int Dimension = 3;
typedef itk::Image<PixelType, Dimension>      ImageType;
typedef itk::ImageFileReader<ImageType>       ReaderType;
typedef itk::ImageRegionIterator<ImageType>   IteratorType;

typename ImageType::RegionType region = itk_img->GetLargestPossibleRegion();
const typename ImageType::SizeType size = region.GetSize();

int len = size[0] * size[1] * size[2]; // This ends up 1920 * 1080 * 1
PixelType rowdata[len];
int count = 0;
IteratorType iter(itk_img, itk_img->GetRequestedRegion());

// convert itk to array
for (iter.GoToBegin(); !iter.IsAtEnd(); ++iter) {
   rowdata[count] = iter.Get();
   count++;
} // count = 1920 * 1080

// convert array to tensor
tensor_img = torch::from_blob(rowdata, {3, (int)size[0], (int)size[1]}, torch::kShort). clone(); // Segmenation Fault

When I try to print log data, it holds three numbers like 84 85 83, so I suppose the PNG file is successfully read.

However, I can't get the last part of the code to work. What I need is 3:1920:1080 tensor, but I don't think the three RGBPixel value is understood properly by the function.

And apart from that, I don't see why the dimension is set to 3.

I would appreciate any kind of help.

Upvotes: 0

Views: 403

Answers (1)

Dženan
Dženan

Reputation: 3395

You don't need dimension 3, Dimension = 2 is sufficient. If the layout you need is RGBx1920x1080, then PixelType* rowdata = itk_img->GetBufferPointer(); would get you that layout without further processing. As torch::from_blob does not take ownership of the buffer, the other person was trying to use .clone(). You don't have to do that either, assuming you keep the itk_img in scope, or muck around with its reference count and deleter.

The crash probably comes from saying the buffer has short pixels (torch::kShort), when it has uchar.

Upvotes: 1

Related Questions