kwotsin
kwotsin

Reputation: 2923

TensorFlow: Will more memory be consumed if I handle secondary computations within a graph?

If I have a trained Imagnet model from Google (inception-resnet-v2) and the model actually gives me two outputs back: the logits, and a list called end_points where I can extract the final predicted layer that has been performed with softmax activation, i.e. a variable called predictions. However, this does not definitively give me the class label which I would require for predictions. To do so, I would either have to do label = tf.argmax(predictions, 1) after I defined the train_op in the graph, so that I won't affect the original computation.

Alternatively, I can use np.argmax(sess.run(predictions), 1) which is computed out of the graph.

My question is if I choose to do the first approach, would it consume more memory and affect my computation (in terms of the batch_size I can use)? Is it safer and better to just compute the necessary labels out of the graph?

Upvotes: 2

Views: 222

Answers (1)

Yaroslav Bulatov
Yaroslav Bulatov

Reputation: 57913

When you issue multiple .run calls, the graph definition is cached. If you modify Graph, it'll need to re-encode it and send it again. So there may be a bit of extra memory used by graph_def.SerializeToString the first time you run modified graph, but it shouldn't affect .run steps after that.

Relevant logic is in session.py, note the line which checks self._graph.version > self._current_version

 def _extend_graph(self):
    # Ensure any changes to the graph are reflected in the runtime.
    with self._extend_lock:
      if self._graph.version > self._current_version:
        # pylint: disable=protected-access
        graph_def, self._current_version = self._graph._as_graph_def(
            from_version=self._current_version,
            add_shapes=self._add_shapes)
        # pylint: enable=protected-access

        with errors.raise_exception_on_not_ok_status() as status:
          tf_session.TF_ExtendGraph(
              self._session, graph_def.SerializeToString(), status)
        self._opened = True

Upvotes: 1

Related Questions