milky way
milky way

Reputation: 11

how to visualize the iris dataset on 2d plots for different combinations of features

I want to visualize the iris dataset in 2d with all six combinations (sepal width-sepal length) , (petal width-sepal length), (sepal length-petal width), (petal length-petal width) (petal length-sepal width) (sepal width-petal length) basically so this is what i got so far:

import matplotlib
matplotlib.rcParams['figure.figsize'] = (9.0, 7.0)

data = load_iris()

pairs = [(i, j) for i in range(4) for j in range(i+1, 4)]

fig, subfigs = pyplot.subplots(2, 3, tight_layout=True)
t1 = time.time()

for (f1, f2), subfig in zip(pairs, subfigs.reshape(-1)):

According to the instructions we have to generate the 2d plot based on this pair list two measures at a time with f1 and f2 as the measures and create class indicators and legend() to better visualize the graph i tried different scatter plots but none of them seems to work.

Upvotes: 0

Views: 1659

Answers (2)

jumi
jumi

Reputation: 25

from sklearn import datasets
import pandas as pd

whole_data = datasets.load_iris()
whole_data

print('The full description of the dataset:\n',whole_data['DESCR'])

x_axis = whole_data.data[:,2]  # Petal Length
y_axis = whole_data.data[:, 3]  # Petal Width

# Plotting
import matplotlib.pyplot as plt

plt.scatter(x_axis, y_axis, c=whole_data.target)   
plt.title("Violet: Setosa, Green: Versicolor, Yellow: Virginica")
plt.show()

I have chosen petal width and petal length for visualization as they show high class correlation ... why just me?

Upvotes: 0

sentence
sentence

Reputation: 8923

IIUC your goal, the best thing to do is to use a function to plot the data, iterating over the possible combinations of features.

from sklearn.datasets import load_iris
import matplotlib.pyplot as plt

iris = load_iris()

def plot_iris(f1, f2):
    n_samples = len(iris.target)
    for t in set(iris.target):
        x = [iris.data[i,f1] for i in range(n_samples) if iris.target[i]==t]
        y = [iris.data[i,f2] for i in range(n_samples) if iris.target[i]==t]
        plt.scatter(x,
                    y,
                    color=['red', 'green', 'blue'][t],
                    label=iris.target_names[t])
    plt.xlabel(iris.feature_names[f1])
    plt.ylabel(iris.feature_names[f2])
    plt.title('Iris Dataset')
    plt.legend(iris.target_names, loc='lower right')
    plt.show()

n_features = len(iris.feature_names)
pairs = [(i, j) for i in range(n_features) for j in range(i+1, n_features)]

for (f1, f2) in pairs:
    plot_iris(f1, f2)

and you get six plots like these:

enter image description here enter image description here

Upvotes: 1

Related Questions