ClementWalter
ClementWalter

Reputation: 5272

How to loop over batch_size in keras custom layer

I want to create a custom layer that takes in __init__ a internal tensor and a custom dot function so that it computes for a given batch the dot function over all possible pairs made with the batch and the internal tensor.

If I were to use the natural inner product, I could write directly tf.matmul(inputs, self.internal_tensor, transpose_b=True) but I want to be able to give other kernel methods.

MWE:

import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Layer


class CustomLayer(Layer):

    def __init__(self, internal_tensor, kernel, **kwargs):
        super().__init__(**kwargs)
        self.internal_tensor = tf.Variable(0., shape=tf.TensorShape((None, 10)), validate_shape=False, name='internal_tensor')
        self.internal_tensor.assign(internal_tensor)
        self.kernel = kernel

    @tf.function
    def call(self, inputs, **kwargs):
        return self.kernel([
            tf.reshape(tf.tile(inputs, [1, self.internal_tensor.shape[0]]), [-1, inputs.shape[1]]),  # because no tf.repeat
            tf.tile(self.support_tensors, [inputs.shape[0], 1]),
        ])


custom_layer = CustomLayer(
    internal_tensor=tf.convert_to_tensor(np.random.rand(30, 10), tf.float32),
    kernel=lambda inputs: inputs[0] + inputs[1],
)
x = np.random.rand(15, 10).astype(np.float32)
custom_layer(x)

# TypeError: Failed to convert object of type <class 'list'> to Tensor. Contents: [1, None]. Consider casting elements to a supported type.

For the sake of clarity, here is the target working layer in Numpy:

class NumpyLayer:

    def __init__(self, internal_tensor, kernel):
        self.internal_tensor = internal_tensor
        self.kernel = kernel

    def __call__(self, inputs):
        return self.kernel([
            np.repeat(inputs, len(self.internal_tensor), axis=0),
            np.tile(self.internal_tensor, (len(inputs), 1)),
        ])

numpy_layer = NumpyLayer(
    internal_tensor=internal_tensor,
    kernel=lambda inputs: inputs[0] + inputs[1],
)
numpy_layer(x)

Upvotes: 0

Views: 2427

Answers (1)

ClementWalter
ClementWalter

Reputation: 5272

So all the troubles came from the use of tf.Tensor.shape instead of tf.shape(tf.Tensor).

Here is a working solution:

import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Layer


class CustomLayer(Layer):

    def __init__(self, internal_tensor, kernel, **kwargs):
        super().__init__(**kwargs)
        self.internal_tensor = tf.Variable(0., shape=tf.TensorShape((None, None)), validate_shape=False, name='internal_tensor')
        self.internal_tensor.assign(internal_tensor)
        self.kernel = kernel

    @tf.function
    def call(self, inputs, **kwargs):
        batch_size = tf.shape(inputs)[0]
        return self.kernel([
            tf.reshape(tf.tile(inputs, [1, tf.shape(self.internal_tensor)[0]]), [-1, inputs.shape[1]]),  # because no tf.repeat
            tf.tile(self.internal_tensor, [batch_size, 1]),
        ])


internal_tensor = np.random.rand(30, 10)
custom_layer = CustomLayer(
    internal_tensor=tf.convert_to_tensor(internal_tensor, tf.float32),
    kernel=lambda inputs: inputs[0] + inputs[1],
)
x = np.random.rand(10, 10).astype(np.float32)
custom_layer(x)

though there is still a warning:

WARNING:tensorflow:Entity <bound method CustomLayer.call of <tensorflow.python.eager.function.TfMethodTarget object at 0x7f8e7e2d8400>> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: converting <bound method CustomLayer.call of <tensorflow.python.eager.function.TfMethodTarget object at 0x7f8e7e2d8400>>: ValueError: Unable to locate the source code of <bound method CustomLayer.call of <tensorflow.python.eager.function.TfMethodTarget object at 0x7f8e7e2d8400>>. Note that functions defined in certain environments, like the interactive Python shell do not expose their source code. If that is the case, you should to define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.do_not_convert. Original error: could not get source code

Upvotes: 1

Related Questions