mrgloom
mrgloom

Reputation: 21632

How to use model in batch generator?

I want to use model.predict in batch generator, what a possible ways of achieve this?

Seems one option is to load model on init and on epoch end:

class DataGenerator(keras.utils.Sequence):
    def __init__(self, model_name):
        # Load model

    # ...

    def on_epoch_end(self):
        # Load model

Upvotes: 0

Views: 513

Answers (1)

Daniel Möller
Daniel Möller

Reputation: 86600

In my experience, predicting another model while training will bring errors.

You should probably simply append your training model after your generator model.

Suppose you have:

generator_model (the one you want to use inside the generator)    
training_model (the one you want to train)

Then

generatorInput = Input(shapeOfTheGeneratorInput)
generatorOutput = generator_model(generatorInput)
trainingOutput = training_model(generatorOutput)

entireModel = Model(generatorInput,trainingOutput)

Make sure that the generator model has all layers untrainable before compiling:

genModel = entireModel.layers[1]
for l in genModel.layers:
    l.trainable = False

entireModel.compile(optimizer=optimizer,loss=loss)

Now use the generator regularly.


Predicting inside the generator:

class DataGenerator(keras.utils.Sequence):

    def __init__(self, model_name, modelInputs, batchSize):
        self.genModel = load_model(model_name)
        self.inputs = modelInputs
        self.batchSize = batchSize


    def __len__(self):
        l,rem = divmod(len(self.inputs), self.batchSize)
        return (l + (1 if rem > 0 else 0))

    def __getitem__(self,i):

        items = self.inputs[i*self.batchSize:(i+1)*self.batchSize]
        items = doThingsWithItems(items)

        predItems = self.genModel.predict_on_batch(items)

        #the following is the only reason not to chain models
        predItems = doMoreThingsWithItems(predItems)

        #do something to get Y_train_items as well

        return predItems, y_train_items

If you do find the error I mentioned, you can sacrifice the parallel generation capabilities and do some manual loops:

for e in range(epochs):
    for i in range(batches):
        x,y = generator[i]
        model.train_on_batch(x,y)

Upvotes: 1

Related Questions