Reputation: 1651
I'm trying to execute these functions
def evaluate(sentence):
sentence = preprocess_sentence(sentence)
sentence = tf.expand_dims(
START_TOKEN + tokenizer.encode(sentence) + END_TOKEN, axis=0)
output = tf.expand_dims(START_TOKEN, 0)
for i in range(MAX_LENGTH):
predictions = model(inputs=[sentence, output], training=False)
# select the last word from the seq_len dimension
predictions = predictions[:, -1:, :]
predicted_id = tf.cast(tf.argmax(predictions, axis=-1), tf.int32)
# return the result if the predicted_id is equal to the end token
if tf.equal(predicted_id, END_TOKEN[0]):
break
#check()
#tf.cond(tf.equal(predicted_id, END_TOKEN[0]),true_fn=break,false_fn=lambda: tf.no_op())
# concatenated the predicted_id to the output which is given to the decoder
# as its input.
output = tf.concat([output, predicted_id], axis=-1)
return tf.squeeze(output, axis=0)
def predict(sentence):
prediction = evaluate(sentence)
predicted_sentence = tokenizer.decode(
[i for i in prediction if i < tokenizer.vocab_size])
print('Input: {}'.format(sentence))
print('Output: {}'.format(predicted_sentence))
return predicted_sentence
however, I'm having the following error:
OperatorNotAllowedInGraphError: using a `tf.Tensor` as a Python `bool` is not allowed in Graph execution. Use Eager execution or decorate this function with @tf.function.
I do understand that I have to rewrite the if condtion in a form of tf.cond(). however, i don't know how to write break
in tensor flow, also I'm not sure which if condition is causing the problem as the same function exactly in this notebook is working properly?
https://colab.research.google.com/github/tensorflow/examples/blob/master/community/en/transformer_chatbot.ipynb#scrollTo=_NURhwYz5AXa
Any help?
Upvotes: 3
Views: 5628
Reputation: 1356
There is nothing wrong with the break statement. The problem is elsewhere.
if tf.equal(predicted_id, END_TOKEN[0]):
break
will give error something about using Python bool in tensor ops. Since you have already used tf.equal condition this could be confusing. The solution is simple. The error is being thrown for the
if (boolean): python syntax.
You would have to take care of this (bool) Python syntax and convert to tensor-style, based on what you are planning to achieve. Remember, the condition returns a tensor of boolean values. Read this tensor and then proceed to do what you want.. So for e.g. below would work unconditionally irrespective of the value of the condition:
if tf.equal(predicted_id, END_TOKEN[0]) is not None:
break
Upvotes: 1
Reputation: 975
The code in the notebook works because it uses TF 2.0, which has eager execution enabled by default. You can turn it on in older versions with tf.enable_eager_execution
.
Alternatively, you can use break in graph mode without writing tf.cond
if you use tf.function or tf.autograph, but they have some restrictions on the code you can run.
Upvotes: 1