linuxpanther
linuxpanther

Reputation: 69

Confusion Matrix Detailed Explanation of Dataset

df = pd.concat(map(pd.read_csv, ['A.csv','B1.csv','B2.csv','B3.csv', 'C1.csv', 'C2.csv','C3.csv']))

As shown above, I used multiple datasets to train and test classification supervised ML models.

classes = ['Good','Bad']
def plot_confusionmatrix(prediction,real,dom):
    print(f'{dom} Confusion matrix')
    cf = confusion_matrix(prediction,real)
    sns.heatmap(cf,annot=True,yticklabels=classes
               ,xticklabels=classes,cmap='Blues', fmt='g')
    plt.tight_layout()
    plt.show()  

enter image description here

I wanted to ask if instead of showing the confusion matrix among the entire testing set, how can I focus only on those samples coming from the A dataset? I am interested in seeing how these samples are classified.

Edit: As suggested, I added a column for each dataset and then I merged them.

A['Dataset'] = "A"
B1['Dataset'] = "B1"
B2['Dataset'] = "B2"
B3['Dataset'] = "B3"
C1['Dataset'] = "C1"
C2['Dataset'] = "C2"
C3['Dataset'] = "C3"

Solution:

df_train, df_test= train_test_split(df, test_size = 0.33, random_state = 42)
X_train = df_train.drop(['Label', 'Dataset'], axis=1)
y_train = df_train['Label']
X_test = df_test.drop(['Label', 'Dataset'], axis=1)
y_test = df_test['Label']

rfc = RandomForestClassifier(n_estimators=100, random_state=0)
rfc.fit(X_train, y_train)
y_pred = rfc.predict(X_test)

for dataset_name in ['A']:
  dataset_df_test = df_test[df_test['Dataset'] == dataset_name]
  dataset_y_test = y_test.iloc[dataset_df_test.index]
  dataset_y_pred = y_pred[dataset_df_test.index]

  plot_confusionmatrix(dataset_y_pred, dataset_y_test, dom='Test')

Upvotes: 0

Views: 117

Answers (1)

Serhii Maksymenko
Serhii Maksymenko

Reputation: 319

Seems like there is no other way except plotting confusion matrix for each subset separately. It must be easy to do in your case by iterating through all dataset names, selecting corresponding rows from y_test_DT_predicted and y_test for each dataset and using selected values in plot_confusionmatrix call. Don't forget to manage multiple subplots in plt and call plt.show() once at the end of the loop.

Upvotes: 1

Related Questions