Faruk
Faruk

Reputation: 2449

Keras - On train end callback does not get any log

I am working on a project that requires some work on LSTM network. I did not have any problem like that before, and weirdly I did not change anything about this part of the code.

The problem is I have callback to write down the process of the training of a model to file, called Logger. In the on_train_end method I am calling another custom function to save the plots of loss, acc and perplexity. But logs parameter of the on_train_end method is given empty dictionary, on the hand there is no problem with the on_epoch_end.

def on_train_end(self, logs={}):
        #calculate total time
        self.train_dur = time.time()-self.train_start
        self.log.write("\n\nTotal duration: " + str(self.train_dur) + " seconds\n")
        self.log.write("*"*100+"\n")
        self.log.close()
        print("train end logs:", logs)
        self.__save_plots(logs)
        #write time to a file
        return

def on_epoch_end(self, epoch, logs={}):
        #calculate epoch time
        self.epoch_dur = time.time()-self.epoch_start
        #write epoch logs to a file
        print("epoch end logs:" , logs)
        epoch_loss_info = "\nloss: {loss} -- val_loss: {val_loss}".format(loss = logs["loss"], val_loss = logs["val_loss"])
        epoch_acc_info = "\nacc: {acc} -- val_acc: {val_acc}".format(acc = logs["acc"], val_acc = logs["val_acc"])
        epoch_ppl_info = "\nppl: {ppl} -- val_ppl: {val_ppl}\n".format(ppl=logs["ppl"], val_ppl=logs["val_ppl"])
        self.log.write("-"*100+"\n")
        self.log.write("\n\nEpoch: {epoch} took {dur} seconds \n".format(epoch=epoch+1, dur=self.epoch_dur))
        self.log.write(epoch_loss_info+epoch_acc_info+epoch_ppl_info)
        #write generations to a file
        generator = model_generator(gen_seq_len=self.gen_seq_len, by_dir=self.model_dir)
        generated_text = generator.generate()
        self.log.write("\nInput text:\t" + generated_text[:self.gen_seq_len] + "\n" )
        self.log.write("\nGenerated text:\t" + generated_text + "\n")
        self.log.write("-"*100+"\n")
        return

As you can see below, I have a print function in each method and print("epoch end logs") prints out a dict filled with proper values. However print("train end logs") prints out a empty dict.

I also tried to get history as a returning from fit_generator function and tried to print it out. That also comes with values.

I have searched GitHub and Stackoverflow but did not see anything like this.

Thanks in advance.

Upvotes: 4

Views: 1394

Answers (1)

Faruk
Faruk

Reputation: 2449

I have temporarly solved this problem by creating a empty dictionary contains the same keys as logs variable and append every value after each epoch to this dictionary. Thus, I am not using the parameter self.logs which I fill after each epoch instead for logs parameter in on_train_end method.

But feel free to give an answer to this weird problem.

Upvotes: 1

Related Questions