StalkerMuse
StalkerMuse

Reputation: 1029

how to calculate a net's FLOPs in CNN

I want to design a convolutional neural network which occupy GPU resource no more than Alexnet. I want to use FLOPs to measure it but I don't know how to calculate it. Is there any tool to do it?

Upvotes: 23

Views: 28469

Answers (2)

al2
al2

Reputation: 159

Tobias Scheck's answer works if you are using TensorFlow 1.x, but if you are using TensorFlow 2.x you should use the following code:

import tensorflow as tf

def get_flops(model_h5_path):
    session = tf.compat.v1.Session()
    graph = tf.compat.v1.get_default_graph()
        

    with graph.as_default():
        with session.as_default():
            model = tf.keras.models.load_model(model_h5_path)

            run_meta = tf.compat.v1.RunMetadata()
            opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation()
        
            # We use the Keras session graph in the call to the profiler.
            flops = tf.compat.v1.profiler.profile(graph=graph,
                                                  run_meta=run_meta, cmd='op', options=opts)
        
            return flops.total_float_ops

The above function takes the path of a saved model in h5 format. You can save your model and use the function this way:

model.save('path_to_my_model.h5')
tf.compat.v1.reset_default_graph()
print(get_flops('path_to_my_model.h5'))

Note that we use tf.compat.v1.reset_default_graph() for not to accumulate FLOPS each time we call the fuction.

Upvotes: 0

Tobias Scheck
Tobias Scheck

Reputation: 633

If you use Keras and TensorFlow as Backend then you can try the following example. It calculates the FLOPs for the MobileNet.

import tensorflow as tf
import keras.backend as K
from keras.applications.mobilenet import MobileNet

run_meta = tf.RunMetadata()
with tf.Session(graph=tf.Graph()) as sess:
    K.set_session(sess)
    net = MobileNet(alpha=.75, input_tensor=tf.placeholder('float32', shape=(1,32,32,3)))

    opts = tf.profiler.ProfileOptionBuilder.float_operation()    
    flops = tf.profiler.profile(sess.graph, run_meta=run_meta, cmd='op', options=opts)

    opts = tf.profiler.ProfileOptionBuilder.trainable_variables_parameter()    
    params = tf.profiler.profile(sess.graph, run_meta=run_meta, cmd='op', options=opts)

    print("{:,} --- {:,}".format(flops.total_float_ops, params.total_parameters))

Upvotes: 13

Related Questions