Jiageng Zhu
Jiageng Zhu

Reputation: 31

how to show every class accuracy for every epoch in keras

I try to train the mnist dataset when training, I want to show every class accuracy for every epoch not the accuracy for the whole dataset. What should we do? change the callback()? Thanks advance!

Upvotes: 2

Views: 897

Answers (1)

Jiageng Zhu
Jiageng Zhu

Reputation: 31

Finally figure it out myself xD use callback can solve this question take mnist dataset for example and I wanna show the digit 5 class accuracy here, do the following:

class TestCallback(Callback):
def __init__(self, test_data):
    self.test_data = test_data

def on_epoch_end(self, epoch, logs={}):
    x, y = self.test_data
    pred = self.model.predict(x)
    true = y
    prediction = np.argmax(pred,axis=1)
    label = np.argmax(true,axis=1)
    acc = 0
    tar = label[label==5]
    size_of_5 = len(tar)
    print("there are %d of 5"%(size_of_5))
    for i in range(len(label)):
        if label[i]==5:
            if prediction[i]==5:
                 acc += 1/size_of_5
    print('\n digit 5 accuracy:{}\n'.format(acc))

Upvotes: 1

Related Questions