Reputation: 137
I have a Torch Tensor z
and I would like to apply a transformation matrix mat
to z
and have the output be exactly the same size as z
. Here is the code I am running:
def trans(z):
print(z)
mat = transforms.Compose([transforms.ToPILImage(),transforms.RandomRotation(90),transforms.ToTensor()])
z = Variable(mat(z.cpu()).cuda())
z = nnf.interpolate(z, size=(28, 28), mode='linear', align_corners=False)
return z
z = trans(z)
However, I get this error:
RuntimeError Traceback (most recent call last)
<ipython-input-12-e2fc36889ba5> in <module>()
3 inputs,targs=next(iter(tst_loader))
4 recon, mean, var = vae.predict(model, inputs[img_idx])
----> 5 out = vae.generate(model, mean, var)
4 frames
/content/vae.py in generate(model, mean, var)
90 z = trans(z)
91 z = Variable(z.cpu().cuda())
---> 92 out = model.decode(z)
93 return out.data.cpu()
94
/content/vae.py in decode(self, z)
56
57 def decode(self, z):
---> 58 out = self.z_develop(z)
59 out = out.view(z.size(0), 64, self.z_dim, self.z_dim)
60 out = self.decoder(out)
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
720 result = self._slow_forward(*input, **kwargs)
721 else:
--> 722 result = self.forward(*input, **kwargs)
723 for hook in itertools.chain(
724 _global_forward_hooks.values(),
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/linear.py in forward(self, input)
89
90 def forward(self, input: Tensor) -> Tensor:
---> 91 return F.linear(input, self.weight, self.bias)
92
93 def extra_repr(self) -> str:
/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py in linear(input, weight, bias)
1674 ret = torch.addmm(bias, input, weight.t())
1675 else:
-> 1676 output = input.matmul(weight.t())
1677 if bias is not None:
1678 output += bias
RuntimeError: mat1 dim 1 must match mat2 dim 0
How can I successfully apply this rotation transform mat
and not get any errors doing so?
Thanks, Vinny
Upvotes: 4
Views: 6077
Reputation: 13601
The problem is that interpolate
expects a batch dimension, and looks like your data does not have one, based on the error message and the successful application of transforms
. Since your input is spatial (based on the size=(28, 28)
), you can fix that by adding the batch dimension and changing the mode
, since linear
is not implemented for spatial input:
z = nnf.interpolate(z.unsqueeze(0), size=(28, 28), mode='bilinear', align_corners=False)
If you want z
to still have a shape like (C, H, W), then:
z = nnf.interpolate(z.unsqueeze(0), size=(28, 28), mode='bilinear', align_corners=False).squeeze(0)
Upvotes: 2