Reputation: 83
I have a large, non-image, CSV file that I want to read in batches and feed to model.fit_generator. I have written a DataGenerator(keras.utils.all_utils.Sequence) class with the following methods:
I instantiate a training and validation class and then call model.fit_generator with the correct batch size and other parameters. I noticed that _len_, _getitem_, and _data_generation gets called more than the required amount and out of sequence throwing off my modeling accuracy. For example, if my total rows in CSV is 350 and I set my batch size to 50, _getitem_ and _data_generation should be called 7 times each for training and validation data sets. Instead, I see that the order of reading the batches is not sequential and _getitem_ and __data_generation are called more than 7 times. I have set shuffle=False.
I am using the following:
Here's my sample code:
class DataGenerator(keras.utils.all_utils.Sequence):
def __init__(self, file_name, rows_per_batch=50, shuffle=False):
self.rows_per_batch = rows_per_batch
self.shuffle = shuffle
self.file_name = file_name
self.on_epoch_end()
reader = csv.reader(open(self.file_name, 'r'))
self.lines = list(reader)
def __len__(self):
return len(self.lines) // self.rows_per_batch
def __getitem__(self, index):
skip_rows = index * self.rows_per_batch
nrows = (index + 1) * self.rows_per_batch
X, y = self.__data_generation(self.folder_name, self.file_name, skip_rows, nrows)
return X, y
def on_epoch_end(self):
if self.shuffle == True:
np.random.shuffle(self.indexes)
def __data_generation(self, file_name, skip_rows, nrows, config_file):
df = pd.read_csv(file_name, header=None, skiprows=skip_rows, nrows=nrows)
< do some data processing here >
return X, y
# main program
file_name = <some filename>
rows_per_batch = <some number>
a = DataGenerator(file_name, rows_per_batch, shuffle=False)
b = DataGenerator(file_name, rows_per_batch, shuffle=False)
# Train model on dataset
model.fit_generator(generator=a,
validation_data=b,
use_multiprocessing=False,
shuffle=False,
epochs=10,
workers=6)
Thanks, in advance, for any help or suggestions.
Upvotes: 0
Views: 770
Reputation: 3333
The function def __len__(self):
must report the number of batch that the sequence will generate.
The function def __getitem__(self, index):
return the index-th batch.
The training/testing framwork will call __getitem__
with index in the range: [0,len(dg)-1]
where dg is an instance of your DataGenerator
So the function: def __len__(self):
implementation shall be:
def __len__(self):
return math.ceil(len(self.lines)/self.rows_per_batch))
Upvotes: 1