Kevin O'Hara
Kevin O'Hara

Reputation: 11

Tensorflow 2 variable not trainable

I've made a simple model in tf2 which multiplies the input 'a' by a variable 'b' (initialized to 1) and returns the output 'c'. I then try to train it on the simple dataset a=1, c=5. I expect it to learn b=5.

import tensorflow as tf
from tensorflow.keras.models import Model

a = Input(shape=(1,))
b = tf.Variable(1., trainable=True)
c = a*b
model = Model(a,c)

loss = tf.keras.losses.MeanAbsoluteError()
model.compile(optimizer='adam', loss=loss)

model.fit([1.],[5.],batch_size=1, epochs=1)

However, tf2 does not see the variable 'b' as being trainable. The summary shows no trainable parameters.

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 1)]               0         
_________________________________________________________________
tf_op_layer_mul (TensorFlowO [(None, 1)]               0         
=================================================================
Total params: 0
Trainable params: 0
Non-trainable params: 0
_________________________________________________________________

Why is the variable 'b' not training?

Upvotes: 1

Views: 1854

Answers (1)

Vivek Mehta
Vivek Mehta

Reputation: 2642

Keras model is wrapper around Layer class. You'll have to wrap this variable as keras layer in order to show this as trainable parameter in model.

You can create a tiny custom layer for that like this:

class MyLayer(tf.keras.layers.Layer):
  def __init__(self):
    super(MyLayer, self).__init__()

    #your variable goes here
    self.variable = tf.Variable(1., trainable=True, dtype=tf.float64)

  def call(self, inputs, **kwargs):

    # your mul operation goes here
    x = inputs * self.variable

    return x

Here call method will do multiplication operation. We can use this layer just like any other layer in out model. Here I am creating a Sequential model adding aboce multiplication operation as a model layer.

model = tf.keras.models.Sequential()
mylayer_object = MyLayer()
model.add(mylayer_object)

loss = tf.keras.losses.MeanAbsoluteError()
model.compile("adam", loss)

model.fit([1.],[5.],batch_size=1, epochs=1)
model.summary()
'''
Train on 1 samples
1/1 [==============================] - 0s 426ms/sample - loss: 4.0000
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
my_layer (MyLayer)           multiple                  1         
=================================================================
Total params: 1
Trainable params: 1
Non-trainable params: 0
_________________________________________________________________
'''

After this if you can list out model's trainable parameters.

print(model.trainable_variables)
# [<tf.Variable 'Variable:0' shape=() dtype=float64, numpy=1.0009999968852092>]

Upvotes: 1

Related Questions