Reputation: 341
Here is an implementation of AND function with single neuron using tensorflow:
def tf_sigmoid(x):
return 1 / (1 + tf.exp(-x))
data = [
(0, 0),
(0, 1),
(1, 0),
(1, 1),
]
labels = [
0,
0,
0,
1,
]
n_steps = 1000
learning_rate = .1
x = tf.placeholder(dtype=tf.float32, shape=[2])
y = tf.placeholder(dtype=tf.float32, shape=None)
w = tf.get_variable('W', shape=[2], initializer=tf.random_normal_initializer(), dtype=tf.float32)
b = tf.get_variable('b', shape=[], initializer=tf.random_normal_initializer(), dtype=tf.float32)
h = tf.reduce_sum(x * w) + b
output = tf_sigmoid(h)
error = tf.abs(output - y)
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(error)
sess.run(tf.initialize_all_variables())
for step in range(n_steps):
for i in np.random.permutation(range(len(data))):
sess.run(optimizer, feed_dict={x: data[i], y: labels[i]})
Sometimes it works perfectly, but on some parameters it gets stuck and doesn't want to learn. For example with these initial parameters:
w = tf.Variable(initial_value=[-0.31199348, -0.46391705], dtype=tf.float32)
b = tf.Variable(initial_value=-1.94877, dtype=tf.float32)
it will hardly make any improvement in cost function. What am I doing wrong, maybe I should somehow adjust initialization of parameters?
Upvotes: 0
Views: 191
Reputation: 6367
Aren't you missing a mean(error)
?
Your problem is the particular combination of the sigmoid, the cost function, and the optimizer.
Don't feel bad, AFAIK this exact problem stalled the entire field for a few years.
Sigmoid is flat when you're far from the middle, and You're initializing it with relatively large numbers, try /1000.
So your abs-error (or square-error) is flat too, and the GradientDescent optimizer takes steps proportional to the slope.
Either of these should fix it:
Use cross-entropy for the error - it's convex.
Use a better Optimizer, like Adam , who's step size is much less dependent on the slope. More on the consistency of the slope.
Bonus: Don't roll your own sigmoid, use tf.nn.sigmoid
, you'll get a lot fewer NaN's that way.
Have fun!
Upvotes: 2