Ushnish
Ushnish

Reputation: 81

Tensorflow GRU cell error when fetching activations with variable sequence length

I want to run a GRU cell on some time series data to cluster them according to the activations in the last layer. I made one small change to the GRU cell implementation

def __call__(self, inputs, state, scope=None):
"""Gated recurrent unit (GRU) with nunits cells."""
with vs.variable_scope(scope or type(self).__name__):  # "GRUCell"
  with vs.variable_scope("Gates"):  # Reset gate and update gate.
    # We start with bias of 1.0 to not reset and not update.
    r, u = array_ops.split(1, 2, linear([inputs, state], 2 * self._num_units, True, 1.0))
    r, u = sigmoid(r), sigmoid(u)
  with vs.variable_scope("Candidate"):
    c = tanh(linear([inputs, r * state], self._num_units, True))
  new_h = u * state + (1 - u) * c

  # store the activations, everything else is the same
  self.activations = [r,u,c]
return new_h, new_h

After this I concatenate the activations in the following manner before I return them in the script which calls this GRU cell

@property
def activations(self):
    return self._activations


@activations.setter
def activations(self, activations_array):
    print "PRINT THIS"         
    concactivations = tf.concat(concat_dim=0, values=activations_array, name='concat_activations')
    self._activations = tf.reshape(tensor=concactivations, shape=[-1], name='flatten_activations')

I invoke the GRU cell in the following manner

outputs, state = rnn.rnn(cell=cell, inputs=x, initial_state=initial_state, sequence_length=s)

Where s is an array of batch length with the number of timestamps in each element of the input batch.

And finally I fetch using

fetched = sess.run(fetches=cell.activations, feed_dict=feed_dict)

When executing I get the following error

Traceback (most recent call last): File "xxx.py", line 162, in fetched = sess.run(fetches=cell.activations, feed_dict=feed_dict) File "/xxx/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 315, in run return self._run(None, fetches, feed_dict) File "/xxx/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 511, in _run feed_dict_string) File "/xxx/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 564, in _do_run target_list) File "/xxx/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 588, in _do_call six.reraise(e_type, e_value, e_traceback) File "/xxx/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 571, in _do_call return fn(*args) File "/xxx/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 555, in _run_fn

return tf_session.TF_Run(session, feed_dict, fetch_list, target_list) tensorflow.python.pywrap_tensorflow.StatusNotOK: Invalid argument: The tensor returned for RNN/cond_396/ClusterableGRUCell/flatten_activations:0 was not valid.

Can someone give an insight as to how to fetch the activations from a GRU cell at the last step, with passing variable length sequences? Thanks.

Upvotes: 4

Views: 356

Answers (1)

Alexandre Passos
Alexandre Passos

Reputation: 5206

To fetch the activations from the last step you want to make the activations a part of your state, which is returned by tf.rnn.

Upvotes: 0

Related Questions