Mpizos Dimitris
Mpizos Dimitris

Reputation: 4991

How to: fit_generator in keras

I am a little bit confused how to use fit_generator in keras.

By example lets say:

Using fit we just:

x, y = load_data()
model.fit(x=x, y=y, batch_size=512, epochs=10)

where load_data loads all the data.

Now how to do the same with fit_generator.

Its not clear to me how it is processed when using fit_generator. If I have the following generator:

def data_generator():
    for x, y in load_data_per_line():
        yield x, y

In the generator above each time it yields one data point. And:

def data_generator_2():
    x_output = []
    y_output = []
    i = 0
    for x, y in load_data_per_line():
        x_output[i] = x
        y_output[i] = y
        i = i + 1
        if i == batch_size:
           yield x_output, y_output
           i = 0
           x_output = []
           y_output = []

In the above generator every time it yields batch size data points(512 in this case).

To achieve the same as fit but using fit_generator:

model.fit_generator(data_generator(), steps_per_epoch=10000 / 512, epochs=10)

or

model.fit_generator(data_generator_2(), steps_per_epoch=10000 / 512, epochs=10)

Or both are wrong(fit_generator and data_generators)? If any of them is correct, does is guaranties that all data points will be processed and also be processed sequentially?

Any insight is useful

Upvotes: 2

Views: 951

Answers (1)

Daniel Möller
Daniel Möller

Reputation: 86600

Generator 2 is almost ok, but it should better return numpy arrays:

yield np.asarray(x_output),np.asarray(y_output)

Also, it should be infinite:

while True: 

    #the code inside to loop infinitely

The first one will not return batches and will fail.

You will probably have a problem in steps_per_epoch, because 10000 is not a multiple of 512. You need integer steps. You may inside the generator check if i == 10000: and pass a smaller batch as the last batch.

Then you've got (10000 //512) + (10000 % 512) steps or batches.

All batches will be read in sequence, but keras automatically shuffles the content of these batches, use suffle=False. If you use multithreading (not the case), then you need to create thread safe generators or use a keras Sequence.

Upvotes: 2

Related Questions