Reputation: 63
as I know state = tanh(w * input + u * pre_state + b) output = state*w_out but for the tf.nn.rnn_cell.BasicRNN , I just get the unit_num (I think it's the dim of state) and at the api web page,Most basic RNN: output = new_state = activation(W * input + U * state + B so can I think in this function state = output? and the function just has w,u,b,but no w_out?
Upvotes: 1
Views: 970
Reputation: 12795
What "vanilla" RNN that you describe does is it computes the new hidden state, and then uses some output projection to compute the output. In tensorflow they separated that "compute new hidden state" and "compute output projection" parts. The BasicRNN
just outputs the hidden state as its output, another class called OutputProjectionWrapper
can then apply a projection to it (and multiplying by w_out
is just applying a projection). To get the behavior you want, you need to do:
tf.nn.rnn_cell.OutputProjectionWrapper(tf.nn.rnn_cell.BasicRNNCell(...), num_output_units)
It also allows you to have different number of neurons in your hidden state and in your output projection.
Upvotes: 3