Reputation: 293
I'm relatively new to Machine Learning and I decided to delve into some theory and then practice with some code. In the process I got many error messages that I managed to fix but I'm lost with this one. I'm also relatively new to Python so I'm sure this is some syntax-related problem but I couldn't pin it down this time (Python 2.7.15). Here's the complete code:
import numpy as np
from matplotlib import pyplot as plt
# Next we input our data of the for [X, Y, Bias] in a matrix using the Numpy array method:
X = np.array([
[-2, 4,-1],
[2, -2, -1],
[2, 4, -1],
[8,-4, -1],
[9, 4, -1],
])
# Let's make another variable Y that contains the output labels for each element in the matrix:
Y = np.array([-1,-1,1,1,1])
#Now let's plot our data. We're going to use a For Loop for this:
for index,element in enumerate(X):
if index<2:
plt.scatter(element[0],element[1], marker="_", s=120, color="r")
else:
plt.scatter(element[0],element[1], marker="+", s=120, color="b")
plt.plot([-2,8], [8,0.5])
plt.show()
def svm_sgd_plot(X, Y):
#Initialize our SVMs weight vector with zeros (3 values)
w = np.zeros(len(X[0]))
#The learning rate
eta = 1
#how many iterations to train for
epochs = 100000
#store misclassifications so we can plot how they change over time
errors = []
#training part & gradient descent part
for epoch in range(1,epochs):
error = 0
for i, x in enumerate(X):
#misclassification
if (Y[i]*np.dot(X[i], w)) < 1:
#misclassified update for ours weights
w = w + eta * ( (X[i] * Y[i]) + (-2 * (1/epoch) * w) )
error = 1
else:
#correct classification, update our weights
w = w + eta * (-2 * (1/epoch) * w)
errors.append(error)
# lets plot the rate of classification errors during training for our SVM
plt.plot(errors, '|')
plt.ylim(0.5,1.5)
plt.axes().set_yticklabels([])
plt.xlabel('Epoch')
plt.ylabel('Misclassified')
plt.show()
return w
for d, sample in enumerate(X):
# Plot the negative samples
if d < 2:
plt.scatter(sample[0], sample[1], s=120, marker='_', linewidths=2)
# Plot the positive samples
else:
plt.scatter(sample[0], sample[1], s=120, marker='+', linewidths=2)
# Add our test samples
plt.scatter(2,2, s=120, marker='_', linewidths=2, color='yellow')
plt.scatter(4,3, s=120, marker='+', linewidths=2, color='blue')
plt.show()
# Print the hyperplane calculated by svm_sgd()
x2=[ w[0],w[1],-w[1],w[0] ]
x3=[ w[0],w[1],w[1],-w[0] ]
x2x3 = np.array([x2,x3])
X,Y,U,V = zip(*x2x3)
ax = plt.gca()
ax.quiver(X,Y,U,V,scale=1, color='blue')
w = svm_sgd_plot(X,Y)
But I keep getting the following error:
Traceback (most recent call last): File "C:\Users...\Support Vector Machine (from scratch).py", line 134, in x2=[ w[0],w[1],-w[1],w[0] ] NameError: name 'w' is not defined
I hope someone more knowledgeable would help. Thanks.
Upvotes: 1
Views: 907
Reputation: 1383
First you defined w
inside the method svm_sgd_plot
, but the method is not doing anything until you explicitly call it to do something.
You can call it by adding the line w = svm_sgd_plot(X,Y)
, for example after plotting your testing data, so your code becomes
#PLOT TRAINING DATA
for d, sample in enumerate(X):
# Plot the negative samples
if d < 2:
plt.scatter(sample[0], sample[1], s=120, marker='_', linewidths=2)
# Plot the positive samples
else:
plt.scatter(sample[0], sample[1], s=120, marker='+', linewidths=2)
#PLOT TESTING DATA
# Add our test samples
plt.scatter(2,2, s=120, marker='_', linewidths=2, color='yellow')
plt.scatter(4,3, s=120, marker='+', linewidths=2, color='blue')
plt.show()
#CALL YOUR METHOD
w = svm_sgd_plot(X,Y)
Then you just need to visualize the classification provided by your method. I added your two testing data observations so that you can see how your SVM method classifies them correctly. Notice that the yellow point and the blue point are separated by the line generated by your SVM method.
# Print the hyperplane calculated by svm_sgd()
x2=[ w[0],w[1],-w[1],w[0] ]
x3=[ w[0],w[1],w[1],-w[0] ]
x2x3 = np.array([x2,x3])
X,Y,U,V = zip(*x2x3)
ax = plt.gca()
ax.quiver(X,Y,U,V,scale=1, color='blue')
#I ADDED THE FOLLOWING THREE LINES SO THAT YOU CAN SEE HOW YOU TESTING DATA IS BEING CLASSIFIED BY YOUR SVM METHOD
plt.scatter(2,2, s=120, marker='_', linewidths=2, color='yellow')
plt.scatter(4,3, s=120, marker='+', linewidths=2, color='blue')
plt.show()
Upvotes: 2