Reputation: 2659
I am trying to print the result of the final batch in tf.function
import tensorflow as tf
def small_data():
for i in range(10):
yield 3, 2
data = tf.data.Dataset.from_generator(
small_data, (tf.int32, tf.int32), )
def result(data):
"""
Psuedo code for a model which outputs multiple layer outputs
:param data:
:return:
"""
return tf.random.normal(shape=[1, 2]), tf.random.normal(shape=[1, 2]),data[0]
@tf.function
def train(dataset):
batch_result = None
for batch in dataset:
batch_result = result(data)
tf.print("Final batch result is", batch_result)
train(dataset=data)
Error
raise ValueError("None values not supported.")
ValueError: None values not supported.
result function is actually a Keras model which results in layer outputs of different shapes. If I remove the batch_result=None
assignment and move the tf.print inside the loop, It prints for each batch. I want to print the result only for the last batch. Also, I am not sure about the number of records fed to the loop. I have also tried multiple variations but nothing worked. How can I achieve this in tensorflow 2.0.
Upvotes: 0
Views: 933
Reputation: 1466
You have to mimic the expected form of batch_result. This works:
@tf.function
def train(dataset):
batch_result = result(dataset.take(1))
for batch in dataset:
batch_result = result(data)
tf.print("Final batch result is", batch_result)
A bit hackish, but this might work:
@tf.function
def train(dataset):
batch_result = result(next(dataset.__iter__()))
for batch in dataset:
batch_result = result(data)
tf.print("Final batch result is", batch_result)
Upvotes: 1