Reputation: 8249
My question is related to this one here about persisting graphs: I wonder if it is possible to only export a certain part of a graph, e.g. a subgraph prefixed by a given scope, using TensorFlow 0.12 or newer.
I'm currently using a combination of tf.train.import_meta_graph()
, graph_util.convert_variables_to_constants()
and tf.train.write_graph()
to export ("freeze") a graph into a protocol buffer file which I then can load back using tf.import_graph_def()
. During the export I can specify which nodes are considered required outputs of the graph, so no upstream nodes are thrown away, while during import I can rewire certain parts of the graph to other operations using input_map
.
This all works fine, but it is missing the notion of unnecessary inputs and the problem is that by doing so, the entire upstream of the output_nodes
is written to the file as well, i.e. everything input and preprocessing related.
Currently, exporting looks like this:
output_nodes = ['subgraph/y', 'subgraph/dropout']
checkpoint = tf.train.latest_checkpoint(input_path)
importer = tf.train.import_meta_graph(checkpoint + '.meta', clear_devices=True)
graph = tf.get_default_graph()
gd = graph.as_graph_def()
with tf.Session() as sess:
importer.restore(sess, checkpoint)
output_graph_def = graph_util.\
convert_variables_to_constants(sess, gd, output_nodes)
tf.train.write_graph(output_graph_def, 'dir', 'file.pb', as_text=False)
While importing looks like this:
with tf.gfile.GFile(input_path, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
elems = tf.import_graph_def(
graph_def,
input_map=None,
return_elements=output_nodes,
name='imported'
)
Is there a way to filter and/or remove the parts I do not require during or before exporting? To be clear, I known I can ignore them after loading, but I do not to export them to begin with. Would using collections help at some point?
Upvotes: 3
Views: 2019
Reputation: 5162
The built in graph saving tools need to be reworked in my opinion.
The best solution I've found is to create a list of the variables you want saved then save them during the session. Afterwards you have the freedom to reload the weights and change the graph as needed.
parameters = []
parameters += [wieghts, biases]
...
def load_weights(self, weight_file, sess):
weights = np.load(weight_file)
keys = sorted(weights.keys())
for i, k in enumerate(keys):
print i, k, np.shape(weights[k])
sess.run(self.parameters[i].assign(weights[k]))
credit for idea goes to Davi Frossard here
Upvotes: 1