Shouyu Chen
Shouyu Chen

Reputation: 655

my simple loss function cause NAN

I have write a customer loss for myself, but after several steps, the loss became nan, my code is

def my_loss(label_batch, logits_batch, alpha=1.3, beta=0.5):
    softmax_logits_batch = tf.nn.softmax(logits_batch, axis=-1)

    indices_not_0 = tf.where(tf.not_equal(label_batch, 0))  # not-zero indices
    indices_0 = tf.where(tf.equal(label_batch, 0))  # zero indices

    predict_not_0 = tf.gather_nd(softmax_logits_batch, indices_not_0)
    predict_0 = tf.gather_nd(softmax_logits_batch, indices_0)
    avg_p_not_0 = tf.reduce_mean(predict_not_0, axis=0)
    avg_p_0 = tf.reduce_mean(predict_0, axis=0)

    euclidean_distance = tf.sqrt(tf.reduce_sum(tf.square(avg_p_0 - avg_p_not_0)))
    max_value = tf.maximum(alpha - euclidean_distance, 0)
    return max_value

Some basic ideas behind are:

  1. My loss is for semantic segmentation which has only 2 categories.

  2. The shape of label_batch is (?, H, W), all the value in it are 0 or 1. The shape of logits_batch is (?, H, W, 2) the value of logits_batch is the logits of the FCN (without Softmax).

  3. I want to find all the logits values (predict_0 or predict_not_0) whose label value is 0 or 1 respectively by indices_0 or indices_not_0.

  4. The shape of both predict_not_0 and predict_0 should be (?, 2)

  5. Calculate the average value for predict_not_0 and predict_0 respectively (which represents the central point coordinates of Euclidean space for category 0 and category 1). The shape of them should be (2,)

  6. calculate the Euclidean distance between two central point coordinates, and it should larger than a certain alpha value (alpha = 1.3 for example)

Now, the problem is after several steps, the loss value become nan.

the output of the code is (I used a very small learning rate parameter)

Epoch[0],step[1],train batch loss = 2.87282,train acc = 0.486435.
Epoch[0],step[2],train batch loss = 2.87282,train acc = 0.485756.
Epoch[0],step[3],train batch loss = 2.87281,train acc = 0.485614.
Epoch[0],step[4],train batch loss = 2.87282,train acc = 0.485649.
Epoch[0],step[5],train batch loss = 2.87282,train acc = 0.485185.
Epoch[0],step[6],train batch loss = 2.87279,train acc = 0.485292.
Epoch[0],step[7],train batch loss = 2.87281,train acc = 0.485222.
Epoch[0],step[8],train batch loss = 2.87282,train acc = 0.484989.
Epoch[0],step[9],train batch loss = 2.87282,train acc = 0.48406.
Epoch[0],step[10],train batch loss = 2.8728,train acc = 0.483306.
Epoch[0],step[11],train batch loss = 2.87281,train acc = 0.483426.
Epoch[0],step[12],train batch loss = 2.8728,train acc = 0.482954.
Epoch[0],step[13],train batch loss = 2.87281,train acc = 0.482535.
Epoch[0],step[14],train batch loss = 2.87281,train acc = 0.482225.
Epoch[0],step[15],train batch loss = 2.87279,train acc = 0.482005.
Epoch[0],step[16],train batch loss = 2.87281,train acc = 0.48182.
Epoch[0],step[17],train batch loss = 2.87282,train acc = 0.48169.
Epoch[0],step[18],train batch loss = 2.8728,train acc = 0.481279.
Epoch[0],step[19],train batch loss = 2.87281,train acc = 0.480878.
Epoch[0],step[20],train batch loss = 2.87281,train acc = 0.480607.
Epoch[0],step[21],train batch loss = 2.87278,train acc = 0.480186.
Epoch[0],step[22],train batch loss = 2.87281,train acc = 0.479925.
Epoch[0],step[23],train batch loss = 2.87282,train acc = 0.479617.
Epoch[0],step[24],train batch loss = 2.87282,train acc = 0.479378.
Epoch[0],step[25],train batch loss = 2.87281,train acc = 0.479496.
Epoch[0],step[26],train batch loss = 2.87281,train acc = 0.479354.
Epoch[0],step[27],train batch loss = 2.87282,train acc = 0.479262.
Epoch[0],step[28],train batch loss = 2.87282,train acc = 0.479308.
Epoch[0],step[29],train batch loss = 2.87282,train acc = 0.479182.
Epoch[0],step[30],train batch loss = 2.22282,train acc = 0.478985.
Epoch[0],step[31],train batch loss = nan,train acc = 0.494112.
Epoch[0],step[32],train batch loss = nan,train acc = 0.508811.
Epoch[0],step[33],train batch loss = nan,train acc = 0.523289.
Epoch[0],step[34],train batch loss = nan,train acc = 0.536233.
Epoch[0],step[35],train batch loss = nan,train acc = 0.548851.
Epoch[0],step[36],train batch loss = nan,train acc = 0.561351.
Epoch[0],step[37],train batch loss = nan,train acc = 0.573149.
Epoch[0],step[38],train batch loss = nan,train acc = 0.584382.
Epoch[0],step[39],train batch loss = nan,train acc = 0.595006.
Epoch[0],step[40],train batch loss = nan,train acc = 0.605065.
Epoch[0],step[41],train batch loss = nan,train acc = 0.614475.
Epoch[0],step[42],train batch loss = nan,train acc = 0.623371.
Epoch[0],step[43],train batch loss = nan,train acc = 0.632092.
Epoch[0],step[44],train batch loss = nan,train acc = 0.640199.
Epoch[0],step[45],train batch loss = nan,train acc = 0.647391.

I used exactly the same code before, except the loss function is tf.nn.sparse_softmax_cross_entropy_with_logits() and everything works, so I suppose there is something wrong in my new loss function.

I have a guess, maybe some batch data only have one category's label (only 0 or 1), so one of predict_not_0 and predict_0 will have no data therefore, but I don't know how to code to validate whether there has data in predict_not_0 and predict_0

can somebody help me find where the problem is and how can I improve my loss function to avoid nan?

Upvotes: 2

Views: 287

Answers (2)

P-Gn
P-Gn

Reputation: 24591

This is probably due to the use of tf.sqrt, which has the bad property of having an exploding gradient near 0. Therefore, you are progressively hitting more numerical instabilities as you converge.

The solution is to get rid of tf.sqrt. You could minimize the squared euclidean distance for example.

Another potentiel source of error is tf.reduce_mean, which could return NaN when operated on an empty list. You need to figure out what you want your loss to be when that happens.

Upvotes: 2

Kaihong Zhang
Kaihong Zhang

Reputation: 419

nan is caused by 0.0/0.0, log(0.0) or some other computations in many programming language because of the floating point number computation, usually in very big or small number(treated as Infinity or zero because of accuracy).

tf.nn.softmax is not safe enough while training, try some other functions instead, like tf.log_softmax, tf.softmax_cross_entropy_with_logits and so on.

Upvotes: 0

Related Questions