tinyMind
tinyMind

Reputation: 155

How to batch train a CNN with Keras fit_generator?

Apologies if this is the wrong place to raise my issue (please help me out with where best to raise it if that's the case). I'm a novice with Keras and Python so hope responses have that in mind.

I'm trying to train a CNN steering model that takes images as input. It's a fairly large dataset, so I created a data generator to work with fit_generator(). It's not clear to me how to make this method trains on batches, so I assumed that the generator has to return batches to fit_generator(). The generator looks like this:

def gen(file_name, batchsz = 64):
    csvfile = open(file_name)
    reader = csv.reader(csvfile)
    batchCount = 0
    while True:
        for line in reader:
            inputs = []
            targets = []
            temp_image = cv2.imread(line[1]) # line[1] is path to image
            measurement = line[3] # steering angle
            inputs.append(temp_image)
            targets.append(measurement)
            batchCount += 1
            if batchCount >= batchsz:
                batchCount = 0
                X = np.array(inputs)
                y = np.array(targets)
                yield X, y
        csvfile.seek(0)

It reads a csv file containing telemetry data (steering angle etc) and paths to image samples, and returns arrays of size: batchsz The call to fit_generator() looks like this:

tgen = gen('h:/Datasets/dataset14-no.zero.speed.trn.csv', batchsz = 128) # Train data generator
vgen = gen('h:/Datasets/dataset14-no.zero.speed.val.csv', batchsz = 128) # Validation data generator
try:
    model.fit_generator(
        tgen,
        samples_per_epoch=113526,
        nb_epoch=6,
        validation_data=vgen,
        nb_val_samples=20001
    )

The dataset contains 113526 sample points yet the model training update output reads like this (for example):

  1020/113526 [..............................] - ETA: 27737s - loss: 0.0080
  1021/113526 [..............................] - ETA: 27723s - loss: 0.0080
  1022/113526 [..............................] - ETA: 27709s - loss: 0.0080
  1023/113526 [..............................] - ETA: 27696s - loss: 0.0080

Which appears to be training sample by sample (stochastically?). The resultant model is useless. I previously trained on a much smaller dataset using .fit() with the whole dataset loaded into memory, and that produced a model that at least works even if poorly. Clearly something is wrong with my fit_generator() approach. Will be very grateful for some help with this.

Upvotes: 1

Views: 2934

Answers (1)

Daniel Möller
Daniel Möller

Reputation: 86600

This:

for line in reader:
    inputs = []
    targets = []

... is resetting your batch for every line in the csv files. You're not training with your entire data, but with just a single sample in 128.

Suggestion:

for line in reader:

    if batchCount == 0:
        inputs = []
        targets = []  
    ....
    ....

As someone commented, the in fit generator, samples_per_epoch should be equal to total_samples / batchsz

Even though, I think your loss should be going down anyway. If it isn't, there might still be another problem in the code, perhaps in the way you load the data, or in the model's initialization or structure.

Try to plot your images and print the data in the generator:

for X,y in tgen: #careful, this is an infinite loop, make it stop

    print(X.shape[0]) # is this really the number of batches you expect?

    for image in X:
        ...some method to plot X so you can see it, or just print     

    print(y)

Check if the yielded values are ok with what you expect them to be.

Upvotes: 2

Related Questions