Kukai
Kukai

Reputation: 25

Error in simple tensorflow linear regression

I am just getting started with tensorflow. I am trying to write a simple linear regression based off of an example I found online.

I was able to get a reasonable answer when I used sklearn.

Why is my MSE returning NaN?

import pandas as pd
import tensorflow as tf
import numpy as np

# Create some fake data
size = 1000
performance_x = np.stack((np.random.uniform(24, 40, size), np.random.uniform(80, 240, size), np.random.uniform(80, 100, size), np.random.uniform(15, 25, size)), axis=1)
performance_y = np.sum(np.multiply(performance_x, [0.25, 1, 0.5, 0.75]), axis=1)
performance_y = performance_y + np.stack(np.random.uniform(-10, 10, size))
performance_y = np.reshape(performance_y, (size,1))
n_dim = performance_x.shape[1]


# Testing Tensorflow

learning_rate = 0.001
training_epochs = 1000
cost_history = np.empty(shape=[1], dtype=float)


rnd_indices = np.random.rand(len(performance_x)) < 0.80

train_x = performance_x[rnd_indices]
train_y = performance_y[rnd_indices]

test_x = performance_x[~rnd_indices]
test_y = performance_y[~rnd_indices]


X = tf.placeholder(tf.float32, [None, n_dim])
Y = tf.placeholder(tf.float32, [None, 1])

W = tf.Variable(tf.ones([n_dim, 1]))

init = tf.global_variables_initializer()

y_ = tf.matmul(X, W)

cost = tf.reduce_mean(tf.square(y_ - Y))

training_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

sess = tf.Session()
sess.run(init)

for epoch in range(training_epochs):
  sess.run(training_step, feed_dict={X:train_x,Y:train_y})

pred_y = sess.run(y_, feed_dict={X: test_x})

mse = tf.reduce_mean(tf.square(pred_y - test_y))
print("MSE: %.4f" % sess.run(mse)) 

Upvotes: 0

Views: 149

Answers (1)

Richard_wth
Richard_wth

Reputation: 702

I scaled performance_x by 0.01, i.e. performance_x = np.stack((np.random.uniform(.24, .40, size), np.random.uniform(.80, 2.40, size), np.random.uniform(.80, 1.00, size), np.random.uniform(.15, .25, size)), axis=1) and got an MSE around 0.0038 after 100k steps.

Upvotes: 1

Related Questions