sgu
sgu

Reputation: 1373

What's the purpose of keras.backend.function()

The Keras manual doesn't say too much:

keras.backend.function(inputs, outputs, updates=None)

Instantiates a Keras function.
Arguments
inputs: List of placeholder tensors.
outputs: List of output tensors.
updates: List of update ops.
**kwargs: Passed to tf.Session.run.
Returns

Tensorflow source code, which is actually quite short, shows that K.function(...) return a Function object which, when called, evaluates the outputs and updates using the inputs. The interesting part is how it handles the updates which I don't follow. Any explanations/examples/pointers to help understanding this K.function(...) is appreciated! Here is the relevant part from Tensorflow source code

class Function(object):
  """Runs a computation graph.
  Arguments:
      inputs: Feed placeholders to the computation graph.
      outputs: Output tensors to fetch.
      updates: Additional update ops to be run at function call.
      name: a name to help users identify what this function does.
  """

  def __init__(self, inputs, outputs, updates=None, name=None,
               **session_kwargs):
    updates = updates or []
    if not isinstance(inputs, (list, tuple)):
      raise TypeError('`inputs` to a TensorFlow backend function '
                      'should be a list or tuple.')
    if not isinstance(outputs, (list, tuple)):
      raise TypeError('`outputs` of a TensorFlow backend function '
                      'should be a list or tuple.')
    if not isinstance(updates, (list, tuple)):
      raise TypeError('`updates` in a TensorFlow backend function '
                      'should be a list or tuple.')
    self.inputs = list(inputs)
    self.outputs = list(outputs)
    with ops.control_dependencies(self.outputs):
      updates_ops = []
      for update in updates:
        if isinstance(update, tuple):
          p, new_p = update
          updates_ops.append(state_ops.assign(p, new_p))
        else:
          # assumed already an op
          updates_ops.append(update)
      self.updates_op = control_flow_ops.group(*updates_ops)
    self.name = name
    self.session_kwargs = session_kwargs

  def __call__(self, inputs):
    if not isinstance(inputs, (list, tuple)):
      raise TypeError('`inputs` should be a list or tuple.')
    feed_dict = {}
    for tensor, value in zip(self.inputs, inputs):
      if is_sparse(tensor):
        sparse_coo = value.tocoo()
        indices = np.concatenate((np.expand_dims(sparse_coo.row, 1),
                                  np.expand_dims(sparse_coo.col, 1)), 1)
        value = (indices, sparse_coo.data, sparse_coo.shape)
      feed_dict[tensor] = value
    session = get_session()
    updated = session.run(
        self.outputs + [self.updates_op],
        feed_dict=feed_dict,
        **self.session_kwargs)
    return updated[:len(self.outputs)]


def function(inputs, outputs, updates=None, **kwargs):
  """Instantiates a Keras function.
  Arguments:
      inputs: List of placeholder tensors.
      outputs: List of output tensors.
      updates: List of update ops.
      **kwargs: Passed to `tf.Session.run`.
  Returns:
      Output values as Numpy arrays.
  Raises:
      ValueError: if invalid kwargs are passed in.
  """
  if kwargs:
    for key in kwargs:
      if (key not in tf_inspect.getargspec(session_module.Session.run)[0] and
          key not in tf_inspect.getargspec(Function.__init__)[0]):
        msg = ('Invalid argument "%s" passed to K.function with Tensorflow '
               'backend') % key
        raise ValueError(msg)
  return Function(inputs, outputs, updates=updates, **kwargs)

Upvotes: 40

Views: 28125

Answers (3)

Pallavi
Pallavi

Reputation: 596

I have the following understanding of this function keras.backend.function. I will explain it with the help of a code snippet from this.

The part of code snippet is as follows

final_conv_layer = get_output_layer(model, "conv5_3")
get_output = K.function([model.layers[0].input], 
                        [final_conv_layer.output, model.layers[-1].output])
[conv_outputs, predictions] = get_output([img])
    

In this code, there is a model from which conv5_3 layer is extracted (line 1). In the function K.function(), the first argument is input to this model and second is set of 2 outputs - one for convolution and second for softmax output at the last layer.

As per the Keras/Tensorflow manual, this function runs the computation graph that we have created in the code, taking input from the first parameter and extracting the number of outputs as per the layers mentioned in the second parameter. Thus, conv_outputs are output of final_conv_layer and predictions are output of model.layers[-1], i.e. the last layer of the model.

Upvotes: 27

ZhaoPan Song
ZhaoPan Song

Reputation: 31

Think it as a function wrapper. In framework of keras and tensorflow, it wrappers list of tensor as input and does some operations on weights in network (backward propagation). It is specially useful in field of Reinforcement learning, where the loss is computed after actions(model output) and model.fit is too macro to incorporate such op.

Upvotes: 0

Jack Wang
Jack Wang

Reputation: 151

I think this function is just used to extract intermediate result. One can refer to the Keras Documentation about "How can I obtain the output of an intermediate layer?"

One simple way is to create a new Model that will ouput the layers that you are interested in:

from keras.models import Model

model = ...  # create the original model

layer_name = 'my_layer'
intermediate_layer_model = Model(inputs=model.input,
                                 outputs=model.get_layer(layer_name).output)
intermediate_output = intermediate_layer_model.predict(data)

Another way is to build a Keras function, which will return the output of a certain layer given input. For example:

from keras import backend as K

# with a Sequential model
get_3rd_layer_output = K.function([model.layers[0].input],
                                  [model.layers[3].output])
layer_output = get_3rd_layer_output([x])[0]

You can used the returned function object get_3rd_layer_output to get the intermediate result of the third layer.

Upvotes: 15

Related Questions