user9238790
user9238790

Reputation:

Pass single rows from a dataframe to predict with a loop

I pass the row index using iloc and specifying the position using n. Instead, how to modify the code to pass the rows from class_zero, and print each prediction of it.

import numpy as np
import pandas as pd
from sklearn.datasets import make_classification
from sklearn.ensemble import RandomForestClassifier


X, y = make_classification(n_samples=1000,
                           n_features=6,
                           n_informative=3,
                           n_classes=2,
                           random_state=0,
                           shuffle=False)

# Creating a dataFrame
df = pd.DataFrame({'Feature 1':X[:,0],
                                  'Feature 2':X[:,1],
                                  'Feature 3':X[:,2],
                                  'Feature 4':X[:,3],
                                  'Feature 5':X[:,4],
                                  'Feature 6':X[:,5],
                                  'Class':y})

y_train = df['Class']
X_train = df.drop('Class', axis=1)
class_zero = df.loc[df['Class'] == 0]

n = 5  #instead of specifying 5 which is where class_zero = 0, I want to pass directly the class_zero from the list I created
#and print for each one

rf = RandomForestClassifier()
rf.fit(X_train, y_train)
instances = X_train.iloc[n].values.reshape(1, -1)

predictValue = rf.predict(instances)
actualValue = y_train.iloc[n]

print('##')
print(n)
print(predictValue)
print(actualValue)
print('##')

Upvotes: 1

Views: 2636

Answers (1)

Vivek Kumar
Vivek Kumar

Reputation: 36609

You can use the index of rows where class==0 as a list in iloc()

Change the class_zero like this:

class_zero = df.index[df['Class'] == 0].tolist()

And you are doing the reshape wrong. Keep it like this:

instances = X_train.iloc[class_zero].values

Edit for comment:

for n in class_zero:
    instances = X_train.iloc[n].values.reshape(1,-1)

    predictValue = rf.predict(instances)
    actualValue = y_train.iloc[n]

    print('##')
    print(n)
    print(predictValue)
    print(actualValue)
    print('##')

Upvotes: 1

Related Questions