Reputation: 31
I try to train the mnist dataset when training, I want to show every class accuracy for every epoch not the accuracy for the whole dataset. What should we do? change the callback()? Thanks advance!
Upvotes: 2
Views: 897
Reputation: 31
Finally figure it out myself xD use callback can solve this question take mnist dataset for example and I wanna show the digit 5 class accuracy here, do the following:
class TestCallback(Callback):
def __init__(self, test_data):
self.test_data = test_data
def on_epoch_end(self, epoch, logs={}):
x, y = self.test_data
pred = self.model.predict(x)
true = y
prediction = np.argmax(pred,axis=1)
label = np.argmax(true,axis=1)
acc = 0
tar = label[label==5]
size_of_5 = len(tar)
print("there are %d of 5"%(size_of_5))
for i in range(len(label)):
if label[i]==5:
if prediction[i]==5:
acc += 1/size_of_5
print('\n digit 5 accuracy:{}\n'.format(acc))
Upvotes: 1