user2709619
user2709619

Reputation: 61

Tensorflow GAN only works when batch size equals one

I am training a CGAN to reconstruct images from corrupted images. I have written all my code for variable batch sizes so I can also train on variable batch sizes (I dont get an error or anything). When I use batch size one, after 2 minutes the reconstructed images no longer have any strange artifacts. However here is my problem: for any other batch size I get very strange checkerboard artifacts, even when I try different learning rates or when I train for multiple hours.

This is a reconstructed image at batch size 2 after training for a while. (These weird artifacts were not in the corrupted data.)

This is the adversarial component to the generator loss at batch size 2.

This is the generator loss at batch size 2.

This is the discriminator loss at batch size 2.

For comparison, at batch size 1:

advloss gloss dloss

Orange is train, Blue is validation

It seems like my code does something completely different as soon as the batch size is larger than one. I'm certain that the batches are being loaded correctly. Am I going crazy?

My model:

    self.original = tf.placeholder(tf.float32, shape=(None,conf.fig_size, conf.fig_size, conf.fig_channel), name="original")
    self.corrupted = tf.placeholder(tf.float32, shape=(None,conf.fig_size, conf.fig_size, conf.fig_channel), name="corrupted")

    self.reconstructed = self.generator(self.corrupted)

    pos = self.discriminator(self.original, self.corrupted, False)
    neg = self.discriminator(self.original, self.corrupted, True)

    pos_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=pos, labels=tf.ones_like(pos)))
    neg_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=neg, labels=tf.zeros_like(neg)))

    self.d_loss = pos_loss + neg_loss

    adv_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=neg, labels=tf.ones_like(neg)))
    self.g_loss =  adv_loss + conf.l1_lambda * tf.reduce_mean(tf.abs(self.original - self.reconstructed))

    t_vars = tf.trainable_variables()
    self.d_vars = [var for var in t_vars if 'disc' in var.name]
    self.g_vars = [var for var in t_vars if 'gen' in var.name]

    self.merged = tf.summary.merge_all()

def generator(self, corrupted):
    conf = self.config
    with tf.variable_scope("gen"):
        feature = conf.conv_channel_base
        e1 = conv2d(corrupted, feature, name="e1")
        e2 = batch_norm(conv2d(lrelu(e1), feature*2, name="e2"), "e2", conf.batch_norm_decay)
        e3 = batch_norm(conv2d(lrelu(e2), feature*4, name="e3"), "e3", conf.batch_norm_decay)
        e4 = batch_norm(conv2d(lrelu(e3), feature*8, name="e4"), "e4", conf.batch_norm_decay)
        e5 = batch_norm(conv2d(lrelu(e4), feature*8, name="e5"), "e5", conf.batch_norm_decay)
        e6 = batch_norm(conv2d(lrelu(e5), feature*8, name="e6"), "e6", conf.batch_norm_decay)
        e7 = batch_norm(conv2d(lrelu(e6), feature*8, name="e7"), "e7", conf.batch_norm_decay)
        e8 = batch_norm(conv2d(lrelu(e7), feature*8, name="e8"), "e8", conf.batch_norm_decay)

        size = conf.fig_size
        num = [0] * 9
        for i in range(1,9):
            num[9-i]=size
            size =(size+1)/2

        d1 = deconv2d(tf.nn.relu(e8), [num[1],num[1],feature*8], name="d1")
        d1 = tf.concat([tf.nn.dropout(batch_norm(d1, "d1", conf.batch_norm_decay), 0.5), e7], 3)
        d2 = deconv2d(tf.nn.relu(d1), [num[2],num[2],feature*8], name="d2")
        d2 = tf.concat([tf.nn.dropout(batch_norm(d2, "d2", conf.batch_norm_decay), 0.5), e6], 3)
        d3 = deconv2d(tf.nn.relu(d2), [num[3],num[3],feature*8], name="d3")
        d3 = tf.concat([tf.nn.dropout(batch_norm(d3, "d3", conf.batch_norm_decay), 0.5), e5], 3) 
        d4 = deconv2d(tf.nn.relu(d3), [num[4],num[4],feature*8], name="d4")
        d4 = tf.concat([batch_norm(d4, "d4", conf.batch_norm_decay), e4], 3)
        d5 = deconv2d(tf.nn.relu(d4), [num[5],num[5],feature*4], name="d5")
        d5 = tf.concat([batch_norm(d5, "d5", conf.batch_norm_decay), e3], 3) 
        d6 = deconv2d(tf.nn.relu(d5), [num[6],num[6],feature*2], name="d6")
        d6 = tf.concat([batch_norm(d6, "d6", conf.batch_norm_decay), e2], 3)
        d7 = deconv2d(tf.nn.relu(d6), [num[7],num[7],feature], name="d7")

        d7 = tf.concat([batch_norm(d7, "d7", conf.batch_norm_decay), e1], 3) 
        d8 = deconv2d(tf.nn.relu(d7), [num[8],num[8],conf.fig_channel], name="d8")

        return tf.nn.tanh(d8)

def discriminator(self, original, corrupted, reuse):
    conf = self.config
    dim = len(original.get_shape())
    with tf.variable_scope("disc", reuse=reuse):
        image_pair = tf.concat([original, corrupted], dim - 1)
        feature = conf.conv_channel_base
        h0 = lrelu(conv2d(image_pair, feature, name="h0"))
        h1 = lrelu(batch_norm(conv2d(h0, feature*2, name="h1"), "h1", conf.batch_norm_decay))
        h2 = lrelu(batch_norm(conv2d(h1, feature*4, name="h2"), "h2", conf.batch_norm_decay))
        h3 = lrelu(batch_norm(conv2d(h2, feature*8, name="h3"), "h3", conf.batch_norm_decay))
        h4 = linear(tf.reshape(h3, [-1,h3.shape[1]*h3.shape[2]*h3.shape[3]]), 1, "linear")
    return h4

def batch_norm(x, scope, decay):
    return tf.contrib.layers.batch_norm(x, decay=decay, updates_collections=None, epsilon=1e-5, scale=True, scope=scope)

def conv2d(input, output_dim, k_h=4, k_w=4, d_h=2, d_w=2, stddev=0.02, name="conv2d"):
    with tf.variable_scope(name):
        weight = tf.get_variable('weight', [k_h, k_w, input.get_shape()[-1], output_dim],
                            initializer=tf.truncated_normal_initializer(stddev=stddev))
        bias = tf.get_variable('bias', [output_dim], initializer=tf.constant_initializer(0.0))
        conv = tf.nn.bias_add(tf.nn.conv2d(input, weight, strides=[1, d_h, d_w, 1], padding='SAME'), bias)
        return conv

def deconv2d(input, output_shape, k_h=4, k_w=4, d_h=2, d_w=2, stddev=0.02, name="deconv2d"):
    with tf.variable_scope(name):
        dyn_batch_size = tf.shape(input)[0]
        weight = tf.get_variable('weight', [k_h, k_w, output_shape[-1], input.get_shape()[-1]],initializer=tf.random_normal_initializer(stddev=stddev))
        bias = tf.get_variable('bias', [output_shape[-1]], initializer=tf.constant_initializer(0.0))
        output_shape = tf.stack([dyn_batch_size,output_shape[0],output_shape[1],output_shape[2]])
        deconv = tf.nn.bias_add(tf.nn.conv2d_transpose(input, weight, output_shape=output_shape, strides=[1, d_h, d_w, 1]), bias)
        return deconv

def lrelu(x, leak=0.2):
    return tf.maximum(x, leak * x)

def linear(input, output_size, name="Linear", stddev=0.02, bias_start=0.0):
    shape = input.get_shape().as_list()
    with tf.variable_scope(name):
        weight = tf.get_variable("weight", [shape[1], output_size], tf.float32,
                                 tf.random_normal_initializer(stddev=stddev))
        bias = tf.get_variable("bias", [output_size],
                               initializer=tf.constant_initializer(bias_start))
        return tf.matmul(input, weight) + bias

My training:

d_opt = tf.train.AdamOptimizer(learning_rate=conf.learning_rate).minimize(model.d_loss, var_list=model.d_vars)
g_opt = tf.train.AdamOptimizer(learning_rate=conf.learning_rate).minimize(model.g_loss, var_list=model.g_vars)
with tf.Session(config=configProto) as sess:
        for epoch in xrange(0, conf.max_epoch):
            batch_index = 0
            for original, corrupted in data.iterate_batches_train():
                feed_dict = {model.original:preprocess(original), model.corrupted:preprocess(corrupted)}
                sess.run([d_opt], feed_dict = feed_dict)
                sess.run([d_opt], feed_dict = feed_dict)
                sess.run([g_opt], feed_dict = feed_dict)

The default configuration for batch size one:

    self.fig_size = 424
    self.fig_channel = 1

    self.conv_channel_base = 64
    self.l1_lambda = 100
    self.batch_norm_decay = 0.9

    self.batch_size = 1
    self.max_epoch = 20
    self.learning_rate = 0.0002

I appreciate any insight you might have...

Upvotes: 3

Views: 1348

Answers (1)

Karl
Karl

Reputation: 121

I think this is because of your use of batch normalization.

With batch size = 1, BN isn't really a meaningful operation.

With a small batch size > 1 you're working with statistics that aren't really reflective of your population and thus things can get wonky.

Can you try training with batch size = 2 and no BN?

Upvotes: 1

Related Questions