Taras Sereda
Taras Sereda

Reputation: 420

Tensorflow tfprof LSTMCell

I'm using tfprof in order to get number of flops necessary for model forward path. My model is 3 layer LSTM and fully connected layer afterwards. I've observed that number of computations grows linearly for fully connected layer, while it doesn't changes for LSTM layers. How that could be possible?

tfprof Report for 1 timestamp forward path.

==================Model Analysis Report======================
_TFProfRoot (0/2.71m flops)
  rnn/while/multi_rnn_cell/cell_1/lstm_cell/lstm_cell_1/MatMul (1.05m/1.05m flops)
  rnn/while/multi_rnn_cell/cell_2/lstm_cell/lstm_cell_1/MatMul (1.05m/1.05m flops)
  rnn/while/multi_rnn_cell/cell_0/lstm_cell/lstm_cell_1/MatMul (606.21k/606.21k flops)
  fc_layer/MatMul (1.54k/1.54k flops)
  rnn/while/multi_rnn_cell/cell_0/lstm_cell/lstm_cell_1/BiasAdd (1.02k/1.02k flops)
  rnn/while/multi_rnn_cell/cell_1/lstm_cell/lstm_cell_1/BiasAdd (1.02k/1.02k flops)
  rnn/while/multi_rnn_cell/cell_2/lstm_cell/lstm_cell_1/BiasAdd (1.02k/1.02k flops)
  fc_layer/BiasAdd (3/3 flops)

tfprof Report for 2 timestamps forward path.

==================Model Analysis Report======================
_TFProfRoot (0/2.71m flops)
  rnn/while/multi_rnn_cell/cell_1/lstm_cell/lstm_cell_1/MatMul (1.05m/1.05m flops)
  rnn/while/multi_rnn_cell/cell_2/lstm_cell/lstm_cell_1/MatMul (1.05m/1.05m flops)
  rnn/while/multi_rnn_cell/cell_0/lstm_cell/lstm_cell_1/MatMul (606.21k/606.21k flops)
  fc_layer/MatMul (3.07k/3.07k flops)
  rnn/while/multi_rnn_cell/cell_0/lstm_cell/lstm_cell_1/BiasAdd (1.02k/1.02k flops)
  rnn/while/multi_rnn_cell/cell_1/lstm_cell/lstm_cell_1/BiasAdd (1.02k/1.02k flops)
  rnn/while/multi_rnn_cell/cell_2/lstm_cell/lstm_cell_1/BiasAdd (1.02k/1.02k flops)
  fc_layer/BiasAdd (6/6 flops)

Upvotes: 0

Views: 350

Answers (1)

Peter
Peter

Reputation: 108

tfprof does static analysis of your graph and calculate the float operations for each graph node.

I assume you are using dynamic_rnn or something similar that has tf.while_loop. In that case, a graph node appear in graph once but is actually run multiple times at run time.

In this case, tfprof has no way to statically figure out how many steps (timestamp in your word) will be run. Hence, it only counts the float operations once.

A work around for now is probably multiply timesteps by yourself.

Upvotes: 2

Related Questions