Reputation: 980
For a nice output in Tensorboard I want to show a batch of input images, corresponding target masks and output masks in a grid. Input images have different size then the masks. Furthermore the images are obviously RGB. From a batch of e.g. 32 or 64 I only want to show the first 4 images.
After some fiddling around I came up with the following example code. Good thing: It works. But I am really not sure if I missed something in Pytorch. It just looks much longer then I expected. Especially the upsampling and transformation to RGB seems wild. But the other transformations I found would not work for a whole batch.
import torch
from torch.autograd import Variable
import torch.nn.functional as FN
import torchvision.utils as vutils
from tensorboardX import SummaryWriter
import time
batch = 32
i_size = 192
o_size = 112
nr_imgs = 4
# Tensorboard init
writer = SummaryWriter('runs/' + time.strftime('%Y%m%d_%H%M%S'))
input_image=Variable(torch.rand(batch,3,i_size,i_size))
target_mask=Variable(torch.rand(batch,o_size,o_size))
output_mask=Variable(torch.rand(batch,o_size,o_size))
# upsample target_mask, add dim to have gray2rgb
tm = FN.upsample(target_mask[:nr_imgs,None], size=[i_size, i_size], mode='bilinear')
tm = torch.cat( (tm,tm,tm), dim=1) # grayscale plane to rgb
# upsample target_mask, add dim to have gray2rgb
om = FN.upsample(output_mask[:nr_imgs,None], size=[i_size, i_size], mode='bilinear')
om = torch.cat( (om,om,om), dim=1) # grayscale plane to rgb
# add up all images and make grid
imgs = torch.cat( ( input_image[:nr_imgs].data, tm.data, om.data ) )
x = vutils.make_grid(imgs, nrow=nr_imgs, normalize=True, scale_each=True)
# Tensorboard img output
writer.add_image('Image', x, 0)
EDIT: Found this on Pytorchs Issues list. Its about Batch support for Transform
. Seems there are no plans to add batch transforms in the future. So my current code might be the best solution for the time being, anyway?
Upvotes: 1
Views: 2997
Reputation: 1444
Maybe you can just convert your Tensors to the numpy array (.data.cpu().numpy() ) and use opencv to do upsampling? OpenCV implementation should be quite fast.
Upvotes: 2