Read tfrecord during batch costruction with single tf session

Im trying to adapt in Keras a data model generator for the model.fit_generator() method. The point is to read from a tfrecord the image at some index during the batch's costruction.

So i have my object generator:

class DataGeneratorCustom:

    def __init__(self, ...):

    def generate(self,
                 returns={'processed_images', 'encoded_labels'},

            The next batch as a tuple of items as defined by the `returns` argument.

        if self.dataset_size == 0:
            raise DatasetError("Cannot generate batches because you did not load a dataset.")

        # Warn if any of the set returns aren't possible.

        if self.labels is None:
            if any([ret in returns for ret in ['original_labels', 'processed_labels', 'encoded_labels', 'matched_anchors', 'evaluation-neutral']]):
                warnings.warn("Since no labels were given, none of 'original_labels', 'processed_labels', 'evaluation-neutral', 'encoded_labels', and 'matched_anchors' " +
                              "are possible returns, but you set `returns = {}`. The impossible returns will be `None`.".format(returns))
        elif label_encoder is None:
            if any([ret in returns for ret in ['encoded_labels', 'matched_anchors']]):
                warnings.warn("Since no label encoder was given, 'encoded_labels' and 'matched_anchors' aren't possible returns, " +
                              "but you set `returns = {}`. The impossible returns will be `None`.".format(returns))
        elif not isinstance(label_encoder, SSDInputEncoder):
            if 'matched_anchors' in returns:
                warnings.warn("`label_encoder` is not an `SSDInputEncoder` object, therefore 'matched_anchors' is not a possible return, " +
                              "but you set `returns = {}`. The impossible returns will be `None`.".format(returns))

        # Do a few preparatory things like maybe shuffling the dataset initially.

        if shuffle:
            objects_to_shuffle = [self.dataset_indices]
            if not (self.filenames is None):
            if not (self.labels is None):
            if not (self.image_ids is None):
            if not (self.eval_neutral is None):
            shuffled_objects = sklearn.utils.shuffle(*objects_to_shuffle)
            for i in range(len(objects_to_shuffle)):
                objects_to_shuffle[i][:] = shuffled_objects[i]

        if degenerate_box_handling == 'remove':
            box_filter = BoxFilter(check_overlap=False,

        # Override the labels formats of all the transformations to make sure they are set correctly.
        if not (self.labels is None):
            for transform in transformations:
                transform.labels_format = self.labels_format

        # Generate mini batches.

        current = 0

        while True:

            batch_X, batch_y = [], []

            if current >= self.dataset_size:
                current = 0

            # Maybe shuffle the dataset if a full pass over the dataset has finished.

                if shuffle:
                    objects_to_shuffle = [self.dataset_indices]
                    if not (self.filenames is None):
                    if not (self.labels is None):
                    if not (self.image_ids is None):
                    if not (self.eval_neutral is None):
                    shuffled_objects = sklearn.utils.shuffle(*objects_to_shuffle)
                    for i in range(len(objects_to_shuffle)):
                        objects_to_shuffle[i][:] = shuffled_objects[i]

            # Get the images, (maybe) image IDs, (maybe) labels, etc. for this batch.

            # We prioritize our options in the following order:
            # 1) If we have the images already loaded in memory, get them from there.
            # 2) Else, if we have an TFRecord dataset, get the images from there.
            # 3) Else, if we have neither of the above, we'll have to load the individual image
            #    files from disk.
            batch_indices = self.dataset_indices[current:current+batch_size]
            if not (self.images is None):
                for i in batch_indices:
                if not (self.filenames is None):
                    batch_filenames = self.filenames[current:current+batch_size]
                    batch_filenames = None
            # elif not (self.hdf5_dataset is None):
            #     for i in batch_indices:
            #         batch_X.append(self.hdf5_dataset['images'][i].reshape(self.hdf5_dataset['image_shapes'][i]))
            elif not (self.tfrecord_dataset is None):
                for i in batch_indices:
                    image, image_shape = self.tfrecord_extract_image(i)
                    # batch_X.append(self.hdf5_dataset['images'][i].reshape(self.hdf5_dataset['image_shapes'][i]))
                if not (self.filenames is None):
                    batch_filenames = self.filenames[current:current+batch_size]
                    batch_filenames = None
                batch_filenames = self.filenames[current:current+batch_size]
                for filename in batch_filenames:
                    with as image:
                        batch_X.append(np.array(image, dtype=np.uint8))

            # Get the labels for this batch (if there are any).
            if not (self.labels is None):
                batch_y = deepcopy(self.labels[current:current+batch_size])
                batch_y = None

            if not (self.eval_neutral is None):
                batch_eval_neutral = self.eval_neutral[current:current+batch_size]
                batch_eval_neutral = None

            # Get the image IDs for this batch (if there are any).
            if not (self.image_ids is None):
                batch_image_ids = self.image_ids[current:current+batch_size]
                batch_image_ids = None

            if 'original_images' in returns:
                batch_original_images = deepcopy(batch_X) # The original, unaltered images
            if 'original_labels' in returns:
                batch_original_labels = deepcopy(batch_y) # The original, unaltered labels

            current += batch_size

            # Maybe perform image transformations.

            batch_items_to_remove = [] # In case we need to remove any images from the batch, store their indices in this list.
            batch_inverse_transforms = []

            for i in range(len(batch_X)):

                if not (self.labels is None):
                    # Convert the labels for this image to an array (in case they aren't already).
                    batch_y[i] = np.array(batch_y[i])
                    # If this image has no ground truth boxes, maybe we don't want to keep it in the batch.
                    if (batch_y[i].size == 0) and not keep_images_without_gt:

                # Apply any image transformations we may have received.
                if transformations:

                    inverse_transforms = []

                    for transform in transformations:

                        if not (self.labels is None):

                            if ('inverse_transform' in returns) and ('return_inverter' in inspect.signature(transform).parameters):
                                batch_X[i], batch_y[i], inverse_transform = transform(batch_X[i], batch_y[i], return_inverter=True)
                                batch_X[i], batch_y[i] = transform(batch_X[i], batch_y[i])

                            if batch_X[i] is None: # In case the transform failed to produce an output image, which is possible for some random transforms.


                            if ('inverse_transform' in returns) and ('return_inverter' in inspect.signature(transform).parameters):
                                batch_X[i], inverse_transform = transform(batch_X[i], return_inverter=True)
                                batch_X[i] = transform(batch_X[i])


                # Check for degenerate boxes in this batch item.

                if not (self.labels is None):

                    xmin = self.labels_format['xmin']
                    ymin = self.labels_format['ymin']
                    xmax = self.labels_format['xmax']
                    ymax = self.labels_format['ymax']

                    if np.any(batch_y[i][:,xmax] - batch_y[i][:,xmin] <= 0) or np.any(batch_y[i][:,ymax] - batch_y[i][:,ymin] <= 0):
                        if degenerate_box_handling == 'warn':
                            warnings.warn("Detected degenerate ground truth bounding boxes for batch item {} with bounding boxes {}, ".format(i, batch_y[i]) +
                                          "i.e. bounding boxes where xmax <= xmin and/or ymax <= ymin. " +
                                          "This could mean that your dataset contains degenerate ground truth boxes, or that any image transformations you may apply might " +
                                          "result in degenerate ground truth boxes, or that you are parsing the ground truth in the wrong coordinate format." +
                                          "Degenerate ground truth bounding boxes may lead to NaN errors during the training.")
                        elif degenerate_box_handling == 'remove':
                            batch_y[i] = box_filter(batch_y[i])
                            if (batch_y[i].size == 0) and not keep_images_without_gt:

            # Remove any items we might not want to keep from the batch.

            if batch_items_to_remove:
                for j in sorted(batch_items_to_remove, reverse=True):
                    # This isn't efficient, but it hopefully shouldn't need to be done often anyway.
                    if batch_inverse_transforms: batch_inverse_transforms.pop(j)
                    if not (self.labels is None): batch_y.pop(j)
                    if not (self.image_ids is None): batch_image_ids.pop(j)
                    if not (self.eval_neutral is None): batch_eval_neutral.pop(j)
                    if 'original_images' in returns: batch_original_images.pop(j)
                    if 'original_labels' in returns and not (self.labels is None): batch_original_labels.pop(j)


            # CAUTION: Converting `batch_X` into an array will result in an empty batch if the images have varying sizes
            #          or varying numbers of channels. At this point, all images must have the same size and the same
            #          number of channels.
            batch_X = np.array(batch_X)
            if (batch_X.size == 0):
                raise DegenerateBatchError("You produced an empty batch. This might be because the images in the batch vary " +
                                           "in their size and/or number of channels. Note that after all transformations " +
                                           "(if any were given) have been applied to all images in the batch, all images " +
                                           "must be homogenous in size along all axes.")

            # If we have a label encoder, encode our labels.

            if not (label_encoder is None or self.labels is None):

                if ('matched_anchors' in returns) and isinstance(label_encoder, SSDInputEncoder):
                    batch_y_encoded, batch_matched_anchors = label_encoder(batch_y, diagnostics=True)
                    batch_y_encoded = label_encoder(batch_y, diagnostics=False)
                    batch_matched_anchors = None

                batch_y_encoded = None
                batch_matched_anchors = None

            # Compose the output.

            ret = []
            if 'processed_images' in returns: ret.append(batch_X)
            if 'encoded_labels' in returns: ret.append(batch_y_encoded)
            if 'matched_anchors' in returns: ret.append(batch_matched_anchors)
            if 'processed_labels' in returns: ret.append(batch_y)
            if 'filenames' in returns: ret.append(batch_filenames)
            if 'image_ids' in returns: ret.append(batch_image_ids)
            if 'evaluation-neutral' in returns: ret.append(batch_eval_neutral)
            if 'inverse_transform' in returns: ret.append(batch_inverse_transforms)
            if 'original_images' in returns: ret.append(batch_original_images)
            if 'original_labels' in returns: ret.append(batch_original_labels)

            yield ret

    def tfrecord_extract_image(self,

        # tf.keras.backend.clear_session()
        iterator = self.tfrecord_dataset.make_one_shot_iterator()
        next_record = iterator.get_next()

        # with tf.Graph().as_default():
        # with tf.keras.backend.get_session() as session:

        # Iterate with a tensorflow-session
        # with self.session.as_default() as default_session:

        # Jump to the record of the index
        if index > 0:
            for i in range(index):
                # K.get_session().run(next_record)

        # Extract and return the image
        # image, labels, image_shape, labels_shape, image_id, eval_neutral =
        # image, labels, image_shape, labels_shape, image_id, eval_neutral = K.get_session().run(next_record)
        image, labels, image_shape, labels_shape, image_id, eval_neutral =

        # Decode the fields
        image_shape = tf.decode_raw(image_shape, tf.int32)
        image_shape = image_shape.eval()
        image = tf.decode_raw(image, tf.uint8)
        image = image.eval()
        image = image.reshape(image_shape)

        return image, image_shape

This generator is external given to a model through the fit_generator():

history = model.fit_generator(generator=train_generator,

The only piece of code that give me problem is tfrecord_extract_image(). For read the record, i need a tf.Session() and, indeed by using the with keyword on a tf.Session() i can read the tfrecord:

def tfrecord_extract_image(self,

    # tf.keras.backend.clear_session()
    # tf.keras.backend.clear_session()
    iterator = self.tfrecord_dataset.make_one_shot_iterator()
    next_record = iterator.get_next()

    with tf.Session() as session:

        # Jump to the record of the index
        if index > 0:
            for i in range(index):
                # K.get_session().run(next_record)

        # Extract and return the image
        # image, labels, image_shape, labels_shape, image_id, eval_neutral =
        # image, labels, image_shape, labels_shape, image_id, eval_neutral = K.get_session().run(next_record)
        image, labels, image_shape, labels_shape, image_id, eval_neutral =

        # Decode the fields
        image_shape = tf.decode_raw(image_shape, tf.int32)
        image_shape = image_shape.eval()
        image = tf.decode_raw(image, tf.uint8)
        image = image.eval()
        image = image.reshape(image_shape)

        return image, image_shape

So, for every search i use a session, but this give me a lot of problem when i use GoogleML. Indeed the cloud machine is forced to create a new GPU instance at every batch step:

12/100 [==>...........................] - ETA: 9:46 - loss: 32.8772  master-replica-0
Adding visible gpu devices: 0  master-replica-0
Device interconnect StreamExecutor with strength 1 edge matrix:  master-replica-0
  0   master-replica-0
  N   master-replica-0
Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 10763 MB memory) -> physical GPU (device: 0, name: Tesla K80, pci bus id: 0000:00:04.0, compute capability: 3.7)  master-replica-0

13/100 [==>...........................] - ETA: 9:21 - loss: 32.8790  master-replica-0
Adding visible gpu devices: 0  master-replica-0
Device interconnect StreamExecutor with strength 1 edge matrix:  master-replica-0
  0   master-replica-0
  N   master-replica-0
Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 10763 MB memory) -> physical GPU (device: 0, name: Tesla K80, pci bus id: 0000:00:04.0, compute capability: 3.7)  master-replica-0

 14/100 [===>..........................] - ETA: 9:25 - loss: 32.5690  master-replica-0
Adding visible gpu devices: 0  master-replica-0
Device interconnect StreamExecutor with strength 1 edge matrix:  master-replica-0
  0   master-replica-0
  0:   N   master-replica-0
Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 10763 MB memory) -> physical GPU (device: 0, name: Tesla K80, pci bus id: 0000:00:04.0, compute capability: 3.7)  master-replica-0

 15/100 [===>..........................] - ETA: 9:00 - loss: 32.9770  master-replica-0

So i try to:

In every case i received the error:

Fetch argument <tf.Tensor 'IteratorGetNext:0' shape=() dtype=string> cannot be interpreted as a Tensor. (Tensor Tensor("IteratorGetNext:0", shape=(), dtype=string) is not an element of this graph.)

How can i use a unique tf.session for my batch-generator?

The problem was the sessions. Open multiple sessions during the training force the system to instances a new GPU.

A solution is to move the with tf.Session() as session before the while True.

