Haitao Leng
Haitao Leng

Reputation: 63

for the tf.nn.rnn_cell.BasicRNN,what's the difference between the state and output

as I know state = tanh(w * input + u * pre_state + b) output = state*w_out but for the tf.nn.rnn_cell.BasicRNN , I just get the unit_num (I think it's the dim of state) and at the api web page,Most basic RNN: output = new_state = activation(W * input + U * state + B so can I think in this function state = output? and the function just has w,u,b,but no w_out?

Upvotes: 1

Views: 970

Answers (1)

Ishamael
Ishamael

Reputation: 12795

What "vanilla" RNN that you describe does is it computes the new hidden state, and then uses some output projection to compute the output. In tensorflow they separated that "compute new hidden state" and "compute output projection" parts. The BasicRNN just outputs the hidden state as its output, another class called OutputProjectionWrapper can then apply a projection to it (and multiplying by w_out is just applying a projection). To get the behavior you want, you need to do:

tf.nn.rnn_cell.OutputProjectionWrapper(tf.nn.rnn_cell.BasicRNNCell(...), num_output_units)

It also allows you to have different number of neurons in your hidden state and in your output projection.

Upvotes: 3

Related Questions