jeffery_the_wind
jeffery_the_wind

Reputation: 18218

Print shape of tensor even when it is `None`

I am creating a custom algorithm with a custom dataloader in Keras. I understand that when trying to access tensors inside the internal methods of the models that you often get a None when you print the shape of the tensor, often on the batch axis, as the batch size can be variable. I have created a custom method of updating gradients, and just for a sanity check I am trying to print the actual value of the the shape of this axis when the program is executing. I cannot figure out how to do it.

Here is some code, see where I wrote THIS LINE. This code will print out the following output, which shows the batch axis as NONE. Just for debugging purposes I would actually like to see what this value is when the code runs, how do I do that?

(None, 4, 100) (None, 100) (None, 100, 100) (None, 100) (None, 100, 100) (None, 100) (None, 100, 1) (None, 1)

class CustomModel(keras.Model):
    def train_step(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        x, y = data
        
        tao = 1

        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)  # Forward pass
            # Compute the loss value
            # (the loss function is configured in `compile()`)
            loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)
            
        gradients = tape.jacobian(loss, self.trainable_variables)
        
        new_gradients = []
        for grad in gradients:

            print(grad.shape) # <--- THIS LINE

            q1 = K.mean( grad[:env_siz], axis=0 )
            q2 = K.mean( grad[env_siz:], axis=0 )

            Q = K.mean( K.stack((K.sign(q1), K.sign(q2))), axis=0 ) # 1 means all gradients in same direction on that axis
            P = tf.where( tf.abs(Q) >= tao, K.mean( K.stack((q1, q2)), axis=0 ), 0)
#             print(P)
            new_gradients.append( P )

        # Compute gradients
        trainable_vars = self.trainable_variables
#         gradients = tape.gradient(loss, trainable_vars)
        # Update weights
        self.optimizer.apply_gradients(zip(new_gradients, trainable_vars))
        # Update metrics (includes the metric that tracks the loss)
        self.compiled_metrics.update_state(y, y_pred)
        # Return a dict mapping metric names to current value
        return {m.name: m.result() for m in self.metrics}

Upvotes: 1

Views: 1186

Answers (1)

javidcf
javidcf

Reputation: 59731

You can use tf.print instead of print to be able to see values of tensors within "graphed" functions. Instead of accessing the .shape attribute, which will always be the statically-known shape, use tf.shape to read the actual tensor shape.

tf.print(tf.shape(grad))

Upvotes: 2

Related Questions