Brian
Brian

Reputation: 14836

Tensorflow pairwise dot product for batches

I want to train a neural network that takes an input three list of floats for each element of the batch. For example an element of the batch will look like vec = [vec_a, vec_b, vec_c] = [1, 2.,3.], [1.5, 6.7, 9.], [3.4, 0.4, 0.3] and I would like the first layer of the network to return the pairwise dot product of each distinct element. In this case that would be vec_a*vec_b+vec_a*vec_c+vec_b*vec_c (* means dot product between two vectors in this case)

If I want to translate this into a tensorflow model I can do the following

import tensorflow as tf

def pairwise_dot_product(x):
  matrix_dot_product = tf.tensordot(x, tf.transpose(x), axes=1)
  matrix_sum = tf.math.reduce_sum(matrix_dot_product)
  matrix_diag_sum = tf.linalg.trace(matrix_dot_product)
  return (matrix_sum - matrix_diag_sum)/2

model = tf.keras.Sequential()
model.add(tf.keras.layers.Lambda(pairwise_dot_product, input_shape=(None, ), name="pairwise_dot_product"))
model.compile(optimizer="sgd", loss="categorical_crossentropy")

If I evaluate the first layer of the network on vec

model.layers[0].apply(vec)

I do get the correct answer (57.480003). Now the problem is that I would like to train this model for something that looks like this training_data = [vec_1, vec_2, vec_3, ...]. For the sake of simplicity let's say that I have training_data = [vec, vec, ...], so I expect the first layer of the network to return [57.480003, 57.480003, ...]. How can I modify the network to do that? I think the problem is that the pairwise_dot_product function I defined is applied to the entire training batch, but I would like to only be applied to each element of the batch (vec).

Upvotes: 0

Views: 1012

Answers (1)

Marco Cerliani
Marco Cerliani

Reputation: 22031

try in this way

def pairwise_dot_product(x):

    matrix_dot_product = tf.keras.backend.batch_dot(x, tf.transpose(x, [0,2,1]), axes=[2,1])
    matrix_sum = tf.math.reduce_sum(matrix_dot_product, axis=[1,2])
    matrix_diag_sum = tf.linalg.trace(matrix_dot_product)
    return (matrix_sum - matrix_diag_sum)/2

model = tf.keras.Sequential()
model.add(tf.keras.layers.Lambda(pairwise_dot_product, input_shape=(None,None), 
                                 name="pairwise_dot_product"))

vec = [[[1, 2.,3.], [1.5, 6.7, 9.], [3.4, 0.4, 0.3]]]
vec = tf.constant(vec*10) # repeat 10 times vec

model(vec)

results:

<tf.Tensor: shape=(10,), dtype=float32, numpy=
array([57.480003, 57.480003, 57.480003, 57.480003, 57.480003, 57.480003,
       57.480003, 57.480003, 57.480003, 57.480003], dtype=float32)>

Upvotes: 1

Related Questions