Nikita Belooussov
Nikita Belooussov

Reputation: 596

pytorch object too deep for array when saving image

I am trying to run the code from the following github rep:

https://github.com/iamkrut/image_inpainting_resnet_unet

I havent changed anything in the code and it is causing a ValueError, that the object is too deep, when the code tries to save the image. The error seems to come from these two lines.

images = img_tensor.cpu().detach().permute(0,2,3,1)
plt.imsave(join(data_dir, 'samples', image), images[index,:,:,:3])

Here is the error statement

  File "train.py", line 205, in <module>
    data_dir=args.data_dir)
  File "train.py", line 94, in train_net
    plt.imsave(join(data_dir, 'samples', image), images[index,:,:,:]);
  File "C:\ProgramData\Anaconda3\envs\torch2\lib\site-packages\matplotlib\pyplot.py", line 2140, in imsave
    return matplotlib.image.imsave(fname, arr, **kwargs)
  File "C:\ProgramData\Anaconda3\envs\torch2\lib\site-packages\matplotlib\image.py", line 1498, in imsave
    _png.write_png(rgba, fname, dpi=dpi)
ValueError: object too deep for desired array

Anyone know what could be causing this or how to fix it? Thank you

Upvotes: 1

Views: 238

Answers (2)

Seyed Sajad Ashrafi
Seyed Sajad Ashrafi

Reputation: 171

matplotlib package does not understand the pytorch datatype (tensor). you should convert tensor array to numpy array and then use matplotlib functions.

a = torch.rand(10, 3, 20, 20)
plt.imsave("test.jpg", a.cpu().detach().permute(0, 2, 3, 1)[0, ...]) # Error
plt.imsave("test.jpg", a.cpu().detach().permute(0, 2, 3, 1).numpy()[0, ...])

Upvotes: 1

Nikita Belooussov
Nikita Belooussov

Reputation: 596

I managed to fix the code by changing the lines to

images=img_tensor.cpu().numpy()[0]
images = np.transpose(images, (1,2,0))
plt.imsave(join(data_dir, 'samples', image), images)

Still not sure what was wrong with the previous version. So if anyone knows please tell me.

Upvotes: 0

Related Questions