Reza Ghoddosian
Reza Ghoddosian

Reputation: 71

what is the difference of 'call' v.s '__call__' RNN methods in tensorflow?

I know what__call__ is,but what confuses me is that some classes like BasicRNNCell or tf.nn.rnn_cell.MultiRNNCell have this 'call' method instead of _call__ . What is this plain call method? it seems like the same thing , if it is not then i didnt see it being called. I found this explanation somewhere with no clue. can you clarify please?

"The call function is where the logic of your cell will live. RNNCell’s __call_ method will wrap your call method and help with scoping and other logistics." sample:

def call(self, inputs, state):

    total_hidden_size = sum(c._h_above_size for c in self._cells)

    # split out the part of the input that stores values of ha
    raw_inp = inputs[:, :-total_hidden_size]                # [B, I]
    raw_h_aboves = inputs[:, -total_hidden_size:]           # [B, sum(ha_l)]

    ha_splits = [c._h_above_size for c in self._cells]
    h_aboves = array_ops.split(value=raw_h_aboves,
                               num_or_size_splits=ha_splits, axis=1)

    z_below = tf.ones([tf.shape(inputs)[0], 1])             # [B, 1]
    raw_inp = array_ops.concat([raw_inp, z_below], axis=1) # [B, I + 1]

Upvotes: 6

Views: 1918

Answers (1)

Alen Chen
Alen Chen

Reputation: 41

In tensorflow2.0, if you define your network by subclassing tf.keras.Model, you need to implement the model's forward pass in call().

https://www.tensorflow.org/api_docs/python/tf/keras/Model

Upvotes: 3

Related Questions