Reputation: 708
In Tensorflow, I am trying to build a model to perform image super-resolution (i.e. regression task) and analyze the results using TensorBoard. During training, I found that the mean square error (MSE) bounces between 100 to 200 most of the time (even from the beginning) and has never converged. I was hoping to add the following variables to tf.summary
and analyze what is the problem causing this.
graph_loss = get_graph_mean_square_error()
tf.summary.scalar('graph_loss', graph_loss)
regularization_loss = tf.add_n([tf.nn.l2_loss(weight) for weight in weights]) * regularization_param
tf.summary.scalar('reg_loss', regularization_loss)
tf.summary.scalar('overall_loss', regularization_loss + graph_loss)
for index in range(len(weights)):
tf.summary.histogram("weight[%02d]" % index, weights[index])
optimizer = tf.train.AdamOptimizer()
capped_grad_and_vars = [(tf.clip_by_value(grad, -clip_value, clip_value), var) for grad, var in grad_and_vars if grad is not None]
train_optimizer = optimizer.apply_gradients(capped_grad_and_vars, global_step)
for grad, var in grad_and_vars:
tf.summary.histogram(var.name + '/gradient', grad)
for grad, var in capped_grad_and_vars:
tf.summary.histogram(var.name + '/capped_gradient', grad)
The model is a ResNET with skipped connection which contains several [convolution -> batch normalization -> ReLU] repeated layers. In the Distributions tab, I can see that there are several graphs added with the following pattern:
There are few things I was looking at and would like someone to shed some light on them:
Using L2 loss for regularization
The value of regularization_param
was set to 0.0001 and reg_loss
graph showed that it increases from 1.5 (like logarithmically) and converges around 3.5. In my case, the graph_loss
is between 100 and 200 while reg_loss
is between 1.5 to 3.5.
reg_loss
graph we are looking for (like logarithmically increasing function)?reg_loss
too small to penalize the model (100-200 vs 1.5-3.5)?regularization_param
correctly?Addressing vanishing gradients problem
I was thinking the MSE bouncing problem from the beginning to the end could be due to the vanishing gradients problem. I was hoping to use several techniques like ResNET with the skipped connection, batch normalization and gradient clipping (clip_by_value
at 0.05) to address the vanishing gradients problem. I am not too sure how to read the graph but it looks to me the weights do not seem to change for the first 22 layers in the first 20K steps like this (I am not familiar with TensorBoard and please correct me if I read/interpret it incorrectly):
I have split the training into several runs and restore the checkpoints from the previous run. And here is the graph after 66K steps for the last few layers:
You can see that in the first few 20K steps the weights still change on some layers like weight_36_ and weight_37_ in orange. However, after 50K steps, all the weights look flat like weight_36_ (very thin) and weight_39_ (with little thickness) in green.
Then I look into batch normalization graph (note that capped_gradient is clip_by_value
at 0.05) and it looks like there are some changes like below:
Any other suggestions are welcome :)
Upvotes: 6
Views: 1444
Reputation: 376
- Is the trend of
reg_loss
graph we are looking for (like logarithmically > increasing function)?
Yes it looks okay.
- Would the
reg_loss
too small to penalize the model (100-200 vs 1.5-3.5)?- How do I know if I choose
regularization_param
correctly?
First I would suggest you to vary the learning rate from 0.001 to 0.1 (which is the very first thing to investigate the gradient clipping problem), and observe if the average MSE reduces to choose the best learning rate without reg_loss
. Then you can add back the regularization by fine tuning reg_loss
.
- Please, can someone explain if the above graph looks correct? (I do not understand why after each batch normalization there are some good values but the weights do not seem to change)
- Which direction should I look at to address the MSE bouncing problem from the beginning to the end?
Please double check if you take the average MSE for each epoch. Sometimes it could be normal to observe bouncing problem in each sub-epoch. But if you take the average MSE for each epoch, you might observe that it will go down gradually.
Upvotes: 1
Reputation: 66
Things to Try:
remove gradient clipping: You are clipping the gradient values at 0.05. I think that update = (0.05 * learning rate) yield very low weight updates and that is why most of the layers are not learning anything. If you clip the gradients of last layer(first from the output) to 0.05 then very low gradient values propagate back to its previous layer and multiplication to local gradients yield even lower gradient values. Thus you probably see last few layers learn something.
remove l2 regularization: Try removing the regularization, and removing the regularization solves the bouncing MSE problem then you should tune regularization parameter very carefully.
Upvotes: 2