xiaoming Li
xiaoming Li

Reputation: 73

Where does the "call" function used in TensorFlow?

I am writing a RESNET but I can not understand where does the "call" function is used.

Maybe this is automatically called by the TensorFlow, so it means we must write a function named "call"? If so, what should be the exact requirement for this "call" function? Thank you!!

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, Sequential

class BasicBlock(layers.Layer):
    def __init__(self, filter_num, strides=1):
        super(BasicBlock, self).__init__()
        self.conv1 = layers.Conv2D(filter_num, (3, 3), strides=strides, padding="same")
        self.bn1 = layers.BatchNormalization()
        self.relu = layers.Activation('relu')
        self.conv2 = layers.Conv2D(filter_num, (3, 3), strides=1, padding="same")
        self.bn2 = layers.BatchNormalization()

        if strides != 1:
            self.downsample = layers.Conv2D(filter_num, (1, 1), strides=strides)
        else:
            self.downsample = lambda x:x

    def call(self, inputs, training=None): 
        out = self.conv1(inputs)
        out = self.bn1(out, training=training)
        out = self.relu(out)

        out = self.conv2(out, training=training)
        out = self.bn2(out)

        identity = self.downsample(inputs)

        output = layers.add([out, identity])
        output = tf.nn.relu(output)

        return output

Upvotes: 1

Views: 1333

Answers (3)

amran hossen
amran hossen

Reputation: 400

In TensorFlow, the call() method of a custom model is automatically called when you invoke the model instance as a function with input data. This is typically done using the parentheses () operator, which is a shorthand for calling the call() method.

For example, in the code provided in the question:

basicblock = BasicBlock(filter_num)
basicblock(input_data) # this will invoke the call method 

It looks like this:

BasicBlock(filter_num)(input_data)

Upvotes: 0

M. Perier--Dulhoste
M. Perier--Dulhoste

Reputation: 1049

When you define a custom layer you will extend the base class tensorflow.keras.layers.Layer and use it as follow:

import tensorflow as tf

class BasicBlock(tf.keras.layers.Layer):
   ...

basic_block = BasicBlock()
basic_block(inputs)

The last line of the snippet just above will call the magic method __call__ from the class (more info on magic methods here if you are interested A Guide to Python's Magic Methods)

Since you did not define the __call__ method in BasicBlock (you defined the call which is different), the __call__ from tensorflow.keras.layers.Layer will be used.

This method has the following documentation according to Tensorflow documentation

Wraps call, applying pre- and post-processing steps.

Roughly speaking you will have (you can check the source code if you are interested but it is much more complex):

class Layer(...):
   ....
   def __call__(self, ...):
      # preprocessing steps
      self.call(...)
      # post processing steps

If you are familiar with inheritance you should guess the different steps when you use basic_block(inputs):

  1. Check if BasicBlock has a method named __call__ => No
  2. Check if the base class Layer has a method named __call__ => Yes, use it and go inside this method
  3. Apply the preprocessing steps
  4. Check if BasicBlock has a method named call => Yes, use it and apply it to the inputs
  5. Apply the post processing steps

Regarding the requirements of the call method to implement the best resource is the official Tensorflow documentation where you have everything explained about the input data structures expected + keyword arguments

Upvotes: 1

Dr. Prof. Patrick
Dr. Prof. Patrick

Reputation: 1374

The call function is used in the following manner:

basic_block = BasicBlock()
basic_block(args)

so it comes instead of:

basic_block.call(args)

Upvotes: 0

Related Questions