vmmgame
vmmgame

Reputation: 41

Pascal VOC dataset loading and YOLO implementation error

I am trying to implement the original YOLO Model from scratch using the 2016 research paper. I have built the loss function as a subclass of the keras.losses.Loss object. How ever I keep getting this error:

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
Cell In[98], line 3
      1 # Example test case
      2 # Generate example input and ground truth using the generator
----> 3 for image, bboxes, labels in train_df.take(1):  # Take one batch
      4     # Convert bboxes and labels to the 7x7x30 format
      5     y_true = convert_to_yolo_format(bboxes, labels)
      6     break  # Get only one sample

File ~/Desktop/env/lib/python3.12/site-packages/tensorflow/python/data/ops/iterator_ops.py:826, in OwnedIterator.__next__(self)
    824 def __next__(self):
    825   try:
--> 826     return self._next_internal()
    827   except errors.OutOfRangeError:
    828     raise StopIteration

File ~/Desktop/env/lib/python3.12/site-packages/tensorflow/python/data/ops/iterator_ops.py:776, in OwnedIterator._next_internal(self)
    773 # TODO(b/77291417): This runs in sync mode as iterators use an error status
    774 # to communicate that there is no more data to iterate over.
    775 with context.execution_mode(context.SYNC):
--> 776   ret = gen_dataset_ops.iterator_get_next(
    777       self._iterator_resource,
    778       output_types=self._flat_output_types,
    779       output_shapes=self._flat_output_shapes)
    781   try:
    782     # Fast path for the case `self._structure` is not a nested structure.
    783     return self._element_spec._from_compatible_tensor_list(ret)  # pylint: disable=protected-access

File ~/Desktop/env/lib/python3.12/site-packages/tensorflow/python/ops/gen_dataset_ops.py:3086, in iterator_get_next(iterator, output_types, output_shapes, name)
   3084   return _result
   3085 except _core._NotOkStatusException as e:
-> 3086   _ops.raise_from_not_ok_status(e, name)
   3087 except _core._FallbackException:
   3088   pass

File ~/Desktop/env/lib/python3.12/site-packages/tensorflow/python/framework/ops.py:6002, in raise_from_not_ok_status(e, name)
   6000 def raise_from_not_ok_status(e, name) -> NoReturn:
   6001   e.message += (" name: " + str(name if name is not None else ""))
-> 6002   raise core._status_to_exception(e) from None

InvalidArgumentError: {{function_node __wrapped__IteratorGetNext_output_types_3_device_/job:localhost/replica:0/task:0/device:CPU:0}} Cannot batch tensors with different shapes in component 1. First element had shape [2,4] and element 1 had shape [1,4]. [Op:IteratorGetNext] name:

When trying to run this code:

# Example test case
# Generate example input and ground truth using the generator
for image, bboxes, labels in train_df.take(1):  # Take one batch
    # Convert bboxes and labels to the 7x7x30 format
    y_true = convert_to_yolo_format(bboxes, labels)
    break  # Get only one sample

# Generate example y_pred (you can use random values for now)
y_pred = np.random.rand(1, 7, 7, 30)

# Instantiate YOLOLoss
yolo_loss = YOLOLoss()

# Test the call method
total_loss = yolo_loss(y_true, y_pred)  # You can directly call the instance
print("Total Loss:", total_loss)

I believe that this is due the different amount of bboxes per image. I am loading the data using the tensorflow.data.from_generator method and here is my generator function:

def pascal_voc_generator(image_dir, annotation_dir, image_set_file):
    image_dir = str(image_dir)
    annotation_dir = str(annotation_dir)
    
    with open(image_set_file, 'r') as f:
        ## Get list of image ids for the split
        image_ids = [line.strip() for line in f]
    for image_id in image_ids:
        image_id = str(image_id)
        
        ## For each id, get corresponding image and annotation file
        image_path = str(os.path.join(image_dir, f"{image_id}.jpg")).replace('b\'', '').replace('\'', '')
        annotation_path = str(os.path.join(annotation_dir, f"{image_id}.xml")).replace('b\'', '').replace('\'', '')

        ## Ensure the paths are strings and normalized
        image_path = os.path.normpath(image_path)
        annotation_path = os.path.normpath(annotation_path)
        
        ## Load the image
        try:

            if not os.path.exists(image_path):
                print(f"Image file not found: {image_path}")
                continue
        
            if not os.path.exists(annotation_path):
                print(f"Annotation file not found: {annotation_path}")
                continue
            image = tf.keras.preprocessing.image.load_img(image_path, target_size=(448,448), keep_aspect_ratio=True)
            image = tf.keras.preprocessing.image.img_to_array(image)
            image = image / 255.0 ## Normalize
            
            ## Parse the XML file
            tree = ET.parse(annotation_path)
            root = tree.getroot()
            bboxes = []
            labels = []
            for obj in root.findall('object'):
                bbox = obj.find('bndbox')
                xmin = int(bbox.find('xmin').text)
                ymin = int(bbox.find('ymin').text)
                xmax = int(bbox.find('xmax').text)
                ymax = int(bbox.find('ymax').text)
    
                ## Get actual features
                center_x, center_y = get_center_coords([xmin, ymin, xmax, ymax])
                # Calculate width and height
                width = xmax - xmin
                height = ymax - ymin
    
                ## Convert label to numeric value
                label_name = obj.find('name').text.lower()
                label = label_map[label_name]
                
                bboxes.append([center_x, center_y, width, height])
                labels.append(label)

            yield image, bboxes, labels
        except Exception as err:
            print(f'''Error ocurred:
                Image id: {image_id}
                Error Message: {err}
            ''')

I have tried to find other stackoverflow/online forum solutions and couldn't find anything helpful. I tried using chatgpt but that just caused more issues. If you have any suggestions, please let me know.

Upvotes: 0

Views: 25

Answers (0)

Related Questions