Reputation: 41
I have original data (X_train, y_train), and I am modifying this data into something else. Original data are just images with labels. Modified data should be pairs of images for Siamese network which are high in number and it would be around 30 GB in memory. So can't run this function to create pairs on whole original data. So, I used keras fit_generator thinking it would load only that particular batch.
I ran both model.fit and also model.fit_generator on sample pairs but i observed both are using the same amount memory. So, I guess think some problem with my code in using fit_generator. Below is the relevant code. Can you guys please help me with this?
Code Below:
def create_pairs(X_train, y_train):
tr_pairs = []
tr_y = []
y_train = np.array(y_train)
digit_indices = [np.where(y_train == i)[0] for i in list(set(y_train))]
for i in range(len(digit_indices)):
n = len(digit_indices[i])
for j in range(n):
random_index = digit_indices[i][j]
anchor_image = X_train[random_index]
anchor_label = y_train[random_index]
anchor_indices = [i for i, x in enumerate(y_train) if x == anchor_label]
negate_indices = list(set(list(range(0,len(X_train)))) - set(anchor_indices))
for k in range(j+1,n):
support_index = digit_indices[i][k]
support_image = X_train[support_index]
tr_pairs += [[anchor_image,support_image]]
negate_index = random.choice(negate_indices)
negate_image = X_train[negate_index]
tr_pairs += [[anchor_image,negate_image]]
tr_y += [1,0]
return np.array(tr_pairs),np.array(tr_y)
def myGenerator():
tr_pairs, tr_y = create_pairs(X_train, y_train)
while 1:
for i in range(110): # 1875 * 32 = 60000 -> # of training samples
if i%125==0:
print("i = " + str(i))
yield [tr_pairs[i*32:(i+1)*32][:, 0], tr_pairs[i*32:(i+1)*32][:, 1]], tr_y[i*32:(i+1)*32]
model.fit_generator(myGenerator(), steps_per_epoch=110, epochs=2,
verbose=1, callbacks=None, validation_data=([te_pairs[:, 0], te_pairs[:, 1]], te_y), validation_steps=None, class_weight=None,
max_queue_size=10, workers=1, use_multiprocessing=False, shuffle=True, initial_epoch=0)
Upvotes: 1
Views: 1091
Reputation: 916
myGenerator
returns a generator.
However you should notice that create_pairs
is loading the full dataset into memory. When you call tr_pairs, tr_y = create_pairs(X_train, y_train)
the dataset is loaded, so the memory resources are being used.
myGenerator
simply traverses a structure that is already in memory.
The solution would be to make create_pairs
a generator itself.
If the data is a numpy array I can suggest using h5
files to read chuncks of data from disk.
http://docs.h5py.org/en/latest/high/dataset.html#chunked-storage
Upvotes: 1