Reputation: 456
The default non-linear activation function in LSTM class is tanh. I wish to use ReLU for my project. Browsing through the documentation and other resources, I'm unable to find a way to do this in a simple manner. The only way I could find was to define my own custom LSTMCell, but here the author says that custom LSTMCells don't support GPU acceleration capabilities(or has that changed since the article was published?). I need to use CUDA to speed up my training. Any help would be appreciated.
Upvotes: 8
Views: 14424
Reputation: 37701
Custom LSTMCells don't support GPU acceleration capabilities - this statement probably means GPU acceleration capabilities become limited if you use LSTMCells. And definitely, you can write your own implementation of LSTM but you need to sacrifice runtime.
For example, once I implemented an LSTM (based on linear layers) as follows which used to take 2~3 times more time than LSTM (provided in PyTorch) when used as a part of a deep neural model.
class LSTMCell(nn.Module):
def __init__(self, input_size, hidden_size, nlayers, dropout):
""""Constructor of the class"""
super(LSTMCell, self).__init__()
self.nlayers = nlayers
self.dropout = nn.Dropout(p=dropout)
ih, hh = [], []
for i in range(nlayers):
ih.append(nn.Linear(input_size, 4 * hidden_size))
hh.append(nn.Linear(hidden_size, 4 * hidden_size))
self.w_ih = nn.ModuleList(ih)
self.w_hh = nn.ModuleList(hh)
def forward(self, input, hidden):
""""Defines the forward computation of the LSTMCell"""
hy, cy = [], []
for i in range(self.nlayers):
hx, cx = hidden[0][i], hidden[1][i]
gates = self.w_ih[i](input) + self.w_hh[i](hx)
i_gate, f_gate, c_gate, o_gate = gates.chunk(4, 1)
i_gate = F.sigmoid(i_gate)
f_gate = F.sigmoid(f_gate)
c_gate = F.tanh(c_gate)
o_gate = F.sigmoid(o_gate)
ncx = (f_gate * cx) + (i_gate * c_gate)
nhx = o_gate * F.tanh(ncx)
input = self.dropout(nhx)
hy, cy = torch.stack(hy, 0), torch.stack(cy, 0)
return hy, cy
I would be happy to know if the runtime of custom implementation of LSTM can be improved!
Upvotes: 7