Lidor shimoni
Lidor shimoni

Reputation: 93

How to make custom callback in keras to generate sample image in VAE training?

I'm training a simple VAE model on 64*64 images and I would like to see the images generated after every epoch or every couple batches to see the progress.

when I train the model I wait until the training is done and then I look at the results.

I tried to make a custom callback function in Keras that generates an image and saves it but couldn't do it. is it even possible? I couldn't find anything like it.

it would be awesome if you refer me to a source that explains how to do so or show me an example.

Note: I'm interested in a clean Keras.callback solution and not to iterate over every epoch, train and generate the sample

Upvotes: 1

Views: 495

Answers (2)

Heisenberg666
Heisenberg666

Reputation: 38

If you still need it, you can define custom callback in keras as a subclass of keras.callbacks.Callback:

class CustomCallback(keras.callbacks.Callback):
    def __init__(self, save_path, VAE):
        self.save_path = save_path
        self.VAE = VAE
    def on_epoch_end(self, epoch, logs={}):
        #load the image
        #get latent_space with self.VAE.encoder.predict(image)
        #get reconstructed image wtih self.VAE.decoder.predict(latent_space)
        #plot reconstructed image with matplotlib.pyplot

Then define callback as image_callback = CustomCallback(...) and place image_callback in the list of callbacks

Upvotes: 1

Paul Higazi
Paul Higazi

Reputation: 197

Yeah its actually possible, but i always use matplotlib and a self-defined function for that. For example something like that.

for steps in range (epochs):

    Train,Test = YourDataGenerator() # load your images for one loop
    model.fit(Train,Test,batch_size= ...)


    result = model.predict(Test_image)
    plt.imshow(result[0,:,:,:]) # keras always returns [batch.nr,heigth,width,channels]

    filename1 = '/content/runde2/%s_generated_plot_%06d.png' % (test, (steps+1))
    plt.savefig(filename1 )
    plt.close()

I think there is also a clean keras.callback version, but i always used this approach because you can use other libraries for easier data augmentation per loop. But thats just my opinion, hope i could help you at least a bit.

Upvotes: 1

Related Questions