Jame
Jame

Reputation: 3854

How to implement correctly next_batch in custom dataset in tensorflow?

I am looking for a right way to implement the next_batch in tensorflow. My training data is train_X=10000x50 where 10000 is a number of sample and 50 is number of the feature vector, and train_Y=10000x1. I used a batch size of 128. This is my function to get training batch during training

def next_batch(num, data, labels):
    '''
    Return a total of `num` random samples and labels.
    '''
    idx = np.arange(0 , data.shape[0])
    np.random.shuffle(idx)
    idx = idx[:num]
    data_shuffle = [data[ i,:] for i in idx]
    labels_shuffle = [labels[ i] for i in idx]
    return np.asarray(data_shuffle), np.asarray(labels_shuffle)

n_samples = 10000
batch_size =128

with tf.Session() as sess:
sess.run(init)
n_batches = int(n_samples / batch_size)
for i in range(n_epochs):
    for j in range(n_batches):
        X_batch, Y_batch = next_batch(batch_size,train_X,train_Y)

With the above function, I found that the shuffle function is called for each batch, which is not the wanted behavior. We have to scan all shuffled elements in the training data before shuffling once again for the next new epoch. Am I right? How to fix it in tensorflow? Thanks

Upvotes: 0

Views: 724

Answers (1)

benjaminplanche
benjaminplanche

Reputation: 15119

A solution would be to use a generator to yield your batches, in order to keep track of the sampling state (the list of shuffled indices and your current position in this list).

Find below a solution you could build on.

def next_batch(num, data, labels):
    '''
    Return a total of maximum `num` random samples and labels.
    NOTE: The last batch will be of size len(data) % num
    '''
    num_el = data.shape[0]
    while True: # or whatever condition you may have
        idx = np.arange(0 , num_el)
        np.random.shuffle(idx)
        current_idx = 0
        while current_idx < num_el:
            batch_idx = idx[current_idx:current_idx+num]
            current_idx += num
            data_shuffle = [data[ i,:] for i in batch_idx]
            labels_shuffle = [labels[ i] for i in batch_idx]
            yield np.asarray(data_shuffle), np.asarray(labels_shuffle)

n_samples = 10000
batch_size =128

with tf.Session() as sess:
    sess.run(init)
    n_batches = int(n_samples / batch_size)
    next_batch_gen = next_batch(batch_size, train_X, train_Y)
    for i in range(n_epochs):
        for j in range(n_batches):
            X_batch, Y_batch = next(next_batch_gen)
            print(Y_batch)

Upvotes: 2

Related Questions