An Ignorant Wanderer
An Ignorant Wanderer

Reputation: 1612

Testing an implementation of an LSTM in Pytorch

I'm trying to use the Pytorch implementation of an LSTM here. I'm including it here for reference. It consists of two classes, LSTMCell and LSTM, where LSTMCell is just a single unit and LSTM puts stacks multiple units together to create a full LSTM model

import math
import torch as th
import torch.nn as nn

class LSTMCell(nn.Module):

    def __init__(self, input_size, hidden_size, bias=True):
        super(LSTM, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bias = bias
        self.i2h = nn.Linear(input_size, 4 * hidden_size, bias=bias)
        self.h2h = nn.Linear(hidden_size, 4 * hidden_size, bias=bias)
        self.reset_parameters()

    def reset_parameters(self):
        std = 1.0 / math.sqrt(self.hidden_size)
        for w in self.parameters():
            w.data.uniform_(-std, std)

    def forward(self, x, hidden):
        if hidden is None:
            hidden = self._init_hidden(x)

        h, c = hidden
        h = h.view(h.size(1), -1)
        c = c.view(c.size(1), -1)
        x = x.view(x.size(1), -1)

        # Linear mappings
        preact = self.i2h(x) + self.h2h(h)

        # activations
        gates = preact[:, :3 * self.hidden_size].sigmoid()
        g_t = preact[:, 3 * self.hidden_size:].tanh()
        i_t = gates[:, :self.hidden_size]
        f_t = gates[:, self.hidden_size:2 * self.hidden_size]
        o_t = gates[:, -self.hidden_size:]

        c_t = th.mul(c, f_t) + th.mul(i_t, g_t)

        h_t = th.mul(o_t, c_t.tanh())

        h_t = h_t.view(1, h_t.size(0), -1)
        c_t = c_t.view(1, c_t.size(0), -1)
        return h_t, (h_t, c_t)

    @staticmethod
    def _init_hidden(input_):
        h = th.zeros_like(input_.view(1, input_.size(1), -1))
        c = th.zeros_like(input_.view(1, input_.size(1), -1))
        return h, c
class LSTM(nn.Module):

    def __init__(self, input_size, hidden_size, bias=True):
        super().__init__()
        self.lstm_cell = LSTMCell(input_size, hidden_size, bias)

    def forward(self, input_, hidden=None):
        # input_ is of dimensionalty (1, time, input_size, ...)

        outputs = []
        for x in torch.unbind(input_, dim=1):
            hidden = self.lstm_cell(x, hidden)
            outputs.append(hidden[0].clone())

        return torch.stack(outputs, dim=1)

I'm doing the following simple test:

x = torch.randn(1, 3, 2, 4)
model = LSTM(4, 5, False)
model(x)

and I get the following error. What exactly is the problem here?

TypeError                                 Traceback (most recent call last)
<ipython-input-33-09e5544a61fc> in <module>
----> 1 model = LSTM(4, 5, False)

<ipython-input-30-9ad06cd4b768> in __init__(self, input_size, hidden_size, bias)
      3     def __init__(self, input_size, hidden_size, bias=True):
      4         super().__init__()
----> 5         self.lstm_cell = LSTMCell(input_size, hidden_size, bias)
      6 
      7     def forward(self, input_, hidden=None):

<ipython-input-29-c91ddfb9dfae> in __init__(self, input_size, hidden_size, bias)
      6 
      7     def __init__(self, input_size, hidden_size, bias=True):
----> 8         super(LSTM, self).__init__()
      9         self.input_size = input_size
     10         self.hidden_size = hidden_size

TypeError: super(type, obj): obj must be an instance or subtype of type

Upvotes: 1

Views: 714

Answers (1)

Michael Jungo
Michael Jungo

Reputation: 32972

The first argument to super() should be class itself, not a different class.

class LSTMCell(nn.Module):

    def __init__(self, input_size, hidden_size, bias=True):
        super(LSTM, self).__init__()
#             ^^^^ self is not an instance of LSTM but LSTMCell

It should be:

super(LSTMCell, self).__init__()

Since Python 3 you can omit the arguments to super to get the same result (as you have done in the LSTM class):

super().__init__()

Upvotes: 1

Related Questions