luw
luw

Reputation: 217

How to implement coordinate descent using tensorflow?

For example, for a simple linear model y=wx+b where x and y are input and output respectively, w and b are training parameters, I am wondering, in every epoch, how can I update b first and then update w?

Upvotes: 0

Views: 443

Answers (2)

Yaoshiang
Yaoshiang

Reputation: 1941

Not really possible. TF's backprop calculate gradients across all variables based on the values of the other variables at the time of forward prop. If you want to alternate between training w and b, you would unfreeze w and freeze b (set it to trainable=False), forwardprop and backprop, then freeze w and unfreeze b, and forward prop and back prop. I don't think that'd run very fast since TF isn't really design to switch the trainable flag on every mini batch.

Upvotes: 1

Tensorflow might not be the best tool for this. You can do it just using python.

And if you need to do the regression with a more complex function scikit-learn might be a more appropriate library.

Regardless of the tool, you can do Batch Gradient Descent or Stochastic Gradient Descent.

But first you need to define a "Cost Function", this function basically tells you how far away from the true value you are, for example least mean square (LMS), this types of functions takes the prediction from your model and the true value and perform the adjustment to the training parameters.

This is the function that is optimized by BGD or SGD in the training process.

Here is an example I did to understand what is happening, it's not the optimum solution but it will give you an idea of what is happening.

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
tips = sns.load_dataset("tips")

alpha = 0.00005
thetas = np.array([1.,1.])

def h(thetas, x):
    #print(f'theta 0: {thetas[0]}')
    #print(f'theta 1: {thetas[1]}')
    #print(f'h=m:{thetas[0] + (thetas[1]*x[1])}')
    return thetas[0] + (thetas[1]*x[1])

for i in zip(tips.total_bill, tips.tip):
    x = np.array([1, i[0]])
    y = i[1]
    for index, theta in enumerate(thetas):
        #print(f'theta in: {thetas[index]}')
        #print(f'error: {thetas[index] + alpha*(y - h(thetas, x))*x[index]}')
        thetas[index] = thetas[index] + alpha*(y - h(thetas, x))*x[index]
        #print(f'theta out: {thetas[index]}')
    #print(thetas)
print(thetas)
xplot = np.linspace(min(tips.total_bill), max(tips.total_bill), 100, endpoint=True)
xp = [[1,x] for x in xplot]
yp = [h(thetas, xi) for xi in xp]

plt.scatter(tips.total_bill,tips.tip)
plt.plot(xplot, yp, 'o', color= 'orange')
plt.show()

Upvotes: 1

Related Questions