sunside
sunside

Reputation: 8249

Freezing/exporting a part of a TensorFlow graph

My question is related to this one here about persisting graphs: I wonder if it is possible to only export a certain part of a graph, e.g. a subgraph prefixed by a given scope, using TensorFlow 0.12 or newer.

I'm currently using a combination of tf.train.import_meta_graph(), graph_util.convert_variables_to_constants() and tf.train.write_graph() to export ("freeze") a graph into a protocol buffer file which I then can load back using tf.import_graph_def(). During the export I can specify which nodes are considered required outputs of the graph, so no upstream nodes are thrown away, while during import I can rewire certain parts of the graph to other operations using input_map.

This all works fine, but it is missing the notion of unnecessary inputs and the problem is that by doing so, the entire upstream of the output_nodes is written to the file as well, i.e. everything input and preprocessing related.

Currently, exporting looks like this:

output_nodes = ['subgraph/y', 'subgraph/dropout']

checkpoint = tf.train.latest_checkpoint(input_path)
importer = tf.train.import_meta_graph(checkpoint + '.meta', clear_devices=True)

graph = tf.get_default_graph()
gd = graph.as_graph_def()

with tf.Session() as sess:
    importer.restore(sess, checkpoint)
    output_graph_def = graph_util.\
        convert_variables_to_constants(sess, gd, output_nodes)
    tf.train.write_graph(output_graph_def, 'dir', 'file.pb', as_text=False)

While importing looks like this:

with tf.gfile.GFile(input_path, "rb") as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())

elems = tf.import_graph_def(
        graph_def,
        input_map=None,
        return_elements=output_nodes,
        name='imported'
    )

Is there a way to filter and/or remove the parts I do not require during or before exporting? To be clear, I known I can ignore them after loading, but I do not to export them to begin with. Would using collections help at some point?

Upvotes: 3

Views: 2019

Answers (1)

Steven
Steven

Reputation: 5162

The built in graph saving tools need to be reworked in my opinion.

The best solution I've found is to create a list of the variables you want saved then save them during the session. Afterwards you have the freedom to reload the weights and change the graph as needed.

parameters = []
parameters += [wieghts, biases]
...
def load_weights(self, weight_file, sess):
    weights = np.load(weight_file)
    keys = sorted(weights.keys())
    for i, k in enumerate(keys):
        print i, k, np.shape(weights[k])
        sess.run(self.parameters[i].assign(weights[k]))

credit for idea goes to Davi Frossard here

Upvotes: 1

Related Questions