siegfried
siegfried

Reputation: 451

Plot SVM decision boundary

The following code fits an SVM with polynomial kernel and plot the iris data and the decision boundary. The input X is using the first 2 columns of the data, sepal length and width. However I am having difficulties reproducing the output with the 3rd and 4th columns as X, that is the petal length and width. How can I change the plot function for the code to work? Thanks in advance.

import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.svm import SVC

iris= datasets.load_iris()
y= iris.target
#X= iris.data[:, :2]  # sepal length and width
X= iris.data[:, 2:]   # I tried a different X but it failed.

# Ref: https://medium.com/all-things-ai/in-depth-parameter-tuning-for-svc-758215394769
def plotSVC(title):
    # create a mesh to plot in
    x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
    y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
    h = (x_max / x_min)/100
    xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
    plt.subplot(1, 1, 1)
    Z = svm_mod.predict(np.c_[xx.ravel(), yy.ravel()])
    Z = Z.reshape(xx.shape)
    plt.contourf(xx, yy, Z, cmap=plt.cm.Paired, alpha=0.8)
    plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.Paired)
    plt.xlabel('Sepal length')
    plt.ylabel('Sepal width')
    plt.xlim(xx.min(), xx.max())
    plt.title(title)
    plt.show()

svm= SVC(C= 10, kernel='poly', degree=2, coef0=1, max_iter=500000)
svm_mod= svm.fit(X,y)
plotSVC('kernel='+ str('polynomial'))

The error:

  import sys
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-57-4515f111e34d> in <module>()
      2 svm= SVC(C= 10, kernel='poly', degree=2, coef0=1, max_iter=500000)
      3 svm_mod= svm.fit(X,y)
----> 4 plotSVC('kernel='+ str('polynomial'))

<ipython-input-56-556d4a22026a> in plotSVC(title)
     10     Z = svm_mod.predict(np.c_[xx.ravel(), yy.ravel()])
     11     Z = Z.reshape(xx.shape)
---> 12     plt.contourf(xx, yy, Z, cmap=plt.cm.Paired, alpha=0.8)
     13     plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.Paired)
     14     plt.xlabel('Sepal length')

~/anaconda3/lib/python3.6/site-packages/matplotlib/pyplot.py in contourf(*args, **kwargs)
   2931                       mplDeprecation)
   2932     try:
-> 2933         ret = ax.contourf(*args, **kwargs)
   2934     finally:
   2935         ax._hold = washold

~/anaconda3/lib/python3.6/site-packages/matplotlib/__init__.py in inner(ax, *args, **kwargs)
   1853                         "the Matplotlib list!)" % (label_namer, func.__name__),
   1854                         RuntimeWarning, stacklevel=2)
-> 1855             return func(ax, *args, **kwargs)
   1856 
   1857         inner.__doc__ = _add_data_doc(inner.__doc__,

~/anaconda3/lib/python3.6/site-packages/matplotlib/axes/_axes.py in contourf(self, *args, **kwargs)
   6179             self.cla()
   6180         kwargs['filled'] = True
-> 6181         contours = mcontour.QuadContourSet(self, *args, **kwargs)
   6182         self.autoscale_view()
   6183         return contours

~/anaconda3/lib/python3.6/site-packages/matplotlib/contour.py in __init__(self, ax, *args, **kwargs)
    844         self._transform = kwargs.pop('transform', None)
    845 
--> 846         kwargs = self._process_args(*args, **kwargs)
    847         self._process_levels()
    848 

~/anaconda3/lib/python3.6/site-packages/matplotlib/contour.py in _process_args(self, *args, **kwargs)
   1414                 self._corner_mask = mpl.rcParams['contour.corner_mask']
   1415 
-> 1416             x, y, z = self._contour_args(args, kwargs)
   1417 
   1418             _mask = ma.getmask(z)

~/anaconda3/lib/python3.6/site-packages/matplotlib/contour.py in _contour_args(self, args, kwargs)
   1472             args = args[1:]
   1473         elif Nargs <= 4:
-> 1474             x, y, z = self._check_xyz(args[:3], kwargs)
   1475             args = args[3:]
   1476         else:

~/anaconda3/lib/python3.6/site-packages/matplotlib/contour.py in _check_xyz(self, args, kwargs)
   1508             raise TypeError("Input z must be a 2D array.")
   1509         elif z.shape[0] < 2 or z.shape[1] < 2:
-> 1510             raise TypeError("Input z must be at least a 2x2 array.")
   1511         else:
   1512             Ny, Nx = z.shape

TypeError: Input z must be at least a 2x2 array.

Upvotes: 0

Views: 2401

Answers (1)

Gulzar
Gulzar

Reputation: 27896

working code

import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.svm import SVC

iris= datasets.load_iris()
y= iris.target
#X= iris.data[:, :2]  # sepal length and width
X= iris.data[:, 2:]   # I tried a different X but it failed.

# Ref: https://medium.com/all-things-ai/in-depth-parameter-tuning-for-svc-758215394769
def plotSVC(title):
    # create a mesh to plot in
    x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
    y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
    h = (x_max - x_min)/100
    xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
    plt.subplot(1, 1, 1)
    z = svm_mod.predict(np.c_[xx.ravel(), yy.ravel()])
    z = z.reshape(xx.shape)
    plt.contourf(xx, yy, z, cmap=plt.cm.Paired, alpha=0.8)
    plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.Paired)
    plt.xlabel('Sepal length')
    plt.ylabel('Sepal width')
    plt.xlim(xx.min(), xx.max())
    plt.title(title)
    plt.show()
    pass

svm= SVC(C= 10, kernel='poly', degree=2, coef0=1, max_iter=500000)
svm_mod= svm.fit(X,y)
plotSVC('kernel='+ str('polynomial'))

out:

enter image description here


Reason:

Division by zero gave inf in h in the line h = (x_max / x_min)/100

Needs to be h = (x_max - x_min)/100


I found this by reading the exception that stated

TypeError: Input z must be at least a 2x2 array.

Then went backward to see that z's shape comes from xx's shape which depends on h, which was inf, which made no sense, which was then easily solved.

I think you should learn how to better use the debugger.

Upvotes: 1

Related Questions