Reputation: 61
I am training a CGAN to reconstruct images from corrupted images. I have written all my code for variable batch sizes so I can also train on variable batch sizes (I dont get an error or anything). When I use batch size one, after 2 minutes the reconstructed images no longer have any strange artifacts. However here is my problem: for any other batch size I get very strange checkerboard artifacts, even when I try different learning rates or when I train for multiple hours.
This is a reconstructed image at batch size 2 after training for a while. (These weird artifacts were not in the corrupted data.)
This is the adversarial component to the generator loss at batch size 2.
This is the generator loss at batch size 2.
This is the discriminator loss at batch size 2.
For comparison, at batch size 1:
Orange is train, Blue is validation
It seems like my code does something completely different as soon as the batch size is larger than one. I'm certain that the batches are being loaded correctly. Am I going crazy?
My model:
self.original = tf.placeholder(tf.float32, shape=(None,conf.fig_size, conf.fig_size, conf.fig_channel), name="original")
self.corrupted = tf.placeholder(tf.float32, shape=(None,conf.fig_size, conf.fig_size, conf.fig_channel), name="corrupted")
self.reconstructed = self.generator(self.corrupted)
pos = self.discriminator(self.original, self.corrupted, False)
neg = self.discriminator(self.original, self.corrupted, True)
pos_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=pos, labels=tf.ones_like(pos)))
neg_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=neg, labels=tf.zeros_like(neg)))
self.d_loss = pos_loss + neg_loss
adv_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=neg, labels=tf.ones_like(neg)))
self.g_loss = adv_loss + conf.l1_lambda * tf.reduce_mean(tf.abs(self.original - self.reconstructed))
t_vars = tf.trainable_variables()
self.d_vars = [var for var in t_vars if 'disc' in var.name]
self.g_vars = [var for var in t_vars if 'gen' in var.name]
self.merged = tf.summary.merge_all()
def generator(self, corrupted):
conf = self.config
with tf.variable_scope("gen"):
feature = conf.conv_channel_base
e1 = conv2d(corrupted, feature, name="e1")
e2 = batch_norm(conv2d(lrelu(e1), feature*2, name="e2"), "e2", conf.batch_norm_decay)
e3 = batch_norm(conv2d(lrelu(e2), feature*4, name="e3"), "e3", conf.batch_norm_decay)
e4 = batch_norm(conv2d(lrelu(e3), feature*8, name="e4"), "e4", conf.batch_norm_decay)
e5 = batch_norm(conv2d(lrelu(e4), feature*8, name="e5"), "e5", conf.batch_norm_decay)
e6 = batch_norm(conv2d(lrelu(e5), feature*8, name="e6"), "e6", conf.batch_norm_decay)
e7 = batch_norm(conv2d(lrelu(e6), feature*8, name="e7"), "e7", conf.batch_norm_decay)
e8 = batch_norm(conv2d(lrelu(e7), feature*8, name="e8"), "e8", conf.batch_norm_decay)
size = conf.fig_size
num = [0] * 9
for i in range(1,9):
num[9-i]=size
size =(size+1)/2
d1 = deconv2d(tf.nn.relu(e8), [num[1],num[1],feature*8], name="d1")
d1 = tf.concat([tf.nn.dropout(batch_norm(d1, "d1", conf.batch_norm_decay), 0.5), e7], 3)
d2 = deconv2d(tf.nn.relu(d1), [num[2],num[2],feature*8], name="d2")
d2 = tf.concat([tf.nn.dropout(batch_norm(d2, "d2", conf.batch_norm_decay), 0.5), e6], 3)
d3 = deconv2d(tf.nn.relu(d2), [num[3],num[3],feature*8], name="d3")
d3 = tf.concat([tf.nn.dropout(batch_norm(d3, "d3", conf.batch_norm_decay), 0.5), e5], 3)
d4 = deconv2d(tf.nn.relu(d3), [num[4],num[4],feature*8], name="d4")
d4 = tf.concat([batch_norm(d4, "d4", conf.batch_norm_decay), e4], 3)
d5 = deconv2d(tf.nn.relu(d4), [num[5],num[5],feature*4], name="d5")
d5 = tf.concat([batch_norm(d5, "d5", conf.batch_norm_decay), e3], 3)
d6 = deconv2d(tf.nn.relu(d5), [num[6],num[6],feature*2], name="d6")
d6 = tf.concat([batch_norm(d6, "d6", conf.batch_norm_decay), e2], 3)
d7 = deconv2d(tf.nn.relu(d6), [num[7],num[7],feature], name="d7")
d7 = tf.concat([batch_norm(d7, "d7", conf.batch_norm_decay), e1], 3)
d8 = deconv2d(tf.nn.relu(d7), [num[8],num[8],conf.fig_channel], name="d8")
return tf.nn.tanh(d8)
def discriminator(self, original, corrupted, reuse):
conf = self.config
dim = len(original.get_shape())
with tf.variable_scope("disc", reuse=reuse):
image_pair = tf.concat([original, corrupted], dim - 1)
feature = conf.conv_channel_base
h0 = lrelu(conv2d(image_pair, feature, name="h0"))
h1 = lrelu(batch_norm(conv2d(h0, feature*2, name="h1"), "h1", conf.batch_norm_decay))
h2 = lrelu(batch_norm(conv2d(h1, feature*4, name="h2"), "h2", conf.batch_norm_decay))
h3 = lrelu(batch_norm(conv2d(h2, feature*8, name="h3"), "h3", conf.batch_norm_decay))
h4 = linear(tf.reshape(h3, [-1,h3.shape[1]*h3.shape[2]*h3.shape[3]]), 1, "linear")
return h4
def batch_norm(x, scope, decay):
return tf.contrib.layers.batch_norm(x, decay=decay, updates_collections=None, epsilon=1e-5, scale=True, scope=scope)
def conv2d(input, output_dim, k_h=4, k_w=4, d_h=2, d_w=2, stddev=0.02, name="conv2d"):
with tf.variable_scope(name):
weight = tf.get_variable('weight', [k_h, k_w, input.get_shape()[-1], output_dim],
initializer=tf.truncated_normal_initializer(stddev=stddev))
bias = tf.get_variable('bias', [output_dim], initializer=tf.constant_initializer(0.0))
conv = tf.nn.bias_add(tf.nn.conv2d(input, weight, strides=[1, d_h, d_w, 1], padding='SAME'), bias)
return conv
def deconv2d(input, output_shape, k_h=4, k_w=4, d_h=2, d_w=2, stddev=0.02, name="deconv2d"):
with tf.variable_scope(name):
dyn_batch_size = tf.shape(input)[0]
weight = tf.get_variable('weight', [k_h, k_w, output_shape[-1], input.get_shape()[-1]],initializer=tf.random_normal_initializer(stddev=stddev))
bias = tf.get_variable('bias', [output_shape[-1]], initializer=tf.constant_initializer(0.0))
output_shape = tf.stack([dyn_batch_size,output_shape[0],output_shape[1],output_shape[2]])
deconv = tf.nn.bias_add(tf.nn.conv2d_transpose(input, weight, output_shape=output_shape, strides=[1, d_h, d_w, 1]), bias)
return deconv
def lrelu(x, leak=0.2):
return tf.maximum(x, leak * x)
def linear(input, output_size, name="Linear", stddev=0.02, bias_start=0.0):
shape = input.get_shape().as_list()
with tf.variable_scope(name):
weight = tf.get_variable("weight", [shape[1], output_size], tf.float32,
tf.random_normal_initializer(stddev=stddev))
bias = tf.get_variable("bias", [output_size],
initializer=tf.constant_initializer(bias_start))
return tf.matmul(input, weight) + bias
My training:
d_opt = tf.train.AdamOptimizer(learning_rate=conf.learning_rate).minimize(model.d_loss, var_list=model.d_vars)
g_opt = tf.train.AdamOptimizer(learning_rate=conf.learning_rate).minimize(model.g_loss, var_list=model.g_vars)
with tf.Session(config=configProto) as sess:
for epoch in xrange(0, conf.max_epoch):
batch_index = 0
for original, corrupted in data.iterate_batches_train():
feed_dict = {model.original:preprocess(original), model.corrupted:preprocess(corrupted)}
sess.run([d_opt], feed_dict = feed_dict)
sess.run([d_opt], feed_dict = feed_dict)
sess.run([g_opt], feed_dict = feed_dict)
The default configuration for batch size one:
self.fig_size = 424
self.fig_channel = 1
self.conv_channel_base = 64
self.l1_lambda = 100
self.batch_norm_decay = 0.9
self.batch_size = 1
self.max_epoch = 20
self.learning_rate = 0.0002
I appreciate any insight you might have...
Upvotes: 3
Views: 1348
Reputation: 121
I think this is because of your use of batch normalization.
With batch size = 1, BN isn't really a meaningful operation.
With a small batch size > 1 you're working with statistics that aren't really reflective of your population and thus things can get wonky.
Can you try training with batch size = 2 and no BN?
Upvotes: 1