Reputation: 1279
I'm trying to use the TensorFlow C API to load and execute a graph. It keeps failing and I can't figure out why.
I first use this Python script to create a very simple graph and save it to a file.
import tensorflow as tf
graph = tf.Graph()
with graph.as_default():
input = tf.placeholder(tf.float32, [10, 3], name='input')
output = tf.reduce_sum(input**2, name='output')
tf.train.write_graph(graph, '.', 'test.pbtxt')
Then I use this C++ code to load it in.
#include <fstream>
#include <iostream>
#include <string>
#include <c_api.h>
using namespace std;
int main() {
ifstream graphFile("test.pbtxt");
string graphText((istreambuf_iterator<char>(graphFile)), istreambuf_iterator<char>());
TF_Buffer* buffer = TF_NewBufferFromString(graphText.c_str(), graphText.size());
TF_Graph* graph = TF_NewGraph();
TF_ImportGraphDefOptions* importOptions = TF_NewImportGraphDefOptions();
TF_Status* status = TF_NewStatus();
TF_GraphImportGraphDef(graph, buffer, importOptions, status);
cout<<TF_GetCode(status)<<endl;
return 0;
}
The status code it prints is 3, or TF_INVALID_ARGUMENT
. Which argument is invalid and why? I verified the file contents are loaded correctly into graphText
, and all the other arguments are trivial.
Upvotes: 1
Views: 798
Reputation: 71
First of all, I think you should write the Graph with as_graph_def()
, in your case:
with open('test.pb', 'wb') as f:
f.write(graph.as_graph_def().SerializeToString())
Apart from it, I recommend you not to use the C API directly as it is error prone with memory leaks. Instead I have tried your code using cppflow, a C++ wrapper, and it works like a charm. I have used the following code:
# Load model
Model model("../test.pb");
# Declare tensors by name
auto input = new Tensor(model, "input");
auto output = new Tensor(model, "output");
# Feed data
std::vector<float> data(30, 1);
input->set_data(data);
# Run and show
model.run(input, output);
std::cout << output->get_data<float>()[0] << std::endl;
Upvotes: 1