Reputation: 393
I am a beginner in tensorflow and machine learning. I want to try a simple linear regression example by tensorflow.
But the loss can't decrease after 3700 epoch. I don't know what's wrong?
Obviously, we got the W = 3.52, b = 2.8865
. So y = 3.52*x + 2.8865
. When testing data x = 11, y = 41.6065
. But this is error. Because the training data x = 10, y = 48.712
.
The code and loss posted in below.
#Goal: predict the house price in 2017 by linear regression method
#Step: 1. load the original data
# 2. define the placeholder and variable
# 3. linear regression method
# 4. launch the graph
from __future__ import print_function
import os
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
# 1. load the original data
price = np.asarray([6.757, 12.358, 10.091, 11.618, 14.064,
16.926, 17.673, 22.271, 26.905, 34.742, 48.712])
year = np.asarray([0,1,2,3,4,5,6,7,8,9,10])
n_samples = price.shape[0]
# 2. define the placeholder and variable
x = tf.placeholder("float")
y_ = tf.placeholder("float")
W = tf.Variable(np.random.randn())
b = tf.Variable(np.random.randn())
# 3. linear regression method
y = tf.add(tf.multiply(x, W), b)
loss = tf.reduce_mean(tf.square(y - y_))/(2*n_samples)
training_step = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
# 4. launch the graph
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for epoch in range(10000):
for (year_epoch, price_epoch) in zip(year, price):
sess.run(training_step, feed_dict = {x: year_epoch, y_: price_epoch})
if (epoch+1) % 50 == 0:
loss_np = sess.run(loss, feed_dict={x: year, y_: price})
print("Epoch: ", '%04d' % (epoch+1), "loss = ", "{:.9f}".format(loss_np), "W = ", sess.run(W), "b = ", sess.run(b))
# print "Training finish"
training_loss = sess.run(loss, feed_dict = {x: year, y_: price})
print("Training cost = ", training_loss, "W = ", sess.run(W), "b = ", sess.run(b), '\n')
And the loss is:
Epoch: 0050 loss = 1.231071353 W = 3.88227 b = 0.289058
Epoch: 0100 loss = 1.207471132 W = 3.83516 b = 0.630129
Epoch: 0150 loss = 1.189429402 W = 3.79423 b = 0.926415
Epoch: 0200 loss = 1.175611973 W = 3.75868 b = 1.1838
Epoch: 0250 loss = 1.165009260 W = 3.72779 b = 1.40738
Epoch: 0300 loss = 1.156855702 W = 3.70096 b = 1.60161
Epoch: 0350 loss = 1.150570631 W = 3.67766 b = 1.77033
Epoch: 0400 loss = 1.145712137 W = 3.65741 b = 1.9169
Epoch: 0450 loss = 1.141945601 W = 3.63982 b = 2.04422
Epoch: 0500 loss = 1.139016271 W = 3.62455 b = 2.15483
Epoch: 0550 loss = 1.136731029 W = 3.61127 b = 2.25091
Epoch: 0600 loss = 1.134940267 W = 3.59974 b = 2.33437
Epoch: 0650 loss = 1.133531928 W = 3.58973 b = 2.40688
Epoch: 0700 loss = 1.132419944 W = 3.58103 b = 2.46986
Epoch: 0750 loss = 1.131537557 W = 3.57347 b = 2.52458
Epoch: 0800 loss = 1.130834818 W = 3.5669 b = 2.57211
Epoch: 0850 loss = 1.130271792 W = 3.5612 b = 2.6134
Epoch: 0900 loss = 1.129818439 W = 3.55625 b = 2.64927
Epoch: 0950 loss = 1.129452229 W = 3.55194 b = 2.68042
Epoch: 1000 loss = 1.129154325 W = 3.5482 b = 2.70749
Epoch: 1050 loss = 1.128911495 W = 3.54496 b = 2.731
Epoch: 1100 loss = 1.128711581 W = 3.54213 b = 2.75143
Epoch: 1150 loss = 1.128546953 W = 3.53968 b = 2.76917
Epoch: 1200 loss = 1.128411174 W = 3.53755 b = 2.78458
Epoch: 1250 loss = 1.128297567 W = 3.53571 b = 2.79797
Epoch: 1300 loss = 1.128202677 W = 3.5341 b = 2.8096
Epoch: 1350 loss = 1.128123403 W = 3.5327 b = 2.81971
Epoch: 1400 loss = 1.128056765 W = 3.53149 b = 2.82849
Epoch: 1450 loss = 1.128000259 W = 3.53044 b = 2.83611
Epoch: 1500 loss = 1.127952814 W = 3.52952 b = 2.84274
Epoch: 1550 loss = 1.127912283 W = 3.52873 b = 2.84849
Epoch: 1600 loss = 1.127877355 W = 3.52804 b = 2.85349
Epoch: 1650 loss = 1.127847791 W = 3.52744 b = 2.85783
Epoch: 1700 loss = 1.127822518 W = 3.52692 b = 2.8616
Epoch: 1750 loss = 1.127801418 W = 3.52646 b = 2.86488
Epoch: 1800 loss = 1.127782702 W = 3.52607 b = 2.86773
Epoch: 1850 loss = 1.127766728 W = 3.52573 b = 2.8702
Epoch: 1900 loss = 1.127753139 W = 3.52543 b = 2.87234
Epoch: 1950 loss = 1.127740979 W = 3.52517 b = 2.87421
Epoch: 2000 loss = 1.127731323 W = 3.52495 b = 2.87584
Epoch: 2050 loss = 1.127722263 W = 3.52475 b = 2.87725
Epoch: 2100 loss = 1.127714872 W = 3.52459 b = 2.87847
Epoch: 2150 loss = 1.127707958 W = 3.52444 b = 2.87953
Epoch: 2200 loss = 1.127702117 W = 3.52431 b = 2.88045
Epoch: 2250 loss = 1.127697825 W = 3.5242 b = 2.88126
Epoch: 2300 loss = 1.127693415 W = 3.52411 b = 2.88195
Epoch: 2350 loss = 1.127689362 W = 3.52402 b = 2.88255
Epoch: 2400 loss = 1.127686620 W = 3.52395 b = 2.88307
Epoch: 2450 loss = 1.127683759 W = 3.52389 b = 2.88352
Epoch: 2500 loss = 1.127680898 W = 3.52383 b = 2.88391
Epoch: 2550 loss = 1.127679348 W = 3.52379 b = 2.88425
Epoch: 2600 loss = 1.127677798 W = 3.52374 b = 2.88456
Epoch: 2650 loss = 1.127675653 W = 3.52371 b = 2.88483
Epoch: 2700 loss = 1.127674222 W = 3.52368 b = 2.88507
Epoch: 2750 loss = 1.127673268 W = 3.52365 b = 2.88526
Epoch: 2800 loss = 1.127672315 W = 3.52362 b = 2.88543
Epoch: 2850 loss = 1.127671123 W = 3.5236 b = 2.88559
Epoch: 2900 loss = 1.127670288 W = 3.52358 b = 2.88572
Epoch: 2950 loss = 1.127670050 W = 3.52357 b = 2.88583
Epoch: 3000 loss = 1.127669215 W = 3.52356 b = 2.88592
Epoch: 3050 loss = 1.127668500 W = 3.52355 b = 2.88599
Epoch: 3100 loss = 1.127668381 W = 3.52354 b = 2.88606
Epoch: 3150 loss = 1.127667665 W = 3.52353 b = 2.88615
Epoch: 3200 loss = 1.127667546 W = 3.52352 b = 2.88621
Epoch: 3250 loss = 1.127667069 W = 3.52351 b = 2.88626
Epoch: 3300 loss = 1.127666950 W = 3.5235 b = 2.8863
Epoch: 3350 loss = 1.127666354 W = 3.5235 b = 2.88633
Epoch: 3400 loss = 1.127666593 W = 3.5235 b = 2.88637
Epoch: 3450 loss = 1.127666593 W = 3.52349 b = 2.8864
Epoch: 3500 loss = 1.127666235 W = 3.52349 b = 2.88644
Epoch: 3550 loss = 1.127665997 W = 3.52348 b = 2.88646
Epoch: 3600 loss = 1.127665639 W = 3.52348 b = 2.88648
Epoch: 3650 loss = 1.127665639 W = 3.52348 b = 2.88649
Epoch: 3700 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 3750 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 3800 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 3850 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 3900 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 3950 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 4000 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 4050 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 4100 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 4150 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 4200 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 4250 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 4300 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 4350 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 4400 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 4450 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 4500 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 4550 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 4600 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 4650 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 4700 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 4750 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 4800 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 4850 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 4900 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 4950 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 5000 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 5050 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 5100 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 5150 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 5200 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 5250 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 5300 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 5350 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 5400 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 5450 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 5500 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 5550 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 5600 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 5650 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 5700 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 5750 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 5800 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 5850 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 5900 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 5950 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 6000 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 6050 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 6100 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 6150 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 6200 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 6250 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 6300 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 6350 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 6400 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 6450 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 6500 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 6550 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 6600 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 6650 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 6700 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 6750 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 6800 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 6850 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 6900 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 6950 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 7000 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 7050 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 7100 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 7150 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 7200 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 7250 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 7300 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 7350 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 7400 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 7450 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 7500 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 7550 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 7600 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 7650 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 7700 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 7750 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 7800 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 7850 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 7900 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 7950 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 8000 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 8050 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 8100 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 8150 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 8200 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 8250 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 8300 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 8350 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 8400 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 8450 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 8500 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 8550 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 8600 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 8650 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 8700 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 8750 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 8800 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 8850 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 8900 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 8950 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 9000 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 9050 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 9100 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 9150 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 9200 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 9250 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 9300 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 9350 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 9400 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 9450 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 9500 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 9550 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 9600 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 9650 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 9700 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 9750 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 9800 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 9850 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 9900 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 9950 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 10000 loss = 1.127665997 W = 3.52348 b = 2.8865
Training cost = 1.12767 W = 3.52348 b = 2.8865
Upvotes: 2
Views: 266
Reputation: 6034
Your hypothesis of assuming the predicted output lies in a straight line is not correct. Check how the plot of year and price is.
So the linear hypothesis which you have taken will try it's best to fit in a straight line by satisfying as many input points as possible to reduce the cost. So when you are testing for a point which is outside the range, it will predict in the straight line which is the best optimized for the set of inputs you provided.
Now, you have mentioned two problems.
1. Cost is not going down: Try reducing the learning rate. Your cost will definitely go down.
2. Your output for year = 11 is wrong: The reason I have mentioned already above. What you need to do is you have to change the hypothesis. Include a square term and then check it. Example: y = ax^2 + bx + c
. You will get a better fit with this hypothesis equation.
Upvotes: 4