kangaroo
kangaroo

Reputation: 427

TensorFlow: How to save the trained model parameters to the file that can be imported to other frameworks?

I'd like to pass the parameters of the trained model (weights and bias for convolution and fully connected layers) to other frameworks or languages including iOS and Torch by parsing the saved file.

I tried tf.train.write_graph(session.graph_def, '', 'graph.pb'), but it seems it only includes the graph architecture without weights and bias. If so, to create checkpoint file (saver.save(session, "model.ckpt")) is the best way? Is it easy to parse ckpt file type in Swift or other languages?

Please let me know if you have any suggestions.

Upvotes: 3

Views: 1454

Answers (1)

Gustavo Bezerra
Gustavo Bezerra

Reputation: 11044

Instead of parsing a .ckpt file, you can just try evaluating the tensor (in your case the weights of a convolutional layer) and getting the values as a numpy array. Here is a quick toy example (tested on r0.10 - there might some small API changes in newer versions):

import tensorflow as tf
import numpy as np

x = tf.placeholder(np.float32, [2,1])
w = tf.Variable(tf.truncated_normal([2,2], stddev=0.1))
b = tf.Variable(tf.constant(1.0, shape=[2,1]))
z = tf.matmul(w, x) + b

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    w_val, z_val = sess.run([w, z], feed_dict={x: np.arange(2).reshape(2,1)})
    print(w_val)
    print(z_val)

Output:

[[-0.02913031  0.13549708]
 [ 0.13807134  0.03763327]]
[[ 1.13549709]
 [ 1.0376333 ]]

If you have trouble getting a reference to your tensor (say it is in nested into a higher-level "layer" operation), try finding by name. More info here: Tensorflow: How to get a tensor by name?

If you want to see the how the weights change during training, you can also try to save all the values you are interested into tf.Summary objects and parse them later: Parsing `summary_str` byte string evaluated on tensorflow summary object

Upvotes: 1

Related Questions