Antonina
Antonina

Reputation: 604

tensorflow - ValueError: The shape for decoder/while/Merge_12:0 is not an invariant for the loop

I use tf.contrib.seq2seq.dynamic_decode for decoder training

prediction, final_decoder_state, _ = dynamic_decode(
    custom_decoder
)

with custom decoder

custom_decoder = CustomDecoder(decoder_cell, helper, decoder_init_state)

and helper

helper = CustomTrainingHelper(batch_size, targets, stop_targets,
                              num_outs, outputs_per_step, 1.0, False)

And dynamic_decoder raises error

Traceback (most recent call last):
  File "E:/tasks/text_to_speech/tts/tf_seq2seq.py", line 95, in <module>
    custom_decoder
  File "C:\Users\User\Anaconda3\lib\site-packages\tensorflow\contrib\seq2seq\python\ops\decoder.py", line 304, in dynamic_decode
    swap_memory=swap_memory)
  File "C:\Users\User\Anaconda3\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 3224, in while_loop
    result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants)
  File "C:\Users\User\Anaconda3\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 2956, in BuildLoop
    pred, body, original_loop_vars, loop_vars, shape_invariants)
  File "C:\Users\User\Anaconda3\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 2930, in _BuildLoop
    next_vars.append(_AddNextAndBackEdge(m, v))
  File "C:\Users\User\Anaconda3\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 688, in _AddNextAndBackEdge
    _EnforceShapeInvariant(m, v)
  File "C:\Users\User\Anaconda3\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 632, in _EnforceShapeInvariant
    (merge_var.name, m_shape, n_shape))
ValueError: The shape for decoder/while/Merge_12:0 is not an invariant for the loop. It enters the loop with shape (10, 1), but has shape (?, 1) after one iteration. Provide shape invariants using either the `shape_invariants` argument of tf.while_loop or set_shape() on the loop variables.

batch_size is equal to 10. As I understand the issue is in tf.while_loop and batch_size. In what way it is possible to fix this error?

Upvotes: 0

Views: 216

Answers (1)

iga
iga

Reputation: 3633

In general, this error is telling you the following. By default TensorFlow checks that the variables passed from one iteration of the while loop to the next one don't change shape. In your case, the decoder/while/Merge_12:0 tensor originally had a shape of (10, 1) but after one iteration it became (?, 1) meaning that tensorflow can no longer infer the size of the first dimension.

If you know that the first dimension is really 10, you can use Tensor.set_shape to tell this to TensorFlow.

Upvotes: 1

Related Questions