Sri Charan Kattamuru
Sri Charan Kattamuru

Reputation: 69

What is steps_per_epoch in model.fit_generator actually doing?

After reading the Keras documentation on the steps_per_epoch required argument in the model.fit_generator method my understanding of it is:

If a dataset contains 'N' samples and the generator function (passed to Keras) returns 'B = batch_size' number of samples for every call (Here, I think of a call as a single yield from the generator function) and since steps_per_epoch = ceil(N/B) the generator is being called steps_per_epoch times so that the full dataset is passed through the model after one epoch and this same process is repeated for every epoch until the training is completed.

So as to test whether my understanding was correct, I implemented the following

import numpy as np
from keras.models import Sequential
from keras.layers import Dense

index = 0

def get_values(inputs, targets):
    i = 0
    while True:
        yield inputs[i], targets[i]
        i += 1
        if i >= len(inputs):
            i = 0


def get_batch(inputs, targets, batch_size=2):
    global index
    batch_X = []
    batch_Y = []
    for inp, targ in get_values(inputs, targets):
        batch_X.append(inp)
        batch_Y.append(targ)

        if len(batch_X) >= batch_size:
            yield np.array(batch_X), np.array(batch_Y)
            index += 1
            batch_X = []
            batch_Y = []


data = list(range(10))
labels = [2*val for val in range(10)]

model = Sequential([
    Dense(16, activation='relu', input_shape=(1, )),
    Dense(1)
])

model.compile(optimizer='rmsprop', loss='mean_squared_error')
model.fit_generator(get_batch(data, labels, batch_size=2), steps_per_epoch=5, epochs=1, verbose=False)

print(index) # Should Print 5 but it prints 15

The program isn't tough to understand...

But as per my interpretation, it should print 5 but it prints 15. Am I wrong in the interpretation of steps_per_epoch? If so please give me the correct interpretation of steps_per_epoch

PS. I'm new to Tensorflow and Keras, thanks in advance.

Upvotes: 1

Views: 562

Answers (1)

Gerry P
Gerry P

Reputation: 8092

Did not go through your code but your original interpertation is correct. Actually per the documentation located here you can omit steps per epoch and the model.fit will divide the length of your data set (N) by the batch size to determine the steps. I did copy and run your code. Guess what it printed the index as 5. Only thing I can think of that might be different are the imports.

Upvotes: 2

Related Questions