Athirooban
Athirooban

Reputation: 11

Batchsize in input shape of chainer CNN

I have a training set of 9957 images. The training set has shape (9957, 3, 60, 80). Is batchsize required when putting training set to model? If required can the original shape be considered correct for fitting to conv2D layer or do I need to add batchsize to input_shape?

X_train.shape

(9957, 60,80,3) from chainer.datasets import split_dataset_random from chainer.dataset import DatasetMixin

import numpy as np


class MyDataset(DatasetMixin):
   def __init__(self, X, labels):
       super(MyDataset, self).__init__()
       self.X_ = X
       self.labels_ = labels
       self.size_ = X.shape[0]

   def __len__(self):
       return self.size_

   def get_example(self, i):
       return np.transpose(self.X_[i, ...], (2, 0, 1)), self.labels_[i] 


batch_size = 3

label_train = y_trainHot1
dataset = MyDataset(X_train1, label_train)
dataset_train, valid = split_dataset_random(dataset, 8000, seed=0)
train_iter = iterators.SerialIterator(dataset_train, batch_size)
valid_iter = iterators.SerialIterator(valid, batch_size, repeat=False, 
shuffle=False)

Upvotes: 1

Views: 144

Answers (1)

Yuki Hashimoto
Yuki Hashimoto

Reputation: 1073

The code below tells you that you do not have to care the batch-size by yourself. You just use DatsetMixin and SerialIterator as is instructed in the tutorial of chainer.

from chainer.dataset import DatasetMixin
from chainer.iterators import SerialIterator
import numpy as np

NUM_IMAGES = 9957
NUM_CHANNELS = 3  # RGB
IMAGE_WIDTH = 60
IMAGE_HEIGHT = 80

NUM_CLASSES = 10

BATCH_SIZE = 32

TRAIN_SIZE = min(8000, int(NUM_IMAGES * 0.9))

images = np.random.rand(NUM_IMAGES, NUM_CHANNELS, IMAGE_WIDTH, IMAGE_HEIGHT)
labels = np.random.randint(0, NUM_CLASSES, (NUM_IMAGES,))


class MyDataset(DatasetMixin):
    def __init__(self, images_, labels_):
        # note: input arg.'s tailing underscore is just to avoid shadowing
        super(MyDataset, self).__init__()
        self.images_ = images_
        self.labels_ = labels_
        self.size_ = len(labels_)

    def __len__(self):
        return self.size_

    def get_example(self, i):
        return self.images_[i, ...], self.labels_[i]


dataset_train = MyDataset(images[:TRAIN_SIZE, ...], labels[:TRAIN_SIZE])
dataset_valid = MyDataset(images[TRAIN_SIZE:, ...], labels[TRAIN_SIZE:])
train_iter = SerialIterator(dataset_train, BATCH_SIZE)
valid_iter = SerialIterator(dataset_valid, BATCH_SIZE, repeat=False, shuffle=False)

###############################################################################
"""This block is just for the confirmation.

.. note: NOT recommended to call :func:`concat_examples` in your code.
    Use :class:`chainer.updaters.StandardUpdater` instead. 
"""
from chainer.dataset import concat_examples

batch_image, batch_label = concat_examples(next(train_iter))
print("batch_image.shape\n{}".format(batch_image.shape))
print("batch_label.shape\n{}".format(batch_label.shape))

Output

batch_image.shape
(32, 3, 60, 80)
batch_label.shape
(32,)

It should be noted that chainer.dataset.concat_example is a little bit tricky part. Usually, the users do not pay attention to this function, if you use StandardUpdater which conceals the native function chainer.dataset.concat_example.

Since chainer is designed on the scheme of Trainer, (Standard)Updater, some Optimizer, (Serial)Iterator and Dataset(Mixin), if you do not follow this scheme, you have to dive into the sea of chainer source code.

Upvotes: 1

Related Questions