Reputation: 2187
I'm trying to implement a dynamic_rnn_decoder
. However, I get an exception, because after the second element the Tensors in the cell are already created. Thus I want to set reuse=True
after the first iteration.
Is there a op
which calls dynamically a function depending on a condition (like fn_dyn = tf.cond(cond, fn1, fn2)
)
Hence, I want to implement this dynamically:
if i > 0:
variable_scope.get_variable_scope().reuse_variables()
The a simplified _time_step
-function for _dynamic_rnn_loop
could be something like that:
def _time_step(time, output_ta_t, *state):
input_t = input_ta.read(time)
# Restore some shape information
input_t.set_shape([const_batch_size, const_depth])
# Pack state back up for use by cell
state = (_packed_state(structure=state_size, state=state)
if state_is_tuple else state[0])
def call_with_previous(feed_previous_t):
if feed_previous_t:
prev = output_ta_t.read(time - 1)
if output_projection is not None:
prev = nn_ops.xw_plus_b(prev, output_projection[0], output_projection[1])
cell_input = math_ops.reduce_max(prev, 1)
print(cell_input.get_shape())
cell_input.set_shape([const_batch_size, const_depth])
else:
cell_input = input_t
def call_cell_t(cell_input_t, state_t):
# set ruse after first call
output_t, state_t = cell(cell_input_t, state_t)
variable_scope.get_variable_scope().reuse_variables()
return output_t, state_t
return lambda: call_cell_t(cell_input, state)
# >>> doesn't work
call_cell = tf.cond(tf.equal(time, tf.constant(0, dtype=tf.int32)),
call_with_previous(False),
call_with_previous(True))
if sequence_length is not None:
(output, new_state) = _rnn_step(
time=time,
sequence_length=sequence_length,
min_sequence_length=min_sequence_length,
max_sequence_length=max_sequence_length,
zero_output=zero_output,
state=state,
call_cell=call_cell,
state_size=state_size,
skip_conditionals=True)
else:
(output, new_state) = call_cell()
# Pack state if using state tuples
new_state = (tuple(_unpacked_state(new_state)) if state_is_tuple else (new_state,))
output_ta_t = output_ta_t.write(time, output)
return (time + 1, output_ta_t) + new_state
Thanks, cheers!
Upvotes: 1
Views: 800
Reputation: 899
while_loop only calls the underlying body function once. not dynamically for every time step. if you're getting an error when getting the variable, it's because you also access the variable elsewhere in your code.
In this case, looks like it's because of your cond statement. this causes two calls to cell(). Try to factor this so the cell call is outside the cond.
Alternatively, as a hack, have the cell call inside a try except block. If you get a variable access error, just set reuse variable and call it again.
source: i wrote dynamic_rnn.
Upvotes: 1