Reputation: 3854
I am looking for a right way to implement the next_batch
in tensorflow. My training data is train_X=10000x50
where 10000 is a number of sample and 50 is number of the feature vector, and train_Y=10000x1
. I used a batch size of 128. This is my function to get training batch during training
def next_batch(num, data, labels):
'''
Return a total of `num` random samples and labels.
'''
idx = np.arange(0 , data.shape[0])
np.random.shuffle(idx)
idx = idx[:num]
data_shuffle = [data[ i,:] for i in idx]
labels_shuffle = [labels[ i] for i in idx]
return np.asarray(data_shuffle), np.asarray(labels_shuffle)
n_samples = 10000
batch_size =128
with tf.Session() as sess:
sess.run(init)
n_batches = int(n_samples / batch_size)
for i in range(n_epochs):
for j in range(n_batches):
X_batch, Y_batch = next_batch(batch_size,train_X,train_Y)
With the above function, I found that the shuffle
function is called for each batch, which is not the wanted behavior. We have to scan all shuffled elements in the training data before shuffling once again for the next new epoch. Am I right? How to fix it in tensorflow? Thanks
Upvotes: 0
Views: 724
Reputation: 15119
A solution would be to use a generator to yield your batches, in order to keep track of the sampling state (the list of shuffled indices and your current position in this list).
Find below a solution you could build on.
def next_batch(num, data, labels):
'''
Return a total of maximum `num` random samples and labels.
NOTE: The last batch will be of size len(data) % num
'''
num_el = data.shape[0]
while True: # or whatever condition you may have
idx = np.arange(0 , num_el)
np.random.shuffle(idx)
current_idx = 0
while current_idx < num_el:
batch_idx = idx[current_idx:current_idx+num]
current_idx += num
data_shuffle = [data[ i,:] for i in batch_idx]
labels_shuffle = [labels[ i] for i in batch_idx]
yield np.asarray(data_shuffle), np.asarray(labels_shuffle)
n_samples = 10000
batch_size =128
with tf.Session() as sess:
sess.run(init)
n_batches = int(n_samples / batch_size)
next_batch_gen = next_batch(batch_size, train_X, train_Y)
for i in range(n_epochs):
for j in range(n_batches):
X_batch, Y_batch = next(next_batch_gen)
print(Y_batch)
Upvotes: 2