VICTOR
VICTOR

Reputation: 1942

Tensorflow - How to batch the dataset

I am doing a convolution neural network for digit recognition. I want to train an image dataset but I don't know how to "batch" the training data.

I get two arrays storing the train_image and train_label:

print train_image.shape
# (73257, 1024)
# where I have 73257 images with size 32x32=1024

print train_label.shape
# (73257, 10)
# Digit '1' has label 1, '9' has label 9 and '0' has label 10

Now, I want to batch the training data with batch size = 50

    sess.run(tf.initialize_all_variables())
    train_image_batch, train_label_batch = tf.train.shuffle_batch([train_image,
       train_label, batch_size = 50, capacity = 50000, min_after_dequeue = 10000)

When I print out the train_image_batch

print train_image_batch
# Tensor("shuffle_batch:0", shape=(50, 73257, 1024), dtype=unit8)

I expect the shape should be (50, 1024)

Am I doing something wrong here?

Upvotes: 0

Views: 1084

Answers (1)

jackberry
jackberry

Reputation: 806

shuffle_batch expects single sample by default. To enforce it to accept multiple samples pass enqueue_many=True. Refer doc

train_image_batch, train_label_batch = tf.train.shuffle_batch(
    [train_image, train_label], batch_size = 50, enqueue_many=True, capacity = 50000, min_after_dequeue = 10000)

print(train_image_batch.shape)

Output:
(50, 1024)

Upvotes: 1

Related Questions