Reputation: 69
After reading the Keras documentation on the steps_per_epoch
required argument in the model.fit_generator
method my understanding of it is:
If a dataset contains 'N' samples and the generator function (passed to Keras) returns 'B = batch_size' number of samples for every call (Here, I think of a call as a single yield from the generator function) and since steps_per_epoch = ceil(N/B) the generator is being called steps_per_epoch
times so that the full dataset is passed through the model after one epoch and this same process is repeated for every epoch until the training is completed.
So as to test whether my understanding was correct, I implemented the following
import numpy as np
from keras.models import Sequential
from keras.layers import Dense
index = 0
def get_values(inputs, targets):
i = 0
while True:
yield inputs[i], targets[i]
i += 1
if i >= len(inputs):
i = 0
def get_batch(inputs, targets, batch_size=2):
global index
batch_X = []
batch_Y = []
for inp, targ in get_values(inputs, targets):
batch_X.append(inp)
batch_Y.append(targ)
if len(batch_X) >= batch_size:
yield np.array(batch_X), np.array(batch_Y)
index += 1
batch_X = []
batch_Y = []
data = list(range(10))
labels = [2*val for val in range(10)]
model = Sequential([
Dense(16, activation='relu', input_shape=(1, )),
Dense(1)
])
model.compile(optimizer='rmsprop', loss='mean_squared_error')
model.fit_generator(get_batch(data, labels, batch_size=2), steps_per_epoch=5, epochs=1, verbose=False)
print(index) # Should Print 5 but it prints 15
The program isn't tough to understand...
But as per my interpretation, it should print 5 but it prints 15. Am I wrong in the interpretation of steps_per_epoch
?
If so please give me the correct interpretation of steps_per_epoch
PS. I'm new to Tensorflow and Keras, thanks in advance.
Upvotes: 1
Views: 562
Reputation: 8092
Did not go through your code but your original interpertation is correct. Actually per the documentation located here you can omit steps per epoch and the model.fit will divide the length of your data set (N) by the batch size to determine the steps. I did copy and run your code. Guess what it printed the index as 5. Only thing I can think of that might be different are the imports.
Upvotes: 2