Reputation: 21632
I want to use model.predict
in batch generator, what a possible ways of achieve this?
Seems one option is to load model on init and on epoch end:
class DataGenerator(keras.utils.Sequence):
def __init__(self, model_name):
# Load model
# ...
def on_epoch_end(self):
# Load model
Upvotes: 0
Views: 513
Reputation: 86600
In my experience, predicting another model while training will bring errors.
You should probably simply append your training model after your generator model.
Suppose you have:
generator_model (the one you want to use inside the generator)
training_model (the one you want to train)
Then
generatorInput = Input(shapeOfTheGeneratorInput)
generatorOutput = generator_model(generatorInput)
trainingOutput = training_model(generatorOutput)
entireModel = Model(generatorInput,trainingOutput)
Make sure that the generator model has all layers untrainable before compiling:
genModel = entireModel.layers[1]
for l in genModel.layers:
l.trainable = False
entireModel.compile(optimizer=optimizer,loss=loss)
Now use the generator regularly.
class DataGenerator(keras.utils.Sequence):
def __init__(self, model_name, modelInputs, batchSize):
self.genModel = load_model(model_name)
self.inputs = modelInputs
self.batchSize = batchSize
def __len__(self):
l,rem = divmod(len(self.inputs), self.batchSize)
return (l + (1 if rem > 0 else 0))
def __getitem__(self,i):
items = self.inputs[i*self.batchSize:(i+1)*self.batchSize]
items = doThingsWithItems(items)
predItems = self.genModel.predict_on_batch(items)
#the following is the only reason not to chain models
predItems = doMoreThingsWithItems(predItems)
#do something to get Y_train_items as well
return predItems, y_train_items
If you do find the error I mentioned, you can sacrifice the parallel generation capabilities and do some manual loops:
for e in range(epochs):
for i in range(batches):
x,y = generator[i]
model.train_on_batch(x,y)
Upvotes: 1