Reputation: 596
I am trying to run the code from the following github rep:
https://github.com/iamkrut/image_inpainting_resnet_unet
I havent changed anything in the code and it is causing a ValueError, that the object is too deep, when the code tries to save the image. The error seems to come from these two lines.
images = img_tensor.cpu().detach().permute(0,2,3,1)
plt.imsave(join(data_dir, 'samples', image), images[index,:,:,:3])
Here is the error statement
File "train.py", line 205, in <module>
data_dir=args.data_dir)
File "train.py", line 94, in train_net
plt.imsave(join(data_dir, 'samples', image), images[index,:,:,:]);
File "C:\ProgramData\Anaconda3\envs\torch2\lib\site-packages\matplotlib\pyplot.py", line 2140, in imsave
return matplotlib.image.imsave(fname, arr, **kwargs)
File "C:\ProgramData\Anaconda3\envs\torch2\lib\site-packages\matplotlib\image.py", line 1498, in imsave
_png.write_png(rgba, fname, dpi=dpi)
ValueError: object too deep for desired array
Anyone know what could be causing this or how to fix it? Thank you
Upvotes: 1
Views: 238
Reputation: 171
matplotlib package does not understand the pytorch datatype (tensor). you should convert tensor array to numpy array and then use matplotlib functions.
a = torch.rand(10, 3, 20, 20)
plt.imsave("test.jpg", a.cpu().detach().permute(0, 2, 3, 1)[0, ...]) # Error
plt.imsave("test.jpg", a.cpu().detach().permute(0, 2, 3, 1).numpy()[0, ...])
Upvotes: 1
Reputation: 596
I managed to fix the code by changing the lines to
images=img_tensor.cpu().numpy()[0]
images = np.transpose(images, (1,2,0))
plt.imsave(join(data_dir, 'samples', image), images)
Still not sure what was wrong with the previous version. So if anyone knows please tell me.
Upvotes: 0