Reputation: 687
I am trying to train a keras_retinanet model as shown in the code given below and the training is working fine. I created a CSVGenerator data-generator for the fit_generator function which inherits the "Generator" super class in which there's a parameter called "batch_size" defaulted to "1".
I would like to change the value of this "batch_size" variable, but I am not able to figure out how can I do that. Any help is much appreciated.
model = load_model('./snapshots/resnet50_csv_01.h5',
backbone_name='resnet50')
generator = CSVGenerator(
csv_data_file='./data_set_retina/train.csv',
csv_class_file='./data_set_retina/class_id_mapping'
)
generator_val = CSVGenerator(
csv_data_file='./data_set_retina/val.csv',
csv_class_file='./data_set_retina/class_id_mapping'
)
model.compile(
loss={
'regression' : keras_retinanet.losses.smooth_l1(),
'classification': keras_retinanet.losses.focal()
},
optimizer=keras.optimizers.adam(lr=1e-5, clipnorm=0.001)
)
csv_logger = keras.callbacks.CSVLogger('./logs/training_log.csv',
separator=',', append=False)
model.fit_generator(generator, steps_per_epoch=80000, epochs=50,
verbose=1, callbacks=[csv_logger],
validation_data=generator_val,validation_steps=20000,class_weight=None,
max_queue_size=10, workers=1, use_multiprocessing=False,
shuffle=True, initial_epoch=0)
Upvotes: 0
Views: 783
Reputation: 1310
I suppose that you're speaking about the keras-retinanet repository.
You can find the batch size
here:
This variable is then passed to the generators in the common_args
dictionary.
In fact, it is also possible to instantiate your CSVGenerator
passing batch_size
argument. Following your code snippet:
generator = CSVGenerator(
csv_data_file='./data_set_retina/train.csv',
csv_class_file='./data_set_retina/class_id_mapping',
batch_size=16
)
Upvotes: 1