Reputation: 43639
I have:
def __init__(self, feature_dim=15, hidden_size=5, num_layers=2):
super(BaselineModel, self).__init__()
self.num_layers = num_layers
self.hidden_size = hidden_size
self.lstm = nn.LSTM(input_size=feature_dim,
hidden_size=hidden_size, num_layers=num_layers)
and then I get an error:
RuntimeError: The size of tensor a (5) must match the size of tensor b (15) at non-singleton dimension 2
If I set the two sizes to be the same, then the error goes away. But I'm wondering if my input_size
is some large number, say 15, and I want to reduce the number of hidden features to 5, why shouldn't that work?
Upvotes: 6
Views: 2319
Reputation: 3234
The short answer is: Yes, input_size
can be different from hidden_size
.
For an elaborated answer, take a look at the LSTM formulae in the PyTorch documentations, for instance:
This is the formula to compute i_t, the input activation at the t-th time step for one layer. Here the matrix W_ii has the shape of (hidden_size x input_size)
. Similarly in other formulae, matrices W_if, W_ig, and W_io all have the same shape. These matrices project the input tensor into the same space as hidden states, so that they can be added together.
Back to your specific problem, as the other answer pointed out, it's probably an error at another part of your code. Without looking at your forward
implementation, it's hard to say what the problem is exactly.
Upvotes: 2
Reputation: 895
It should work the error probably came from elsewhere. This work for example:
feature_dim = 15
hidden_size = 5
num_layers = 2
seq_len = 5
batch_size = 3
lstm = nn.LSTM(input_size=feature_dim,
hidden_size=hidden_size, num_layers=num_layers)
t1 = torch.from_numpy(np.random.uniform(0,1,size=(seq_len, batch_size, feature_dim))).float()
output, states = lstm.forward(t1)
hidden_state, cell_state = states
print("output: ",output.size())
print("hidden_state: ",hidden_state.size())
print("cell_state: ",cell_state.size())
and return
output: torch.Size([5, 3, 5])
hidden_state: torch.Size([2, 3, 5])
cell_state: torch.Size([2, 3, 5])
Are you using the output somewhere after the lstm ? Did you notice it has a size equal to hidden dim ie 5 on last dim ? It looks like you're using it afterwards thinking it has a size of 15 instead
Upvotes: 3