Vectoria
Vectoria

Reputation: 25

Using saved model for prediction in tensorflow

I use this code to restore my model, but I don't know how to predict after restoring it, which function can I use? I'm a beginner in tensorflow, I have no idea to which parameters or function will be saved.

In the meta model:

sess = tf.Session()
saver = tf.train.import_meta_graph("/home/MachineLearning/model.ckpt.meta")
saver.restore(sess,tf.train.latest_checkpoint('./'))
print("Model restored with success ")
x_predict,y_predict= load_svmlight_file('/MachineLearning/to_predict.csv')
x_predict = x_valid.toarray()
sess.run([] ,feed_dict ) #i don't know how to use predict function

These are the results:

$python predict.py
Model restored with success 
Traceback (most recent call last):
  File "predict.py", line 23, in <module>
    sess.run([] ,feed_dict )
NameError: name 'feed_dict' is not defined

Upvotes: 2

Views: 4074

Answers (1)

David Parks
David Parks

Reputation: 32111

You're almost there. Tensorflow is simply a math library. Your graph is a collection of math operations with the associated dependencies (e.g. a graph, DAG specifically).

When you loaded the graph and associated variables (weights) you loaded all the definitions. Now you need to ask tensorflow to compute some value in the graph. There are lots of values it could compute, the one you want is often named logits (a typical name for the output layer of a neural network). But note that it could be named anything (especially if this isn't a neural network model), you need to understand the model. You might also want to compute an operation named accuracy which is defined to compute the accuracy of a particular batch of inputs (again depends on your model).

Note that you will need to provide tensorflow with whatever it needs to perform these computations. There is generally a placeholder where you pass in your data (and during training a placeholder for your labels which you don't need for prediction because none of the operations you will ask tensorflow to compute depend on it).

But you will need to get references to these various operations (logits, and accuracy) and placeholders (x is a typical name). Since you loaded your graph from disk you don't have the references (note that an alternative way of loading the model is to re-run the code that builds the model, which gives you easy access to the references you need).

In order to get the right references you can look them up by name. Here's how you would get a list of all the operations:

List of tensor names in graph in Tensorflow

Then to get a specific OP (operation) by name:

How to get a tensorflow op by name?

So you'll have something like this:

logits = tf.get_default_graph().get_operation_by_name("logits:0")
x = tf.get_default_graph().get_operation_by_name("x:0")
accuracy = tf.get_default_graph().get_operation_by_name("accuracy:0")

Note that the :0 is an index added to all names in tensorflow to avoid duplicate names. Now you have all the references you need and you can use sess.run to perform a specific computation, providing the input data, and OPs you'd like to have computed:

sess.run([logits, accuracy], feed_dict={x:your_input_data_in_numpy_format})

The names of these elements will vary in your implementation, I've used the most common names. If they weren't given pretty names it'll be hard to identify them and you'll need to look through the original code that produced the graph. In fact if they weren't named properly looking them up by name is so painful that it's probably better to just re-run the code that produced the original graph rather than import the meta graph. Notice that saver.restore only restores the actual data, import_meta_graph is the optional piece which can be replaced by simply re-building the graph programmatically.

Upvotes: 3

Related Questions