Tim Bretschneider
Tim Bretschneider

Reputation: 44

pytorch's forward-function for tensorflow

What is the the counterpart in Tensorflow to pyTorch forward function?

I try to translate some pytorch code to tensorflow.

Upvotes: 0

Views: 499

Answers (1)

A.Mounir
A.Mounir

Reputation: 588

The forward function in nn.module in Pytorch can be replaced by the function "__call__()" in tf.module in Tensorflow or the function call() in tf.keras.layers.layer in keras. This is an example of a simple dense layer in tensorflow and keras:

Tensorflow:

class Dense(tf.Module):
  def __init__(self, input_dim, output_size, name=None):
     super().__init__(name=name)
     self.w = tf.Variable(tf.random.normal([input_dim, output_size]), name='w')
     self.b = tf.Variable(tf.zeros([output_size]), name='b')
  def __call__(self, x):
     y = tf.matmul(x, self.w) + self.b
     return tf.nn.relu(y)

Keras:

class Dense(tf.keras.Layers.Layer):
  def __init__(self, units=32):
     super(SimpleDense, self).__init__()
     self.units = units
  def build(self, input_shape):
     self.w = self.add_weight(shape=(input_shape[-1], self.units),
                           initializer='random_normal',
                           trainable=True)
     self.b = self.add_weight(shape=(self.units,),
                           initializer='random_normal',
                           trainable=True)
  def call(self, inputs):
     return tf.matmul(inputs, self.w) + self.b

You can check the following links for more details:

  1. https://www.tensorflow.org/api_docs/python/tf/Module
  2. https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer

Upvotes: 2

Related Questions