senek
senek

Reputation: 73

Pytorch how use a linear activation function

In Keras, I can create any network layer with a linear activation function as follows (for example, a fully-connected layer is taken):

model.add(keras.layers.Dense(outs, input_shape=(160,), activation='linear'))

But I can't find the linear activation function in the PyTorch documentation. ReLU is not suitable, because there are negative values in my sample. How do I create a layer with a linear activation function in PyTorch?

Upvotes: 0

Views: 19332

Answers (4)

Yonas Kassa
Yonas Kassa

Reputation: 3690

As already answered you don't need a linear activation layer in pytorch. But if you need to include it, you can write a custom one, that passes the output as follows.

class linear(torch.nn.Module):
    # a linear activation function based on y=x
    def forward(self, output):return output

Then you can call it like any other activation function.

linear()(torch.tensor([1,2,3])) == nn.ReLU()(torch.tensor([1,2,3]))

Upvotes: 2

momo668
momo668

Reputation: 433

If I assume correct you are doing this due to trying to plug it into some module you built where it takes an 'activation' as argument but now you don't want any. You can just use a linear layer like:

nn.Linear(1,1)

which actually added a parameter which is not the same as passthrough above but works too.

Upvotes: 0

Lior Cohen
Lior Cohen

Reputation: 5745

activation='linear' is equivavlent to no activation at all.

As can be seen here, it is also called "passthrough", meaning the it does nothing.

So in pytorch you can simply not apply any activation at all, to be in parity.

However, as already told by @Minsky, hidden layer without real activation, i.e. some non-linear activation is useless. It is like changing the weights which is anyway done during the network taining.

Upvotes: 1

Ivan
Ivan

Reputation: 40648

If you take a look at the Keras documentation, you will see tf.keras.layers.Dense's activation='linear' corresponds to the a(x) = x function. Which means no non-linearity.

So in PyTorch, you just define the linear function without adding any activation layer:

torch.nn.Linear(160, outs)

Upvotes: 3

Related Questions