Simplicity
Simplicity

Reputation: 48926

How can I plot the ROC curve from this data?

I have trained a Convolutional Neural Network (CNN) using Keras, and doing the following in order to find the accuracy on a test data set:

for root, dirs, files in os.walk(test_directory):
    for file in files:
        img = cv2.imread(root + '/' + file)
        img = cv2.resize(img,(512,512),interpolation=cv2.INTER_AREA)
        img = np.expand_dims(img, axis=0)
        img = img/255.0
        if os.path.basename(root) == 'nevus':
            label = 1
        elif os.path.basename(root) == 'melanoma':
            label = 0
        img_class = model.predict_classes(img)
        prediction = img_class[0]
        if prediction == label:
            correct_classification = correct_classification + 1
        print 'This is the prediction: '
        print prediction
        number_of_test_images = number_of_test_images + 1

print 'correct results:'
print correct_classification

print 'number of test images'
print number_of_test_images

print 'Accuray:'
print number_of_test_images/correct_classification * 100

Is there a way I can find the ROC curve from testing the model on the test data set?

Thanks.

Upvotes: 0

Views: 519

Answers (1)

HakunaMaData
HakunaMaData

Reputation: 1321

The ROC Curve is simply a curve that plots TP (True Positive) vs. FP (False Positive) at different probability thresholds. So, if this is a binary classification problem then you could simply change the probability threshold on the predictions of your test dataset and get the TP-FP rates. Essentially create a table that has three columns: [Prob Threshold, TP, FP] and plot that. You'll need to use model.predict_proba(...) to get the probabilities by class and then use that to create the ROC Curve.

For multi-class it gets a little tricky. You have a few options though. You can plot a ROC curve for each class (A one vs many case) essentially binarizing the primary class against all the other classes. Alternatively, for multi-class you could do what sklean attempts to do and create a micro-average ROC curve or a macro-average ROC Curve.

Upvotes: 2

Related Questions