JP Kim
JP Kim

Reputation: 741

How to read tensorflow memory mapped graph file in android?

Using Tensorflow 1.0.1 it's fine to read optimized graph and quantized graph in android using TensorFlowImageClassifier.create method, such as:

            classifier = TensorFlowImageClassifier.create(
                    c.getAssets(),
                    MODEL_FILE,
                    LABEL_FILE,
                    IMAGE_SIZE,
                    IMAGE_MEAN,
                    IMAGE_STD,
                    INPUT_NAME,
                    OUTPUT_NAME);

But according to the Peter Warden's Blog(https://petewarden.com/2016/09/27/tensorflow-for-mobile-poets/), it's recommended to use memory mapped graph in mobile to avoid memory related crashes.

I built memmapped graph using

bazel-bin/tensorflow/contrib/util/convert_graphdef_memmapped_format \
--in_graph=/tf_files/rounded_graph.pb \
--out_graph=/tf_files/mmapped_graph.pb

and it created fine, but when I tried to load the file with TensorFlowImageClassifier.create(...) it says the file is not valid graph file.

In iOS, it's ok to load the file with

LoadMemoryMappedModel(
        model_file_name, model_file_type, &tf_session, &tf_memmapped_env);

for it has a method for read memory mapped graph.

So, I guess there's a similar function in android, but I couldn't find it.

Could someone guide me how to load memory mapped graph in android ?

Upvotes: 2

Views: 1324

Answers (1)

Pete Warden
Pete Warden

Reputation: 2878

Since the file from the memmapped tool is no longer a standard GraphDef protobuf, you need to make some changes to the loading code. You can see an example of this in the iOS Camera demo app, the LoadMemoryMappedModel() function: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/ios_examples/camera/tensorflow_utils.mm#L159

The same code (with the Objective C calls for getting the filenames substituted) can be used on other platforms too. Because we’re using memory mapping, we need to start by creating a special TensorFlow environment object that’s set up with the file we’ll be using:

std::unique_ptr<tensorflow::MemmappedEnv> memmapped_env;
memmapped_env->reset(
      new tensorflow::MemmappedEnv(tensorflow::Env::Default()));
  tensorflow::Status mmap_status =
      (memmapped_env->get())->InitializeFromFile(file_path);

You then need to pass in this environment to subsequent calls, like this one for loading the graph.

tensorflow::GraphDef tensorflow_graph;
tensorflow::Status load_graph_status = ReadBinaryProto(
    memmapped_env->get(),
    tensorflow::MemmappedFileSystem::kMemmappedPackageDefaultGraphDef,
    &tensorflow_graph);

You also need to create the session with a pointer to the environment you’ve created:

tensorflow::SessionOptions options;
options.config.mutable_graph_options()
    ->mutable_optimizer_options()
    ->set_opt_level(::tensorflow::OptimizerOptions::L0);
options.env = memmapped_env->get();

tensorflow::Session* session_pointer = nullptr;
tensorflow::Status session_status =
    tensorflow::NewSession(options, &session_pointer);

One thing to notice here is that we’re also disabling automatic optimizations, since in some cases these will fold constant sub-trees, and so create copies of tensor values that we don’t want and use up more RAM. This setup also means it's hard to use a model stored as an APK asset in Android, since those are compressed and don't have normal filenames. Instead you'll need to copy your file out of an APK onto a normal filesytem location.

Once you’ve gone through these steps, you can use the session and graph as normal, and you should see a reduction in loading time and memory usage.

Upvotes: 3

Related Questions