Rajat
Rajat

Reputation: 687

How to change batch-size in keras retinanet training

I am trying to train a keras_retinanet model as shown in the code given below and the training is working fine. I created a CSVGenerator data-generator for the fit_generator function which inherits the "Generator" super class in which there's a parameter called "batch_size" defaulted to "1".

I would like to change the value of this "batch_size" variable, but I am not able to figure out how can I do that. Any help is much appreciated.

model = load_model('./snapshots/resnet50_csv_01.h5', 
backbone_name='resnet50')

generator = CSVGenerator(
    csv_data_file='./data_set_retina/train.csv',
    csv_class_file='./data_set_retina/class_id_mapping'
)


generator_val = CSVGenerator(
    csv_data_file='./data_set_retina/val.csv',
    csv_class_file='./data_set_retina/class_id_mapping'
)
model.compile(
    loss={
        'regression'    : keras_retinanet.losses.smooth_l1(),
        'classification': keras_retinanet.losses.focal()
    },
    optimizer=keras.optimizers.adam(lr=1e-5, clipnorm=0.001)
)
csv_logger = keras.callbacks.CSVLogger('./logs/training_log.csv', 
separator=',', append=False)

model.fit_generator(generator, steps_per_epoch=80000, epochs=50, 
verbose=1, callbacks=[csv_logger], 

validation_data=generator_val,validation_steps=20000,class_weight=None, 
max_queue_size=10, workers=1, use_multiprocessing=False,
                shuffle=True, initial_epoch=0)

Upvotes: 0

Views: 783

Answers (1)

jgorostegui
jgorostegui

Reputation: 1310

I suppose that you're speaking about the keras-retinanet repository.

You can find the batch size here:

https://github.com/fizyr/keras-retinanet/blob/b28c358c71026d7a5bcb1f4d928241a693d95230/keras_retinanet/bin/train.py#L395

This variable is then passed to the generators in the common_args dictionary.

In fact, it is also possible to instantiate your CSVGenerator passing batch_size argument. Following your code snippet:

generator = CSVGenerator(
    csv_data_file='./data_set_retina/train.csv',
    csv_class_file='./data_set_retina/class_id_mapping',
    batch_size=16
)

Upvotes: 1

Related Questions