Reputation: 44
What is the the counterpart in Tensorflow to pyTorch forward function?
I try to translate some pytorch code to tensorflow.
Upvotes: 0
Views: 499
Reputation: 588
The forward function in nn.module in Pytorch can be replaced by the function "__call__()"
in tf.module in Tensorflow or the function call()
in tf.keras.layers.layer in keras. This is an example of a simple dense layer in tensorflow and keras:
Tensorflow:
class Dense(tf.Module):
def __init__(self, input_dim, output_size, name=None):
super().__init__(name=name)
self.w = tf.Variable(tf.random.normal([input_dim, output_size]), name='w')
self.b = tf.Variable(tf.zeros([output_size]), name='b')
def __call__(self, x):
y = tf.matmul(x, self.w) + self.b
return tf.nn.relu(y)
Keras:
class Dense(tf.keras.Layers.Layer):
def __init__(self, units=32):
super(SimpleDense, self).__init__()
self.units = units
def build(self, input_shape):
self.w = self.add_weight(shape=(input_shape[-1], self.units),
initializer='random_normal',
trainable=True)
self.b = self.add_weight(shape=(self.units,),
initializer='random_normal',
trainable=True)
def call(self, inputs):
return tf.matmul(inputs, self.w) + self.b
You can check the following links for more details:
Upvotes: 2