Reputation: 776
I have a quick (and possibly silly) question about how Tensorflow defines its Linear layer. Within PyTorch, a Linear (or Dense) layer is defined as, y = x A^T + b where A and b are the weight matrix and bias vector for a Linear layer (see here).
However, I can't precisely find an equivalent equation for Tensorflow! Is it the same as PyTorch or is it just y = x A + b ?
Thank you in advance!
Upvotes: 26
Views: 47326
Reputation: 17239
If we set activation to None
in the dense layer in keras
API, then they are technically equivalent.
Tensorflow's
tf.keras.layers.Dense(..., activation=None)
According to the doc, more study here.
activation: Activation function to use. If you don't specify anything, no activation is applied (ie. "linear" activation: a(x) = x).
And in PyTorch's src.
torch.nn.Linear
They are now equal at this point. A linear transformation to the incoming data: y = x*W^T + b
. See the following more concrete equivalent implementation of these two. In PyTorch
, we do
class Network(torch.nn.Module):
def __init__(self):
super(Network, self).__init__()
self.fc1 = torch.nn.Linear(5, 30)
def forward(self, state):
return self.fc1(state)
or,
trd = torch.nn.Linear(in_features = 3, out_features = 30)
y = trd(torch.ones(5, 3))
print(y.size())
# torch.Size([5, 30])
Its equivalent tf
implementation would be
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Dense(30, input_shape=(5,), activation=None))
or,
tfd = tf.keras.layers.Dense(30, input_shape=(3,), activation=None)
x = tfd(tf.ones(shape=(5, 3)))
print(x.shape)
# (5, 30)
Upvotes: 27
Reputation: 239
tf.keras.layers.Dense
is defined here in the tensorflow source code:
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/layers/core.py#L1081
If you follow the references in its call
function, it leads you to the definition of the operation used here, which is indeed a matrix multiplication of the inputs and weights plus a bias vector as expected:
outputs = gen_math_ops.MatMul(a=inputs, b=kernel)
...
outputs = nn_ops.bias_add(outputs, bias)
Upvotes: 7