Chiam Yuwei
Chiam Yuwei

Reputation: 21

How to set the input channels size?

It is an assignment using pytorch for hand gesture recognition. Code:

D = np.array(Images).astype('float32')
y = np.array(Labels).astype(int)

for i in tqdm(range(X.shape[0])):
    train_data.append(X[i]) # original image
    train_data.append(rotate(X[i], angle = 45, mode = 'wrap')) 
    train_data.append(np.fliplr(X[i])) 
    train_data.append(np.flipud(X[i]))
    train_data.append(random_noise(X[i], var = 0.2 ** 2)) 
    
    for j in range(5):
        target_train.append(Y[i]) 

class Net(nn.Module):
    
    def __init__(self, num_classes = 5):
        
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(in_channels = 1, out_channels = 12, kernel_size = 3, stride = 1, padding = 1)
        self.conv2 = nn.Conv2d(in_channels = 12, out_channels = 24, kernel_size = 3, stride = 1, padding = 1)
        self.pool = nn.MaxPool2d(kernel_size = 2)
        self.drop = nn.Dropout2d(p = 0.2)
        self.fc1 = nn.Linear(in_features = 19.5 * 19.5 * 24, out_features = 120)
        self.fc2 = nn.Linear(in_features = 120, out_features = num_classes)
    
    def forward(self, x):
        
        x = F.relu(self.pool(self.conv1(x)))
        x = F.relu(self.pool(self.conv2(x)))
        x = F.dropout(self.drop(x), training = self.training)
        x = x.view(-1, 19.5 * 19.5 * 24)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)

Error:

RuntimeError: Given groups=1, weight of size [8, 1, 3, 3], expected input[1, 32, 78, 78] to have 1 channels, but got 32 channels instead

Size of X: (2080, 300, 300, 3)

Size of y: (2080,)

How do I set the input channel size for the fc1(fully connected layer 1)?

Upvotes: 0

Views: 244

Answers (1)

shivarama23
shivarama23

Reputation: 242

The input should be of the format [batch_size, channels, height, width] in PyTorch, so you have to change your input to (2080, 1, 300, 300) instead of (2080, 300, 300, 3). As per your NN architecture, the input should be single channel and not 3 channel.

Also,

x = x.view(-1, 19.5 * 19.5 * 24)

will throw an error if the input size and 19.5 * 19.5 * 24(i.e. 9126) are not divisible.

Fixing these 2 things should solve the problem.

Upvotes: 0

Related Questions