Y.Z.
Y.Z.

Reputation: 19

Custom TensorFlow loss function with batch size > 1?

I have some neural network with following code snippets, note that batch_size == 1 and input_dim == output_dim:

net_in = tf.Variable(tf.zeros(shape = [batch_size, input_dim]), dtype=tf.float32)
input_placeholder = tf.compat.v1.placeholder(shape = [batch_size, input_dim], dtype=tf.float32)
assign_input = net_in.assign(input_placeholder) 
# Some matmuls, activations, dropouts, normalizations...
net_out = tf.tanh(output_before_activation)


def loss_fn(output, input):
    #input.shape = output.shape = (batch_size, input_dim)
    output = tf.reshape(output, [input_dim,]) # shape them into 1d vectors
    input = tf.reshape(input, [input_dim,])
    return my_fn_that_only_takes_in_vectors(output, input)

# Create session, preprocess data ...

for epoch in epoch_num:
    for batch in range(total_example_num // batch_size):
        sess.run(assign_input, feed_dict = {input_placeholder : some_appropriate_numpy_array})
        sess.run(optimizer.minimize(loss_fn(net_out, net_in)))

Currently the neural network above works fine, but it is very slow because it updates gradient every sample (batch size = 1). I would like to set batch size > 1, but my_fn_that_only_takes_in_vectors cannot accommodate matrices whose first dimension is not 1. Due to the nature of my custom loss, flattening the batch input into a vector of length (batch_size * input_dim) seems to not work.

How would I write my new custom loss_fn now that the input and output are N x input_dim where N > 1? In Keras this would not have been an issue because keras somehow takes the average of the gradients of each example in the batch. For my TensorFlow function, should I take each row as a vector individually, pass them to my_fn_that_only_takes_in_vectors, then take the average of the results?

Upvotes: 0

Views: 464

Answers (1)

hammockman
hammockman

Reputation: 37

You can use a function that computes the loss on the whole batch, and works independently on the batch size. Basically the operations are applied to the whole first dimension of the input (the first dimension represents the element number in the batch). Here is an example, I hope this helps to see how the operations are carried out:

    def my_loss(y_true, y_pred):
       dx2 = tf.math.squared_difference(y_true[:, 0], y_true[:, 2])  # shape (BatchSize, )
       dy2 = tf.math.squared_difference(y_true[:, 1], y_true[:, 3])  # shape: (BatchSize, )
       denominator = dx2 + dy2 # shape: (BatchSize, )

       dst_vec = tf.math.squared_difference(y_true, y_pred)  # shape: (Batch, n_labels)
       numerator = tf.reduce_sum(dst_vec, axis=-1)  # shape: (BatchSize,)

       loss_vector = tf.cast(numerator / denominator, dtype="float32") # shape: (BatchSize,) this is a vector containing the loss of each element of the batch

       loss = tf.reduce_sum(loss_vector ) #if you want to sum the losses

       return loss

I am not sure whether you need to return the sum or the avg of the losses for the batch. If you sum, make sure to use a validation dataset with same batch size, otherwise the loss is not comparable.

Upvotes: 1

Related Questions