Reputation: 103
I've been trying to implement Logistic Regression in TensorFlow following the MNIST example but with data from a CSV. Each row is one sample and has 12 dimensions. My code is the following:
batch_size = 5
learning_rate = .001
x = tf.placeholder(tf.float32,[None,12])
y = tf.placeholder(tf.float32,[None,2])
W = tf.Variable(tf.zeros([12,2]))
b = tf.Variable(tf.zeros([2]))
mult = tf.matmul(x,W)
pred = tf.nn.softmax(mult+b)
cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred), reduction_indices=1))
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
sess = tf.Session()
sess.run(tf.initialize_all_variables())
avg_cost = 0
total_batch = int(len(Xtrain)/batch_size)
for i in range(total_batch):
batch_xs = Xtrain[i*batch_size:batch_size*i+batch_size]
batch_ys = ytrain[i*batch_size:batch_size*i+batch_size]
_, c = sess.run([optimizer, cost], feed_dict={x: batch_xs,y: batch_ys})
print(c)
Xtrain is a 252x10 numpy array, and ytrain is a 252x2 one hot numpy array.
The Problem: the cost c gets calculated for the first iteration (value is 0.6931...), but for every iteration after, it returns 'nan.'
Things I've Tried: I made sure every component aspect of the model was working. The issue happens entirely after the first iteration. I've played around with the learning rate, but that doesn't do anything. I've tried initializing the weights as truncated_normal (which I shouldn't need to do for logistic regression anyway), but that doesn't help either.
So, any thoughts? I've spent around 3 hours trying to fix it and have run out of ideas. It seems like something just isn't working when TensorFlow goes to optimize the cost function.
Upvotes: 2
Views: 1466
Reputation: 507
The issue you are having is because log(pred) is not defined for pred = 0. The "hacky" way around this is to use tf.maximum(pred, 1e-15)
or tf.clip_by_value(pred, 1e-15, 1.0)
.
An even better solution, however, is using tf.nn.softmax_cross_entropy_with_logits(pred)
instead of applying softmax and cross-entropy separately, which deals with edge cases like this (hence all your problems) automatically!
For further reading, I'd recommend this great answer: https://stackoverflow.com/a/34243720/5829427
Upvotes: 4