Reputation: 651
I am trying to convert some code into the new dataset API so that I can use the distribution strategy. Below is what I am trying to do.
def dataset_generator():
while True:
features, labels = ex_lib.get_image_batch(), ex_lib.get_feature_batch()
yield features, labels
def get_ssf_input_fn():
def input_fn():
return tf.data.Dataset.from_generator(dataset_generator,
(tf.float32, tf.float32), ([None, config.image_height, config.image_width, config.image_channels], [None, 256]))
return input_fn
the problem is ex_lib.get_image_batch
and ex_lib.get_feature_batch
gives me a tensor instead of a numpy array, and I cannot change the code in ex_lib. Also I cannot convert the tensor to numpy array here since I have no access to the sess
here. With this code, it will throw
`generator` yielded an element that could not be converted to the expected type. The expected type was float32, but the yielded element was Tensor("GetImageBatch:0", dtype=uint8)
Is there a way to let my input_fn return a Dataset instead?
Upvotes: 1
Views: 1005
Reputation: 651
I am able to work around this problem with the following trick. Its efficiency is OK.
tf.data.Dataset.from_tensors(0).repeat().map(lambda _: dataset_generator())
Upvotes: 1