CLRW97
CLRW97

Reputation: 550

Calculating FLOPS for Keras Models (TF 2.x)

I found two solutions to calculate FLOPS of Keras models (TF 2.x):

[1] https://github.com/tensorflow/tensorflow/issues/32809#issuecomment-849439287

[2] https://github.com/tensorflow/tensorflow/issues/32809#issuecomment-841975359

At first glance, both seem to work perfectly when testing with tf.keras.applications.ResNet50(). The resulting FLOPS are identical and correspond to the FLOPS of the ResNet paper.

But then I built a small GRU model and found different FLOPS for the two methods:

model = Sequential(name=self._modelName)
model.add(GRU(recurrentCells, input_shape=(10, 8), recurrent_dropout=0.2, return_sequences=True))
model.add(GRU(recurrentCells, recurrent_dropout=0.2))
model.add(Dense(outputDimension, activation='tanh'))
model.compile(loss='mse', optimizer='adam', metrics=[RootMeanSquaredError(), MeanAbsoluteError()])

This results in the following numbers: 13206 for method [1] and 18306 for method [2]. That is really confusing...

Does anyone know how to correctly calculate FLOPS of recurrent Keras models in TF 2.x?

EDIT

I found another information:

[3] https://github.com/tensorflow/tensorflow/issues/36391#issuecomment-596055100

It seems like when our model have the LSTM or GRU's layers, what we need to do is by pass one more argument lower_control_flow=False in order to make the convert_variables_to_constants_v2 work.

When adding this argument to convert_variables_to_constants_v2, the outputs of [1] and [2] are the same when using my GRU example.

The tensorflow documentation explains this argument as follows (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/framework/convert_to_constants.py):

lower_control_flow: Boolean indicating whether or not to lower control flow ops such as If and While. (default True)

Can someone try to explain this?

Upvotes: 3

Views: 2572

Answers (1)

MasterOne Piece
MasterOne Piece

Reputation: 461

None of both are correct. You need to calculate the Flops for RNN models in TensorFlow manually.All modules on Github calculate CNN models only, but RNN don't.You can in PyTorch calculate RNN models perfectly. 13206 and 18306 for GRU are too low, they must at least in millions.

Upvotes: 0

Related Questions