Reputation: 465
TLDR: How to use Variables from frozen tensorflow graphs on Android?
1. What I want to do
I have a Tensorflow model that keeps an internal state in multiple variables, created with: state_var = tf.Variable(tf.zeros(shape, dtype=tf.float32), name='state', trainable=False)
.
This state is modified during inference:
tf.assign(state_var, new_value)
I now want to deploy the model on Android. I was able to make the Tensorflow example App run. There, a frozen model is loaded, which works fine.
2. Restoring variables from frozen graph does not work
However, when you freeze a graph using the freeze_graph script, all Variables are converted to constants. This is fine for weights of the network, but not for the internal state. The inference fails with the following message. I interpret this as "assign does not work on constant tensors"
java.lang.RuntimeException: Failed to load model from 'file:///android_asset/model.pb'
at org.tensorflow.contrib.android.TensorFlowInferenceInterface.<init>(TensorFlowInferenceInterface.java:113)
...
Caused by: java.io.IOException: Not a valid TensorFlow Graph serialization: Input 0 of node layer_1/Assign was passed float from layer_1/state:0 incompatible with expected float_ref.
Luckily, you can blacklist Variables from being converted to constants. However, this also doesn't work because the frozen graph now contains uninitialized variables.
java.lang.IllegalStateException: Attempting to use uninitialized value layer_7/state
3. Restoring SavedModel does not work on Android
One last version I have tried is to use the SavedModel
format which should contain both, a frozen graph and the variables. Unfortunately, calling the restore method does not work on Android.
SavedModelBundle bundle = SavedModelBundle.load(modelFilename, modelTag);
// produces error:
E/AndroidRuntime: FATAL EXCEPTION: main
Process: org.tensorflow.demo, PID: 27451
java.lang.UnsupportedOperationException: Loading a SavedModel is not supported in Android. File a bug at https://github.com/tensorflow/tensorflow/issues if this feature is important to you at org.tensorflow.SavedModelBundle.load(Native Method)
4. How can I make this work?
I don't know what else I can try. Here's what I would imagine, but I don't know how to make it work:
Upvotes: 0
Views: 726
Reputation: 465
I have solved this myself by going down a different route. To the best of my knowledge, the "variable" concept cannot be used in the same way on Android as I was used to in Python (e.g. you cannot initialize variables and then have an internal state of the network be updated during inference).
Instead, you can use placehlder and output nodes to preserve the state inside your Java code and feed it to the network on every inference call.
tf.Variable
occurances with tf.placeholder
. The shape stays the same.tf.identity(inputs, name='state_output')
During inference on Android, you then feed the initial state into the network.
float[] values = {0, 0, 0, ...}; // zeros of the correct shape
inferenceInterface.feed('state', values, ...);
After inference, you read the resulting internal state of the network
float[] values = new float[output_shape];
inferenceInterface.fetch('state_output', values);
You then remember this output in Java to pass it into the 'state'
placeholder for the next inference call.
Upvotes: 1