wheresmycookie
wheresmycookie

Reputation: 763

Modifying a tensorflow graph to output an intermediate value, after training

I'm really new to TF, and so this is my disclaimer that what I'm asking might not make much sense. I'd appreciate any corrections to my understanding. I'm happy to provide more code / information if necessary.

I'm working from the following tutorial: https://www.oreilly.com/learning/perform-sentiment-analysis-with-lstms-using-tensorflow.

I've added name_scopes to the variables / placeholders / etc to help me understand what's going on. Instead of posting all of the code, I thought just posting an image of the graph might be enough for this question:

enter image description here

There are a number of things about this graph that I still don't understand, so as a side note: If anybody has good resources for getting a good intuition for these graphs, I'd appreciate that guidance.

My understanding

It looks like the graph currently accepts a feed of input_data and labels in order to calculate error during training. I believe that "accuracy" is currently the output (as it doesn't have any outputs itself?). It makes sense to me that the cost takes as inputs the current predictions and source-of-truth labels.

Because I found this as part of a tutorial, of course training works well and I couldn't have done that by myself just yet. I'm willing to overlook that for just a second as I try to grasp for intuition here.

My question

I'm interested in now calling sess.run() on my graph with only input_data, and viewing the results of "predictions". It seems reasonable - I don't even have labels when I'm using this model, say, in a production system. The whole point is to get back predictions.

What steps might I take to so that I can call sess.run and get back the new desired output? I'd still somehow need to be able to train the model, though? What "process" might I use to be able to train with both placeholders, and then reduce it to one for predicting?

Upvotes: 0

Views: 182

Answers (1)

zephyrus
zephyrus

Reputation: 1266

The argument of sess.run is always a reference to a node on the graph (i.e. what you have provided an image of).

Tensorflow is written such that it only needs the values of upstream values in order to compute the value at some node--not all possible inputs. Your question appears to be how to get the predictions from the networks without providing the truth labels (what you want the network to learn during training). This is the quintessential "testing" scenario.

With no more information about your code, it seems like you should be able to simply do:

with tf.Session() as sess:
  predictions_eval = sess.run(predictions, feed_dict={input_data=input_data})

Upvotes: 1

Related Questions