unter_983
unter_983

Reputation: 155

Why tape.gradient returns all none in my Sequential model?

I have to compute the gradients of this model:

model=Sequential()
model.add(Dense(40, activation='relu',input_dim=12))
model.add(Dense(60, activation='relu'))
model.add(Dense(units=3, activation='softmax'))
opt=tf.keras.optimizers.Adam(lr=0.001)
model.compile(loss="mse", optimizer=opt)

model_q=Sequential()
model_q.add(Dense(40, activation='relu',input_dim=15))
model_q.add(Dense(60, activation='relu'))
model_q.add(Dense(units=1, activation='linear'))
opt=tf.keras.optimizers.Adam(lr=0.001)
model_q.compile(loss="mse", optimizer=opt)

x=np.random.random(12)
x2=model.predict(x.reshape(-1,12))
with tf.GradientTape() as tape:
            value = model_q([tf.convert_to_tensor(np.append(x,x2).reshape(-1,15))])
            loss = -tf.reduce_mean(value)
grad = tape.gradient(loss, model.trainable_variables)
opt.apply_gradients(zip(grad, model.trainable_variables))

but grad returns all none so opt can't apply gradients to the model. Why is this happening? I know it's quite a strange loss but it's the thing I would like to compute

Upvotes: 0

Views: 1976

Answers (1)

Andrea Angeli
Andrea Angeli

Reputation: 745

Your model is not being recorded by the tape. You have to put the computations into the context of the tape if you want to get gradients.

model=Sequential()
model.add(Dense(40, activation='relu',input_dim=12))
model.add(Dense(60, activation='relu'))
model.add(Dense(units=3, activation='softmax'))
opt=tf.keras.optimizers.Adam(lr=0.001)

model_q=Sequential()
model_q.add(Dense(40, activation='relu',input_dim=15))
model_q.add(Dense(60, activation='relu'))
model_q.add(Dense(units=1, activation='linear'))
opt=tf.keras.optimizers.Adam(lr=0.001)

x=np.random.random(12).reshape(-1,12)
with tf.GradientTape() as tape:
  x2 = model([x])
  value = model_q([tf.concat((x,x2), -1)])
  loss = -tf.reduce_mean(value)
grad = tape.gradient(loss, model.trainable_variables)
opt.apply_gradients(zip(grad, model.trainable_variables))

Upvotes: 1

Related Questions