Vinny Jacobsen
Vinny Jacobsen

Reputation: 137

How can I apply a transformation to a torch tensor

I have a Torch Tensor z and I would like to apply a transformation matrix mat to z and have the output be exactly the same size as z. Here is the code I am running:

def trans(z):
    print(z)
    mat = transforms.Compose([transforms.ToPILImage(),transforms.RandomRotation(90),transforms.ToTensor()])
    z = Variable(mat(z.cpu()).cuda())
    z = nnf.interpolate(z, size=(28, 28), mode='linear', align_corners=False)
    return z
z = trans(z)

However, I get this error:

RuntimeError                              Traceback (most recent call last)
<ipython-input-12-e2fc36889ba5> in <module>()
      3 inputs,targs=next(iter(tst_loader))
      4 recon, mean, var = vae.predict(model, inputs[img_idx])
----> 5 out = vae.generate(model, mean, var)

4 frames
/content/vae.py in generate(model, mean, var)
     90     z = trans(z)
     91     z = Variable(z.cpu().cuda())
---> 92     out = model.decode(z)
     93     return out.data.cpu()
     94 

/content/vae.py in decode(self, z)
     56 
     57     def decode(self, z):
---> 58         out = self.z_develop(z)
     59         out = out.view(z.size(0), 64, self.z_dim, self.z_dim)
     60         out = self.decoder(out)

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/linear.py in forward(self, input)
     89 
     90     def forward(self, input: Tensor) -> Tensor:
---> 91         return F.linear(input, self.weight, self.bias)
     92 
     93     def extra_repr(self) -> str:

/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py in linear(input, weight, bias)
   1674         ret = torch.addmm(bias, input, weight.t())
   1675     else:
-> 1676         output = input.matmul(weight.t())
   1677         if bias is not None:
   1678             output += bias

RuntimeError: mat1 dim 1 must match mat2 dim 0

How can I successfully apply this rotation transform mat and not get any errors doing so?

Thanks, Vinny

Upvotes: 4

Views: 6077

Answers (1)

Berriel
Berriel

Reputation: 13601

The problem is that interpolate expects a batch dimension, and looks like your data does not have one, based on the error message and the successful application of transforms. Since your input is spatial (based on the size=(28, 28)), you can fix that by adding the batch dimension and changing the mode, since linear is not implemented for spatial input:

z = nnf.interpolate(z.unsqueeze(0), size=(28, 28), mode='bilinear', align_corners=False)

If you want z to still have a shape like (C, H, W), then:

z = nnf.interpolate(z.unsqueeze(0), size=(28, 28), mode='bilinear', align_corners=False).squeeze(0)

Upvotes: 2

Related Questions