Reputation: 926
I've a problem for resuming training after saving my model. The problem is that my loss decrease form 6 to 3 for example. At this time I save the model. When I restore it and continue training, the loss restart from 6. It seems that the restoration doesn't really work. I don't understand because printing the weights, it seems that they are loaded properly. I use an ADAM optimizer. Thanks in advance. Here:
batch_size = self.batch_size
num_classes = self.num_classes
n_hidden = 50 #700
n_layers = 1 #3
truncated_backprop = self.seq_len
dropout = 0.3
learning_rate = 0.001
epochs = 200
with tf.name_scope('input'):
x = tf.placeholder(tf.float32, [batch_size, truncated_backprop], name='x')
y = tf.placeholder(tf.int32, [batch_size, truncated_backprop], name='y')
with tf.name_scope('weights'):
W = tf.Variable(np.random.rand(n_hidden, num_classes), dtype=tf.float32)
b = tf.Variable(np.random.rand(1, num_classes), dtype=tf.float32)
inputs_series = tf.split(x, truncated_backprop, 1)
labels_series = tf.unstack(y, axis=1)
with tf.name_scope('LSTM'):
cell = tf.contrib.rnn.BasicLSTMCell(n_hidden, state_is_tuple=True)
cell = tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=dropout)
cell = tf.contrib.rnn.MultiRNNCell([cell] * n_layers)
states_series, current_state = tf.contrib.rnn.static_rnn(cell, inputs_series, \
dtype=tf.float32)
logits_series = [tf.matmul(state, W) + b for state in states_series]
prediction_series = [tf.nn.softmax(logits) for logits in logits_series]
losses = [tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels) \
for logits, labels, in zip(logits_series, labels_series)]
total_loss = tf.reduce_mean(losses)
train_step = tf.train.AdamOptimizer(learning_rate).minimize(total_loss)
tf.summary.scalar('total_loss', total_loss)
summary_op = tf.summary.merge_all()
loss_list = []
writer = tf.summary.FileWriter('tf_logs', graph=tf.get_default_graph())
all_saver = tf.train.Saver()
with tf.Session() as sess:
#sess.run(tf.global_variables_initializer())
tf.reset_default_graph()
saver = tf.train.import_meta_graph('./models/tf_models/rnn_model.meta')
saver.restore(sess, './models/tf_models/rnn_model')
for epoch_idx in range(epochs):
xx, yy = next(self.get_batch)
batch_count = len(self.D.chars) // batch_size // truncated_backprop
for batch_idx in range(batch_count):
batchX, batchY = next(self.get_batch)
summ, _total_loss, _train_step, _current_state, _prediction_series = sess.run(\
[summary_op, total_loss, train_step, current_state, prediction_series],
feed_dict = {
x : batchX,
y : batchY
})
loss_list.append(_total_loss)
writer.add_summary(summ, epoch_idx * batch_count + batch_idx)
if batch_idx % 5 == 0:
print('Step', batch_idx, 'Batch_loss', _total_loss)
if batch_idx % 50 == 0:
all_saver.save(sess, 'models/tf_models/rnn_model')
if epoch_idx % 5 == 0:
print('Epoch', epoch_idx, 'Last_loss', loss_list[-1])
Upvotes: 4
Views: 790
Reputation: 926
My problem was a code error in labels, they were changing between two run. So it works now. Thank you for the help
Upvotes: 0
Reputation: 21
I had the same problem, in my case, the model was being correctly restored but the loss was starting really high again and again, the problem was that my batch retreival was not random. I had three classes, A, B and C. My data was being fed in this manner A, then B, then C. I don't know if that is your problem but you must ensure that every batch you give to your model has all of your classes in it, so in your case, the batch must have batch_size/num_classes input per class. I changed it and everything worked perfectly :)
Check out if you are correctly feeding your model.
Upvotes: 1