Reputation: 71
I know what__call__ is,but what confuses me is that some classes like BasicRNNCell or tf.nn.rnn_cell.MultiRNNCell have this 'call' method instead of _call__ . What is this plain call method? it seems like the same thing , if it is not then i didnt see it being called. I found this explanation somewhere with no clue. can you clarify please?
"The call function is where the logic of your cell will live. RNNCell’s __call_ method will wrap your call method and help with scoping and other logistics." sample:
def call(self, inputs, state):
total_hidden_size = sum(c._h_above_size for c in self._cells)
# split out the part of the input that stores values of ha
raw_inp = inputs[:, :-total_hidden_size] # [B, I]
raw_h_aboves = inputs[:, -total_hidden_size:] # [B, sum(ha_l)]
ha_splits = [c._h_above_size for c in self._cells]
h_aboves = array_ops.split(value=raw_h_aboves,
num_or_size_splits=ha_splits, axis=1)
z_below = tf.ones([tf.shape(inputs)[0], 1]) # [B, 1]
raw_inp = array_ops.concat([raw_inp, z_below], axis=1) # [B, I + 1]
Upvotes: 6
Views: 1918
Reputation: 41
In tensorflow2.0, if you define your network by subclassing tf.keras.Model, you need to implement the model's forward pass in call().
https://www.tensorflow.org/api_docs/python/tf/keras/Model
Upvotes: 3