pfc
pfc

Reputation: 1910

How to calculate the flops of a tensorflow model loaded from pb file

I have a model saved in a pb file. I hope to calculate the flops of it. My example code is as follow:

import tensorflow as tf
import sys
from tensorflow.python.platform import gfile

from tensorflow.core.protobuf import saved_model_pb2
from tensorflow.python.util import compat

pb_file = 'themodel.pb'

run_meta = tf.RunMetadata()
with tf.Session() as sess:
    print("load graph")
    with gfile.FastGFile(pb_path,'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        sess.graph.as_default()
        tf.import_graph_def(graph_def, name='')
        flops = tf.profiler.profile(tf.get_default_graph(), run_meta=run_meta,
            options=tf.profiler.ProfileOptionBuilder.float_operation())
        print("test flops:{:,}".format(flops.total_float_ops))

The print information is strange. My model has tens of layers, but it reports only 18 flops in the printed information. I'm quite sure the model is correctly loaded because if I try to print the names of every layer as follows:

print([n.name for n in tf.get_default_graph().as_graph_def().node])

The print information shows exactly the right network.

What's wrong with my code?

Thank you!

Upvotes: 5

Views: 2895

Answers (2)

pfc
pfc

Reputation: 1910

I think I find the reason and solution for my question. The following code can print the flops of the given pb file.

import os
import tensorflow as tf
from tensorflow.core.framework import graph_pb2
from tensorflow.python.framework import importer

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

pb_path = 'mymodel.pb'

run_meta = tf.RunMetadata()
with tf.Graph().as_default():
    output_graph_def = graph_pb2.GraphDef()
    with open(pb_path, "rb") as f:
        output_graph_def.ParseFromString(f.read())
        _ = importer.import_graph_def(output_graph_def, name="")
        print('model loaded!')
    all_keys = sorted([n.name for n in tf.get_default_graph().as_graph_def().node])
    # for k in all_keys:
    #   print(k)

    with tf.Session() as sess:
        flops = tf.profiler.profile(tf.get_default_graph(), run_meta=run_meta,
            options=tf.profiler.ProfileOptionBuilder.float_operation())
        print("test flops:{:,}".format(flops.total_float_ops))

The reason why the flops printed in the question being merely 18 is that, when generating the pb file, I set the input image shape as [None, None, 3]. If I change it to, say [500, 500, 3], then the printed flops will be correct.

Upvotes: 1

Allen Lavoie
Allen Lavoie

Reputation: 5808

Not sure how it would compute any performance measure without knowing the inputs and outputs: maybe it needs CallableOptions? I'd use trace_next_step and a Session rather than computing those manually.

Upvotes: 0

Related Questions