lhlmgr
lhlmgr

Reputation: 2187

Dynamic function call, depneding on condition in tensorflow graph

I'm trying to implement a dynamic_rnn_decoder. However, I get an exception, because after the second element the Tensors in the cell are already created. Thus I want to set reuse=True after the first iteration. Is there a op which calls dynamically a function depending on a condition (like fn_dyn = tf.cond(cond, fn1, fn2))

Hence, I want to implement this dynamically:

if i > 0:
    variable_scope.get_variable_scope().reuse_variables()

The a simplified _time_step-function for _dynamic_rnn_loop could be something like that:

def _time_step(time, output_ta_t, *state):
    input_t = input_ta.read(time)
    # Restore some shape information
    input_t.set_shape([const_batch_size, const_depth])

    # Pack state back up for use by cell
    state = (_packed_state(structure=state_size, state=state)
             if state_is_tuple else state[0])

    def call_with_previous(feed_previous_t):
        if feed_previous_t:
            prev = output_ta_t.read(time - 1)

            if output_projection is not None:
                prev = nn_ops.xw_plus_b(prev, output_projection[0], output_projection[1])

            cell_input = math_ops.reduce_max(prev, 1)
            print(cell_input.get_shape())

            cell_input.set_shape([const_batch_size, const_depth])
        else:
            cell_input = input_t

        def call_cell_t(cell_input_t, state_t):
            # set ruse after first call
            output_t, state_t = cell(cell_input_t, state_t)
            variable_scope.get_variable_scope().reuse_variables()

            return output_t, state_t

        return lambda: call_cell_t(cell_input, state)

    # >>> doesn't work
    call_cell = tf.cond(tf.equal(time, tf.constant(0, dtype=tf.int32)),
                        call_with_previous(False),
                        call_with_previous(True))

    if sequence_length is not None:
        (output, new_state) = _rnn_step(
            time=time,
            sequence_length=sequence_length,
            min_sequence_length=min_sequence_length,
            max_sequence_length=max_sequence_length,
            zero_output=zero_output,
            state=state,
            call_cell=call_cell,
            state_size=state_size,
            skip_conditionals=True)
    else:
        (output, new_state) = call_cell()

    # Pack state if using state tuples
    new_state = (tuple(_unpacked_state(new_state)) if state_is_tuple else (new_state,))

    output_ta_t = output_ta_t.write(time, output)

    return (time + 1, output_ta_t) + new_state

Thanks, cheers!

Upvotes: 1

Views: 800

Answers (1)

Eugene Brevdo
Eugene Brevdo

Reputation: 899

while_loop only calls the underlying body function once. not dynamically for every time step. if you're getting an error when getting the variable, it's because you also access the variable elsewhere in your code.

In this case, looks like it's because of your cond statement. this causes two calls to cell(). Try to factor this so the cell call is outside the cond.

Alternatively, as a hack, have the cell call inside a try except block. If you get a variable access error, just set reuse variable and call it again.

source: i wrote dynamic_rnn.

Upvotes: 1

Related Questions