top of page

Getting generic object predictions from FasterRCNN architecture using Detectron2

Before it became a class, it was an object.


Introduction


With the advent of wide-spread availability of training data and computational resources, computer vision and object detection has become increasingly more accessible. However, specific real-world use cases such as food detection require new classes all too often. Isn’t it annoying for your food to get cold while you wait for the data gathering process to finish?


In this blog post we present an extension to Facebook’s Detectron2 framework to detect new unseen objects, thus speeding up the data gathering process. We illustrate this below, using the detection output of a model pre-trained on the COCO dataset that has not yet seen either this type of bun or bounty chocolate bars.


Detection example for model pretrained on COCO — the model has not been trained on buns or Bounty chocolate bars.
 

Context and Motivation

Finding a robust definition for the ideal dataset to use for object detection is difficult as the use cases and success metrics can vary, i.e. accuracy vs. generalisation. Typically, they have a balanced class distribution, enough images per class and quality annotations. However, assembling such datasets is an even more daunting task.

Data doesn’t come for free and it costs both time and money.

For example, adding a new class into an already curated dataset, means ensuring we have enough annotations and class instances as well as class balance, such that performance is not jeopardised. Thus, both acquisition and labelling present a bottleneck.

Google’s AI Data Labelling Service charges 25$ for 1000 classification annotations and 49$ for 1000 bounding box annotations. For a robust dataset of 10000 annotations, having the object localisation translates into a cost saving of 240$.

Our solution can help speed up this process considerably by reducing the number of actions needed to obtain object bounding boxes, and by improving the on-boarding process of a new class in the detection pipeline.


 

Implementation

Detectron2 is Facebook’s AI Research framework for implementing Computer Vision algorithms. Designed to switch between tasks with ease, going from object detection to semantic segmentation or keypoint detection with a small change in a config file, Detectron2 offers state-of-the-art implementations for algorithms such as FasterRCNN and RetinaNet. As it is described in this great explanation about the inner workings of Detectron, Detectron2 FasterRCNN-FPN is composed of the following building blocks:

  • Backbone network

  • Region proposal network

  • ROI Heads (Box Head)

FasterRCNN computes a score for each RPN region which defines the confidence of an object being present in that region. The regions with the best objectness scores will be classified and turned into class predictions if the class score is better than the class confidence threshold. Predictions that are below the class confidence threshold but above the objectness score threshold will be referred to as generic object predictions.

Our proposed solution extracts the objectness score of predictions from the head of the model. We use these scores to get predictions of object bounding boxes for classes unseen by the model in its training set. A prerequisite is that new object representations should live within the same representation space as the trained classes, i.e. the solution can find an apple’s location in an image with a model trained on a vegetable dataset that has seen tomatoes and potatoes. This gets us one click away from getting the image detection annotations, skipping the drawing process of a bounding box.

In order to achieve our goal, we have to modify ROI Heads so that it will output generic predictions. In Detectron2, ROI Heads is represented by StandardROIHeads class which contains theFastRCNNOutputLayers class which predicts bounding boxes and classification scores based on the region proposals.

Region proposals are the output of FastRCNNConvFCHead (another component of StandardROIHeads) and they have associated objectness logits which represent how likely an object is to be there. Because we want to obtain the generic predictions based on these scores, we create our own class GenericFastRCNNOutputLayers that subclasses FastRCNNOutputLayers. The roles of this class are to:

  • Obtain class predictions based on the detection scores, filter them using an established score threshold, and then apply NMS.


def _get_class_predictions(self, boxes, scores, image_shape):        num_bbox_reg_classes = boxes.shape[1] // 4        # Convert to Boxes to use the `clip` function ...
        boxes = Boxes(boxes.reshape(-1, 4))
        boxes.clip(image_shape)
        boxes = boxes.tensor.view(-1, num_bbox_reg_classes, 4)  # R x C x 4        # Filter results based on detection scores
        filter_mask = scores > self.class_score_thresh_test        # R' x 2. First column contains indices of the R predictions;
        # Second column contains indices of classes.
        class_inds = filter_mask.nonzero()
        if num_bbox_reg_classes == 1:
            boxes = boxes[class_inds[:, 0], 0]
        else:
            boxes = boxes[filter_mask]
        scores = scores[filter_mask]        # Apply per-class NMS
        keep_class = batched_nms(boxes, scores, class_inds[:, 1],
                                 self.class_nms_thresh_test)
        if self.topk_per_image_test >= 0:
            keep_class = keep_class[:self.topk_per_image_test]        boxes, scores, class_inds = boxes[keep_class], scores[
            keep_class], class_inds[keep_class]        return boxes, scores, class_inds

  • Obtain generic predictions based on objectness logits, filter them by an objectness score threshold, filter the results that overlap with class predictions, then apply NMS on the final generic predictions. Usually, the values of the objectness threshold are above > 3 which indicates the presence of an object in the image.


def _get_generic_predictions(
        self, proposals: Instances, class_boxes: torch.FloatTensor,
        class_scores: torch.FloatTensor, class_inds: torch.FloatTensor,
        generic_idx: int
    ) -> (torch.FloatTensor, torch.FloatTensor, torch.IntTensor):        #####
        # Per object
        objectness = proposals.objectness_logits.reshape(
            (proposals.objectness_logits.shape[0], 1))        obj_boxes = proposals.proposal_boxes.tensor        # Filter by objectness threshold
        filter_object_mask = objectness > self.objectness_score_thresh_test        filter_obj_inds = filter_object_mask.nonzero()
        obj_boxes = obj_boxes[filter_obj_inds[:, 0]]        # Filter generic objects that overlap with class predictions
        generic_mask = self._find_generic_objects_suppression_mask(
            class_boxes, obj_boxes, self.objectness_nms_thresh_test)        objectness = objectness[filter_object_mask]        generic_boxes = obj_boxes[generic_mask]
        generic_inds = filter_obj_inds[:][generic_mask]
        generic_scores = objectness[generic_mask]        # Attribute generic id to selected predictions
        generic_inds[:, 1] = generic_idx        # Apply NMS to generic predictions
        nms_filtered = batched_nms(generic_boxes, generic_scores,
                                   generic_inds[:, 1],
                                   self.objectness_nms_thresh_test)        generic_boxes = generic_boxes[nms_filtered]
        generic_inds = generic_inds[:][nms_filtered]
        generic_scores = generic_scores[nms_filtered]        # Keep top detections - detected classes have priority
        if self.topk_per_image_test >= 0:
            remaining_objects = self.topk_per_image_test - len(class_boxes)
            sorted_generic = np.argsort(generic_scores)
            sorted_generic = sorted_generic[:remaining_objects]            generic_boxes = generic_boxes[sorted_generic]
            generic_inds = generic_inds[sorted_generic]
            generic_scores = generic_scores[sorted_generic]        return generic_boxes, generic_scores, generic_inds

The parameters we have added to the config of our generic object detection model are:

cfg.MODEL.ROI_HEADS.OBJECTNESS_NMS_THRESH_TEST
  • this is the NMS cutoff IoU at which we discard generic object predictions that overlap with other generic object detections. We remove the predictions with the lowest prediction scores of the overlapping ones. We also use this threshold for suppressing generic object predictions that overlap with class instance predictions. Class instance predictions take priority over generic object predictions since they are more specific, so we remove the generic object predictions that have IoU larger than the threshold with class instances.

cfg.MODEL.ROI_HEADS.OBJECTNESS_SCORE_THRESH_TEST
  • used as a confidence threshold for objectness values returned by the network. It is similar in functionality to SCORE_THRESH_TEST but for objectness values rather than class prediction probabilities.

 

Demo


We have provided a fork of the Detectron2 repo with an implementation of the generic object detection solution. The Detectron2 demo is described more in depth in the repo’s documentation.


Detectron2 comes with a set of predefined configs that allow model customisation. Here is the config we used to test the generic object detection:


_BASE_: "../Base-RCNN-FPN.yaml"
MODEL:
  WEIGHTS: "detectron2://COCO-Detection/faster_rcnn_R_50_FPN_1x/137257794/model_final_b275ba.pkl"
  MASK_ON: False
  RESNETS:
    DEPTH: 50
  DEVICE: "cpu"
  ROI_HEADS:
    NAME: "GenericROIHeads"

We take the base RCNN with Feature Pyramid Network config and use a model with a Resnet50 backbone pre-trained on the COCO dataset. The change that affects the objectness detections is the ROI_HEADS.


To run the demo, we can use the following command:


python demo/demo.py --input list_of_image_paths --config-file configs/COCO-Generic-Detection/faster_rcnn_R_50_FPN_1x.yaml

The parameters to the demo that we have introduced are:

--objectness-score-threshold
  • which defaults to 4.5 (explained in the implementation section)

--objectness-nms
  • which defaults to 0.5 (explained in the implementation section)


Below, we present another example of the model’s predictions using our demo. None of the objects in the image are represented in the training set. The COCO dataset contains objects of various sizes, such as dinning tables, which make our model predict objects that may not be of interest to us (ex. the largest bounding box in the image below).


None of the objects in the image are represented in the training set. The COCO dataset contains objects of various sizes, such as dinning tables, which make our model predict objects that may not be of interest to us (ex. the largest bounding box in this image)
None of the objects in the image are represented in the training set. The COCO dataset contains objects of various sizes, such as dinning tables, which make our model predict objects that may not be of interest to us (ex. the largest bounding box in this image)
 

Conclusion


Generic object predictions are a powerful tool that make more use of the power of representations of deep CNNs and add scalability to a model in terms of detected objects. They reduce the time and resources required to define bounding boxes for new object instances.


We have shown how obtaining a generic object detector is straightforward using existing tools (FasterRCNN in Detectron2) and without adding extra layers of complexity or changing the inference time. When deploying this model, having no extra inference time is critical.

What other uses cases in Computer Vision do you see where generic predictions can help? Leave a comment with your thoughts.


In our next posts, we will look at applications of generic object detectors and see how they can be turned into valuable class-specific annotations without any human intervention.

 

Written by Daniela Palcu and Flaviu Samarghitan, Computer Vision Engineers at Neurolabs


At Neurolabs, we believe that the lack of widespread adoption of machine learning is due to a lack of data. We use computer graphics, much like the special effects or video game industry, to produce realistic images at scale. In 5 minutes, you can have 10,000 images tailored for your problem. But we don’t stop there — we kick-start the machine learning training, and allow any industry to implement ready-made Computer Vision algorithms without an army of human annotators.

Recent Posts

See All
bottom of page