HankY
HankY

Reputation: 123

batch axis in keras custom layer

I want to make a custom layer that does the following, given a batch of input vectors.
For each vector a in the batch:

So if the batch is

[[ 1.,  2.,  3.],
 [ 4.,  5.,  6.],
 [ 7.,  8.,  9.],
 [10., 11., 12.]]

This should be a batch of 4 vectors, each with dimension 3 (or am I wrong here?). Then my layer should transform the batch to the following:

[[ 1.,  2.,  3.],
 [ 16., 20., 24.],
 [ 49., 56., 63.],
 [100., 110., 120.]]

Here is my implementation for the layer:

class MyLayer(keras.layers.Layer):
    def __init__(self, activation=None, **kwargs):            
        super().__init__(**kwargs)
        self.activation = keras.activations.get(activation)        

    def call(self, a):
        scale = a[0]
        return self.activation(a * scale)

    def get_config(self):
        base_config = super().get_config()
        return {**base_config,
                "activation": keras.activations.serialize(self.activation)}

But the output is different from what I expected:

batch = tf.Variable([[1,2,3],
                 [4,5,6],
                 [7,8,9],
                 [10,11,12]], dtype=tf.float32)
layer = MyLayer()
print(layer(batch))

Output:

tf.Tensor(
[[ 1.  4.  9.]
 [ 4. 10. 18.]
 [ 7. 16. 27.]
 [10. 22. 36.]], shape=(4, 3), dtype=float32)

It looks like the implementation actually treats each column as a vector, which is strange to me because other pre-written models, such as the sequential model, specify the input shape to be (batch_size, ...), which means each row, instead of column, is a vector.
How should I modify my code so that it behaves the way I want?

Upvotes: 0

Views: 360

Answers (1)

Kaveh
Kaveh

Reputation: 4970

Actually, your input shape is (4,3). So when you slice this tensor by a[0] it gets the first row which is [1,2,3]. To get what you want you should instead get the first column and then transpose your matrix to give you the desired matrix like this:

def call(self, a):
    scale = a[:,0]
    return tf.transpose(self.activation(tf.transpose(a) * scale))

Upvotes: 1

Related Questions