Reputation: 53
I read some issue about same problem, but it looks like my issue some different. I want to freeze graph and then use it.
Here simple example how I do this. First, I create session and save both checkpoint and GraphDef:
a = tf.Variable(tf.constant(1.), name='a')
b = tf.Variable(tf.constant(2.), name='b')
c = tf.placeholder(tf.float32, shape =[1], name="c")
add = tf.add(a, b, 'sum')
add2 = tf.add(add, c, 'sum2')
dir_path = "<full_path>/simple_store"
with tf.Session() as sess:
tf.initialize_all_variables().run()
sess.run([add2], feed_dict={c:[7.]})
tf.train.Saver().save(sess, dir_path + "/" + 'simple.ckpt')
tf.train.write_graph(graph_def=sess.graph.as_graph_def(), logdir=dir_path, name='simple_as_text.pb')
Then I use bazel tool for freezing such way:
../tensorflow/bazel-bin/tensorflow/python/tools/freeze_graph --input_graph=simple_store/simple_as_text.pb --input_checkpoint=simple_store/simple.ckpt --output_graph=simple_store/freeze_out.pb --output_node_names=sum2
Then I load freeze_out.pb in Python and try run:
import tensorflow as tf
from tensorflow.core.framework import graph_pb2, cost_graph_pb2
graph_def = graph_pb2.GraphDef()
d = None
c = tf.placeholder(tf.float32, shape=[1], name="c")
feed_dict = {c: [5.]}
with tf.Session() as session:
print("load graph")
with open("<somepath>/simple_store/freeze_out.pb", "rb") as f:
graph_def.ParseFromString(f.read())
d = tf.import_graph_def(graph_def, return_elements=["sum2:0"], name='')
print(session.run([d[0]], feed_dict=feed_dict))
And finally I get following error:
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
<ipython-input-3-6345b17fba3b> in <module>()
8 d = tf.import_graph_def(graph_def, return_elements=["sum2:0"], name='')
9 tf.initialize_all_variables().run()
---> 10 print(session.run([d[0]], feed_dict=feed_dict))
/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.pyc in run(self, fetches, feed_dict, options, run_metadata)
370 try:
371 result = self._run(None, fetches, feed_dict, options_ptr,
--> 372 run_metadata_ptr)
373 if run_metadata:
374 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.pyc in _run(self, handle, fetches, feed_dict, options, run_metadata)
634 try:
635 results = self._do_run(handle, target_list, unique_fetches,
--> 636 feed_dict_string, options, run_metadata)
637 finally:
638 # The movers are no longer used. Delete them.
/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.pyc in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
706 if handle is None:
707 return self._do_call(_run_fn, self._session, feed_dict, fetch_list,
--> 708 target_list, options, run_metadata)
709 else:
710 return self._do_call(_prun_fn, self._session, handle, feed_dict,
/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.pyc in _do_call(self, fn, *args)
726 except KeyError:
727 pass
--> 728 raise type(e)(node_def, op, message)
729
730 def _extend_graph(self):
InvalidArgumentError: You must feed a value for placeholder tensor 'c_1' with dtype float and shape [1]
[[Node: c_1 = Placeholder[dtype=DT_FLOAT, shape=[1], _device="/job:localhost/replica:0/task:0/cpu:0"]()]]
Caused by op u'c_1', defined at:
File "/usr/lib/python2.7/runpy.py", line 174, in _run_module_as_main
"__main__", fname, loader, pkg_name)
File "/usr/lib/python2.7/runpy.py", line 72, in _run_code
exec code in run_globals
File "/home/artem/.local/lib/python2.7/site-packages/ipykernel/__main__.py", line 3, in <module>
app.launch_new_instance()
File "/home/artem/.local/lib/python2.7/site-packages/traitlets/config/application.py", line 596, in launch_instance
app.start()
File "/home/artem/.local/lib/python2.7/site-packages/ipykernel/kernelapp.py", line 442, in start
ioloop.IOLoop.instance().start()
File "/home/artem/.local/lib/python2.7/site-packages/zmq/eventloop/ioloop.py", line 162, in start
super(ZMQIOLoop, self).start()
File "/home/artem/.local/lib/python2.7/site-packages/tornado/ioloop.py", line 887, in start
handler_func(fd_obj, events)
File "/home/artem/.local/lib/python2.7/site-packages/tornado/stack_context.py", line 275, in null_wrapper
return fn(*args, **kwargs)
File "/home/artem/.local/lib/python2.7/site-packages/zmq/eventloop/zmqstream.py", line 440, in _handle_events
self._handle_recv()
File "/home/artem/.local/lib/python2.7/site-packages/zmq/eventloop/zmqstream.py", line 472, in _handle_recv
self._run_callback(callback, msg)
File "/home/artem/.local/lib/python2.7/site-packages/zmq/eventloop/zmqstream.py", line 414, in _run_callback
callback(*args, **kwargs)
File "/home/artem/.local/lib/python2.7/site-packages/tornado/stack_context.py", line 275, in null_wrapper
return fn(*args, **kwargs)
File "/home/artem/.local/lib/python2.7/site-packages/ipykernel/kernelbase.py", line 276, in dispatcher
return self.dispatch_shell(stream, msg)
File "/home/artem/.local/lib/python2.7/site-packages/ipykernel/kernelbase.py", line 228, in dispatch_shell
handler(stream, idents, msg)
File "/home/artem/.local/lib/python2.7/site-packages/ipykernel/kernelbase.py", line 391, in execute_request
user_expressions, allow_stdin)
File "/home/artem/.local/lib/python2.7/site-packages/ipykernel/ipkernel.py", line 199, in do_execute
shell.run_cell(code, store_history=store_history, silent=silent)
File "/home/artem/.local/lib/python2.7/site-packages/IPython/core/interactiveshell.py", line 2705, in run_cell
interactivity=interactivity, compiler=compiler, result=result)
File "/home/artem/.local/lib/python2.7/site-packages/IPython/core/interactiveshell.py", line 2809, in run_ast_nodes
if self.run_code(code, result):
File "/home/artem/.local/lib/python2.7/site-packages/IPython/core/interactiveshell.py", line 2869, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-3-6345b17fba3b>", line 8, in <module>
d = tf.import_graph_def(graph_def, return_elements=["sum2:0"], name='')
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/importer.py", line 274, in import_graph_def
op_def=op_def)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 2260, in create_op
original_op=self._default_original_op, op_def=op_def)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 1230, in __init__
self._traceback = _extract_stack()
What I did wrong? How I should correct this?
Upvotes: 1
Views: 3240
Reputation: 894
In my case when I tried the above method, the node automatically got renamed. So, here's what worked for me. First I printed out all the nodes using
[n.name for n in tf.get_default_graph().as_graph_def().node]
I had not named my placeholders while training my model but looking at output of above code I could make out that they were automatically named 'Placeholder' and 'Placeholder_1'
Then I used following lines to get tensors
x = tf.get_default_graph().get_tensor_by_name("Placeholder:0")
y = tf.get_default_graph().get_tensor_by_name("Placeholder_1:0")
This gave me the required placeholders. In case for above question, after importing the graph def, just doing
c = tf.get_default_graph().get_tensor_by_name("c:0")
should work.
Upvotes: 0
Reputation: 126184
The tf.import_graph_def()
function maintains the structure of the imported graph, unless you pass the input_map
argument.
In the original graph you passed to freeze_graph
, the tensor named "sum2:0"
depends on a placeholder operation called "c"
which is in the same graph. When you import the frozen graph, TensorFlow first imports the node named "c"
When you import the frozen graph using tf.import_graph_def()
, TensorFlow first imports the node named "c"
in freeze_out.pb
. However, because you have created another node named "c"
in your second program, the imported node is renamed to "c_1"
(an automatically generated unique name) and the imported version of "sum_2"
is rewritten to depend on "c_1"
. Notably, it does not depend on the placeholder that you are feeding (which is named "c"
).
There are two solutions. The more straightforward solution is to extract the previously created placeholder from the imported graph, rather than creating a new one. You can do this by adding "c:0"
to the list of return_elements
:
graph_def = tf.GraphDef()
with open("<somepath>/simple_store/freeze_out.pb", "rb") as f:
graph_def.ParseFromString(f.read())
# Also extract the placeholder "c" from the imported graph.
c, d = tf.import_graph_def(graph_def, return_elements=["c:0", "sum2:0"])
with tf.Session() as session:
print(session.run([d[0]], feed_dict={c: [5.]}))
Alternatively, you can remap the placeholder in the imported graph to use the placeholder in your new graph. (There is not much point in doing this substitution, but it can be useful when the new graph is more complex and includes some new preprocessing, for example.) This uses the input_map
argument to tf.import_graph_def()
:
graph_def = tf.GraphDef()
# Create a new placeholder that we will map into the imported graph.
# (N.B. This has no advantage, but could be useful if `c` were a more interesting
# function.)
c = tf.placeholder(tf.float32, shape=[1], name="c")
feed_dict = {c: [5.]}
with open("<somepath>/simple_store/freeze_out.pb", "rb") as f:
graph_def.ParseFromString(f.read())
# Also remap the placeholder in the imported graph to use the placeholder created
# above. Notice that the syntax is like the feed_dict, but this performs a static
# remapping of one tensor to another at graph construction time.
d = tf.import_graph_def(graph_def, input_map={"c:0": c}, return_elements=["sum2:0"])
with tf.Session() as session:
print(session.run([d[0]], feed_dict={c: [5.]}))
Upvotes: 3