Reputation: 413
Here is what I want to do. I have an individual data of shape (20,20,20) where 20 tensors of shape (1,20,20) will be used as an input for 20 separate CNN. Here's the code I have so far.
class MyModel(torch.nn.Module):
def __init__(self, ...):
...
self.features = nn.ModuleList([nn.Sequential(
nn.Conv2d(1,10, kernel_size = 3, padding = 1),
nn.ReLU(),
nn.Conv2d(10, 14, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(14, 18, kernel_size=3, padding=1),
nn.ReLU(),
nn.Flatten(),
nn.Linear(28*28*18, 256)
) for _ in range(20)])
self.fc_module = nn.Sequential(
nn.Linear(256*n_selected, cnn_output_dim),
nn.Softmax(dim=n_classes)
)
def forward(self, input_list):
concat_fusion = cat([cnn(x) for x,cnn in zip(input_list,self.features)], dim = 0)
output = self.fc_module(concat_fusion)
return output
The shape of the input_list in forward function is torch.Size([100, 20, 20, 20]), where 100 is the batch size. However, there's an issue with
concat_fusion = cat([cnn(x) for x,cnn in zip(input_list,self.features)], dim = 0)
as it results in this error.
RuntimeError: Expected 4-dimensional input for 4-dimensional weight [10, 1, 3, 3], but got 3-dimensional input of size [20, 20, 20] instead
First off, I wonder why it expects me to give 4-dimensional weight [10,1,3,3]. I've seen "RuntimeError: Expected 4-dimensional input for 4-dimensional weight 32 3 3, but got 3-dimensional input of size [3, 224, 224] instead"? but I'm not sure where those specific numbers are coming from.
I have an input_list which is a batch of 100 data. I'm not sure how I can deal with individual data of shape (20,20,20) so that I can actually separate this into 20 pieces to use it as an independent input to 20 CNN.
Upvotes: 0
Views: 754
Reputation: 99
Note the following log means the nn.Conv2d with kernel (10, 1, 3, 3) requiring a 4 dimensional input.
RuntimeError: Expected 4-dimensional input for 4-dimensional weight [10, 1, 3, 3]
Iteration over input_list(100, 20, 20, 20)
produces 100 tensors of shape (20, 20, 20).
If you want to split input along channel, try to slice input_list along second dimension.
concat_fusion = torch.cat([cnn(input_list[:, i:i+1]) for i, cnn in enumerate(self.features)], dim = 1)
Upvotes: 1