Reputation: 73
I am writing a RESNET but I can not understand where does the "call" function is used.
Maybe this is automatically called by the TensorFlow, so it means we must write a function named "call"? If so, what should be the exact requirement for this "call" function? Thank you!!
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, Sequential
class BasicBlock(layers.Layer):
def __init__(self, filter_num, strides=1):
super(BasicBlock, self).__init__()
self.conv1 = layers.Conv2D(filter_num, (3, 3), strides=strides, padding="same")
self.bn1 = layers.BatchNormalization()
self.relu = layers.Activation('relu')
self.conv2 = layers.Conv2D(filter_num, (3, 3), strides=1, padding="same")
self.bn2 = layers.BatchNormalization()
if strides != 1:
self.downsample = layers.Conv2D(filter_num, (1, 1), strides=strides)
else:
self.downsample = lambda x:x
def call(self, inputs, training=None):
out = self.conv1(inputs)
out = self.bn1(out, training=training)
out = self.relu(out)
out = self.conv2(out, training=training)
out = self.bn2(out)
identity = self.downsample(inputs)
output = layers.add([out, identity])
output = tf.nn.relu(output)
return output
Upvotes: 1
Views: 1333
Reputation: 400
In TensorFlow, the call() method of a custom model is automatically called when you invoke the model instance as a function with input data. This is typically done using the parentheses () operator, which is a shorthand for calling the call() method.
For example, in the code provided in the question:
basicblock = BasicBlock(filter_num)
basicblock(input_data) # this will invoke the call method
It looks like this:
BasicBlock(filter_num)(input_data)
Upvotes: 0
Reputation: 1049
When you define a custom layer you will extend the base class tensorflow.keras.layers.Layer
and use it as follow:
import tensorflow as tf
class BasicBlock(tf.keras.layers.Layer):
...
basic_block = BasicBlock()
basic_block(inputs)
The last line of the snippet just above will call the magic method __call__
from the class (more info on magic methods here if you are interested A Guide to Python's Magic Methods)
Since you did not define the __call__
method in BasicBlock
(you defined the call
which is different), the __call__
from tensorflow.keras.layers.Layer
will be used.
This method has the following documentation according to Tensorflow documentation
Wraps call, applying pre- and post-processing steps.
Roughly speaking you will have (you can check the source code if you are interested but it is much more complex):
class Layer(...):
....
def __call__(self, ...):
# preprocessing steps
self.call(...)
# post processing steps
If you are familiar with inheritance you should guess the different steps when you use basic_block(inputs)
:
BasicBlock
has a method named __call__
=> NoLayer
has a method named __call__
=> Yes, use it and go inside this methodBasicBlock
has a method named call
=> Yes, use it and apply it to the inputsRegarding the requirements of the call
method to implement the best resource is the official Tensorflow documentation where you have everything explained about the input data structures expected + keyword arguments
Upvotes: 1
Reputation: 1374
The call function is used in the following manner:
basic_block = BasicBlock()
basic_block(args)
so it comes instead of:
basic_block.call(args)
Upvotes: 0