Reza
Reza

Reputation: 130

Pytorch RNN with no nonlinearity

Is it possible to implement an RNN layer with no nonlinearity in Pytorch like in Keras where one can set the activation to linear? By removing the nonlinearlity, I want to implement a first-order infinite-impulse-response (IIR) filter with a differentiable parameter and integrate it into my model for end-to-end learning. I can obviously implement the filter in Pytorch but I thought using an inbuilt function may be more efficient.

Upvotes: 1

Views: 2522

Answers (3)

iacob
iacob

Reputation: 24351

No, the PyTorch nn.RNN module takes only Tanh or RELU:

nonlinearity – The non-linearity to use. Can be either 'tanh' or 'relu'. Default: 'tanh'

You could implement this yourself however by writing your own for loop over the sequence, as in this example.

Upvotes: 1

Akshay Sehgal
Akshay Sehgal

Reputation: 19307

Removing non-linearity from RNN turns it into a linear dense layer without any activation.

If that is what you want, then simply use nn.linear and set activation to None

Explanation

Here is why this happens. Fundamentally, an RNN for timesteps works as below -

enter image description here

h(t) = tanh(U.x(t) + W.h(t−1) + b)

h(0) = tanh(U0.x(0) + b0)
h(1) = tanh(U1.x(1) + W1.h(0) + b1)
h(2) = tanh(U2.x(2) + W2.h(1) + b2)

#... and so on.

If you remove linearity by removing the tanh, here is what happens -

h(0) = U0.x(0) + b0
h(1) = U1.x(1) + W1.h(0) + b1
     = U1.x(1) + W1.(U0.x(0) + b0) + b1 #expanding x(0)
     = U1.x(1) + W1.U0.x(0) + W1.b0 + b1
     = U1.x(1) + W1.U0.x(0) + W1.b0 + b1
     = V1.x(1) + V0.x(0) + C    #Can be rewritten with new weights
     = V . x + C    #General form

So the final form of the state of an RNN after 2 timesteps is simply Wx+b like the linear layer without activation.

In other words, removing the non-linearity from an RNN turns it into a linear dense layer without any activation, completely removing the notion of time-steps.

Upvotes: 2

Theodor Peifer
Theodor Peifer

Reputation: 3506

I dont think so, you can choose between tanh and relu but is to be one of them when using nn.RNN as far as I know (and I dont think there is a work around). But you could implement the RNN youself quite easily without using the implemented module and then use whatever activation you want. They show an example of that in this Pytorch tutorial.

Upvotes: 2

Related Questions