xiaolingxiao
xiaolingxiao

Reputation: 4885

Tensorflow learning XOR with linear function even though it shouldn't

I am implementing a simple network in tensorflow and for pedagogical purposes, I am trying to show that the linear transformation:

yhat = w(Wx + c) + b

cannot learn XOR. But the problem right now is that with my current implementation, it does! This suggests a bug in the code. Please elucidate?

############################################################
'''
    dummy data
'''
x_data = [[0.,0.],[0.,1.],[1.,0.],[1.,1.]]
y_data = [[0],[1],[1],[0]]

############################################################
'''
    Input and output
'''
X = tf.placeholder(tf.float32, shape = [4,2], name = 'x')
Y = tf.placeholder(tf.float32, shape = [4,1], name = 'y')

'''
    Network parameters
'''
W = tf.Variable(tf.random_uniform([2,2],-1,1), name = 'W')
c = tf.Variable(tf.zeros([2])                , name = 'c')
w = tf.Variable(tf.random_uniform([2,1],-1,1), name = 'w')
b = tf.Variable(tf.zeros([1])                , name = 'b')

############################################################
'''
    Network 1:

    function: Yhat = (w (x'W + c) + b)
    loss    : \sum_i Y * log Yhat
''' 
H1    = tf.matmul(X,  W) + c
Yhat1 = tf.matmul(H1, w) + b


cross_entropy1 = -tf.reduce_sum(
                Y*tf.log(
                        tf.clip_by_value(Yhat1,1e-10,1.0)
                        )
                )

step1 = tf.train.AdamOptimizer(0.01).minimize(cross_entropy1)

'''
    Train
'''

writer = tf.train.SummaryWriter("./logs/xor_logs.graph_def")
graph1 = tf.initialize_all_variables()
sess1  = tf.Session()
sess1.run(graph1)

for i in range(100):
    sess1.run(step1, feed_dict={X: x_data, Y: y_data})


'''
    Evaluation
''' 
corrects = tf.equal(tf.argmax(Y,1), tf.argmax(Yhat1,1))
accuracy = tf.reduce_mean(tf.cast(corrects, tf.float32))
r        = sess1.run(accuracy, feed_dict={X: x_data, Y: y_data})
print ('accuracy: ' + str(r * 100) + '%')

Right now accuracy is at 100%, even though it should be at 75%.

Upvotes: 0

Views: 222

Answers (1)

MMN
MMN

Reputation: 676

tf.argmax(Y,1) will return [0,0,0,0]. This is not what you want.

Upvotes: 1

Related Questions