dsalaj
dsalaj

Reputation: 3197

How to extract weights and other variable values from tensorflow checkpoint without restoring the graph?

Provided a checkpoint file but no meta graph or code that produced the network, I want to extract the stored values of the variables in the checkpoint file.

So without restoring the graph, how do I extract the values stored in thr checkpoint. I could potentially convert everything from the checkpoint to a dictionary of numpy arrays or something similar.

Upvotes: 1

Views: 581

Answers (1)

dsalaj
dsalaj

Reputation: 3197

Found the solution:

reader = tf.train.NewCheckpointReader("/path/to/checkpoint")
shapes_dict = reader.get_variable_to_shape_map()  # use it to get the variable names
extracted_values = reader.get_tensor(shapes_dict.keys()[0])
# array([[ 0.       , -1.8053141],
#        [-1.5647348,  0.       ]], dtype=float32)

The tf.train.NewCheckpointReader is not really documented in current documentation of API r1.12. But you can see the usage example in the source code here.

Upvotes: 2

Related Questions