Reputation: 18218
I am creating a custom algorithm with a custom dataloader in Keras. I understand that when trying to access tensors inside the internal methods of the models that you often get a None
when you print the shape of the tensor, often on the batch axis, as the batch size can be variable. I have created a custom method of updating gradients, and just for a sanity check I am trying to print the actual value of the the shape of this axis when the program is executing. I cannot figure out how to do it.
Here is some code, see where I wrote THIS LINE
. This code will print out the following output, which shows the batch axis as NONE
. Just for debugging purposes I would actually like to see what this value is when the code runs, how do I do that?
(None, 4, 100) (None, 100) (None, 100, 100) (None, 100) (None, 100, 100) (None, 100) (None, 100, 1) (None, 1)
class CustomModel(keras.Model):
def train_step(self, data):
# Unpack the data. Its structure depends on your model and
# on what you pass to `fit()`.
x, y = data
tao = 1
with tf.GradientTape() as tape:
y_pred = self(x, training=True) # Forward pass
# Compute the loss value
# (the loss function is configured in `compile()`)
loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)
gradients = tape.jacobian(loss, self.trainable_variables)
new_gradients = []
for grad in gradients:
print(grad.shape) # <--- THIS LINE
q1 = K.mean( grad[:env_siz], axis=0 )
q2 = K.mean( grad[env_siz:], axis=0 )
Q = K.mean( K.stack((K.sign(q1), K.sign(q2))), axis=0 ) # 1 means all gradients in same direction on that axis
P = tf.where( tf.abs(Q) >= tao, K.mean( K.stack((q1, q2)), axis=0 ), 0)
# print(P)
new_gradients.append( P )
# Compute gradients
trainable_vars = self.trainable_variables
# gradients = tape.gradient(loss, trainable_vars)
# Update weights
self.optimizer.apply_gradients(zip(new_gradients, trainable_vars))
# Update metrics (includes the metric that tracks the loss)
self.compiled_metrics.update_state(y, y_pred)
# Return a dict mapping metric names to current value
return { m.result() for m in self.metrics}
Upvotes: 1
Views: 1186