Reputation: 23
Notice when I created my model, I defined the call function with argument something = False, when I used the model in function train_step, I put in "something =True, training = True", training is not defined in my call, but it is in the default tf.keras.model call.
Why am I able to execute this with no error? And the output basically prints a bunch of 'my call's.
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
# Add a channels dimension
x_train = x_train[..., tf.newaxis].astype("float32")
x_test = x_test[..., tf.newaxis].astype("float32")
train_ds = tf.data.Dataset.from_tensor_slices(
(x_train, y_train)).shuffle(10000).batch(32)
class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
self.fl = Flatten()
self.d = Dense(10)
######My problem#######
def call(self, x, something=False):
if something:
tf.print('my call')
x = self.fl(x)
return self.d(x)
model = MyModel()
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.Adam()
@tf.function
def train_step(X,Y):
with tf.GradientTape() as tape:
######My problem#######
predictions = model(X, something =True, training = True)
loss = loss_object(Y, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
for epoch in range(3):
for X,Y in train_ds:
train_step(X,Y)
Upvotes: 1
Views: 397
Reputation: 1508
In the Model class, the call method documentation :
To call a model on an input, always use the
__call__()̀
method, i.e. model(inputs), which relies on the underlyingcall()
method.
And indeed, the __call__
can take any input argument : def __call__(self, *args, **kwargs):
(in Model class source code)
You can find a more detailed answer here
Upvotes: 2