ustcyue
ustcyue

Reputation: 651

tensorflow Dataset.from_generator using an generator that yield tensors

I am trying to convert some code into the new dataset API so that I can use the distribution strategy. Below is what I am trying to do.

def dataset_generator():
    while True:
        features, labels = ex_lib.get_image_batch(), ex_lib.get_feature_batch()
        yield features, labels

def get_ssf_input_fn():
    def input_fn():
        return tf.data.Dataset.from_generator(dataset_generator,
                                              (tf.float32, tf.float32), ([None, config.image_height, config.image_width, config.image_channels], [None, 256]))

    return input_fn

the problem is ex_lib.get_image_batch and ex_lib.get_feature_batch gives me a tensor instead of a numpy array, and I cannot change the code in ex_lib. Also I cannot convert the tensor to numpy array here since I have no access to the sess here. With this code, it will throw

`generator` yielded an element that could not be converted to the expected type. The expected type was float32, but the yielded element was Tensor("GetImageBatch:0", dtype=uint8)

Is there a way to let my input_fn return a Dataset instead?

Upvotes: 1

Views: 1005

Answers (1)

ustcyue
ustcyue

Reputation: 651

I am able to work around this problem with the following trick. Its efficiency is OK.

tf.data.Dataset.from_tensors(0).repeat().map(lambda _: dataset_generator())

Upvotes: 1

Related Questions