kzs
kzs

Reputation: 1111

batch structure in tensorflow

I was following a neural network tutorial with tensorflow and MNIST dataset. I came across the following piece of code:

for _ in range(1000):
    batch = mnist.train.next_batch(50)
    train_step.run(feed_dict={x: batch[0], y: batch[1]})

I have some problems visualizing the batch structure. In particular, the index of the batch. Does batch[0] somehow mean all 50 images in the batch and batch[1] mean all 50 labels for the images? It would be great if someone could show structure of batch visually. I searched but could not find a good tutorial on this.

Upvotes: 0

Views: 75

Answers (1)

Welcome_back
Welcome_back

Reputation: 1445

Here Is my basic code for display images batch wise

def display_one_flower(image, title, subplot, red=False, titlesize=16):
    plt.subplot(*subplot)
    plt.axis('off')
    plt.imshow(image)
    if len(title) > 0:
        plt.title(title, fontsize=int(titlesize) if not red else int(titlesize/1.2), 
color='red' if red else 'black', fontdict={'verticalalignment':'center'}, 
pad=int(titlesize/1.5))
    return (subplot[0], subplot[1], subplot[2]+1)

def display_batch_of_images(images,labels, predictions=None):
    """This will work with:
    display_batch_of_images(images)
    display_batch_of_images(images, predictions)
    display_batch_of_images((images, labels))
    display_batch_of_images((images, labels), predictions)
   """

    # auto-squaring: this will drop data that does not fit into square or square- 
    #rectangle
    rows = int(math.sqrt(len(images)))
    cols = len(images)//rows

    # size and spacing
    FIGSIZE = 13.0
    SPACING = 0.1
    subplot=(rows,cols,1)
    if rows < cols:
        plt.figure(figsize=(FIGSIZE,FIGSIZE/cols*rows))
    else:
        plt.figure(figsize=(FIGSIZE/rows*cols,FIGSIZE))

    # display
    tempo=""
    for i, (image, label) in enumerate(zip(images[:rows*cols], labels[:rows*cols])):
        title = '' if label is None else tempo 
        correct = True
        if predictions is not None:
        title, correct = title_from_label_and_target(predictions[i], label)
        dynamic_titlesize = FIGSIZE*SPACING/max(rows,cols)*40+3 # magic formula tested to work from 1x1 to 10x10 images
        subplot = display_one_flower(image, title, subplot, not correct, titlesize=dynamic_titlesize)

    #layout
    plt.tight_layout()
    if label is None and predictions is None:
        plt.subplots_adjust(wspace=0, hspace=0)
    else:
        plt.subplots_adjust(wspace=SPACING, hspace=SPACING)
    plt.show()

Now we display images batch wise

images contain all 50 images and label contain all 50 label

  display_batch_of_images(images,labels)

Upvotes: 1

Related Questions