Reputation: 6538
I have have trained binary classifier model. The model class contains self.cost
, self.initial_state
, self.final_state
and self.logits
params. It is saved simply with tf.train.Saver
:
saver = tf.train.Saver(tf.global_variables(), max_to_keep=1)
saver.save(session, 'model.ckpt')
After the model was trained I load it as:
with tf.variable_scope("Model", reuse=False):
model = MODEL(config, is_training=False)
with tf.Session() as session:
saver = tf.train.Saver(tf.global_variables())
saver.restore(session, 'model.ckpt')
However, my model.run
function returns cross-entropy loss which is the last op in the graph. I don't need loss, I need the model predictions for each batch element
logits = tf.sigmoid(tf.nn.xw_plus_b(last_layer, self.output_w, self.output_b))
where last_layer
is a 800x1
matrix which I then later reshape into 32x25x1
(batch_size, sequence_length, 1) matrix. It is this matrix that contains the model prediction values in [0-1] range.
So, how can I use this model to make a prediction for single element matrix 1x1x1
?
Upvotes: 0
Views: 88
Reputation: 32061
Add the OPs necessary to compute accuracy, something like what I have copied below (simply copied out of the closest model I had at hand).
self.logits_flat = tf.argmax(logits, axis=1, output_type=tf.int32)
labels_flat = tf.argmax(labels, axis=1, output_type=tf.int32)
accuracy = tf.cast(tf.equal(self.logits_flat, labels_flat), tf.float32, name='accuracy')
Now when you run the model (either during test or training time) add accuracy to the sess.run call as:
sess.run([train_op, accuracy], feed_dict=...)
or
sess.run([accuracy, logits], feed_dict=...)
All you're doing when you call sess.run
is to tell tensorflow to compute the value of whatever you ask for. You need to pass it in any data it needs to perform those computations. Tensorflow is lazy, it won't perform any computations that aren't explicitly necessary to produce the results you request. E.g. if you run the second version of sess.run
listed above the optimizer will not be run and hence your weights will not be updated.
Note that you can add the OPs after the network was trained because none of them actually add any variables so they won't affect the save/restore process any.
Upvotes: 1