Kshitij Vijayvergiya
Kshitij Vijayvergiya

Reputation: 11

Tensorflow Object detection API: Print detected class as output

I am using the TF Object detection API to detect images, it is working fine and given an image it will draw the bounding box with a label and confidence score. My question is how to print the detected class (as a string) i.e not just on the image but as an output to the terminal too.

This is the code of detection in real time.

cap = cv2.VideoCapture(0)  
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))  
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))    

while True:        
ret, frame = cap.read()      
image_np = np.array(frame)            

input_tensor = tf.convert_to_tensor(np.expand_dims(image_np, 0),       
dtype=tf.float32)    
detections = detect_fn(input_tensor)    

num_detections = int(detections.pop('num_detections'))    
detections = {key: value[0, :num_detections].numpy()
              for key, value in detections.items()}    
detections['num_detections'] = num_detections

# detection_classes should be ints.
detections['detection_classes'] = detections['detection_classes'].astype(np.int64)

label_id_offset = 1
image_np_with_detections = image_np.copy()



viz_utils.visualize_boxes_and_labels_on_image_array(
            image_np_with_detections,
            detections['detection_boxes'],
            detections['detection_classes']+label_id_offset,
            detections['detection_scores'],
            category_index,
            use_normalized_coordinates=True,
            max_boxes_to_draw=3,
            min_score_thresh=.5,
            agnostic_mode=False)



cv2.imshow('object detection',  cv2.resize(image_np_with_detections, (800, 600)))

if cv2.waitKey(1) & 0xFF == ord('q'):
    cap.release()
    break

Upvotes: 1

Views: 1203

Answers (2)

variable
variable

Reputation: 9736

Code:

my_classes = detections['detection_classes'][0].numpy() + label_id_offset
my_scores = detections['detection_scores'][0].numpy()

min_score = 0.5
    
print([category_index[value]['name']
        for index,value in enumerate(my_classes) 
        if my_scores[index] > min_score
     ])
    

Sample output:

['person', 'cell phone', 'remote']

Upvotes: 0

user11530462
user11530462

Reputation:

Classes will be encrypted in category_index variable.Use the below code snippet to get the detected class.

print ([category_index.get(value) for index,value in enumerate(classes[0]) if scores[0,index] > 0.5])

Upvotes: 1

Related Questions