Skip to content

Commit

Permalink
CARLA single modality object detection model (#1160)
Browse files Browse the repository at this point in the history
* rename to deconflict from carla multimodality object detection model

* remove duplicate file

Co-authored-by: Sterling Suggs <sterling.suggs@twosixtech.com>
  • Loading branch information
yusong-tan and swsuggs authored Oct 15, 2021
1 parent 8bd82b2 commit 5ac5046
Showing 1 changed file with 46 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""
PyTorch Faster-RCNN Resnet50-FPN object detection model
"""
import logging
from typing import Optional

from art.estimators.object_detection import PyTorchFasterRCNN
import torch
from torchvision import models

logger = logging.getLogger(__name__)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# NOTE: PyTorchFasterRCNN expects numpy input, not torch.Tensor input
def get_art_model(
model_kwargs: dict, wrapper_kwargs: dict, weights_path: Optional[str] = None
) -> PyTorchFasterRCNN:

if weights_path:
assert (
model_kwargs.get("num_classes", None) == 4
), "model trained on CARLA data outputs predictions for 4 classes"
assert not model_kwargs.get(
"pretrained", False
), "model trained on CARLA data should not use COCO-pretrained weights"
else:
assert (
model_kwargs.get("num_classes", None) == 91
), "model without predefined weights should use COCO classes"
assert model_kwargs.get(
"pretrained", False
), "model without predefined weights should use COCO-pretrained weights"

model = models.detection.fasterrcnn_resnet50_fpn(**model_kwargs)
model.to(DEVICE)

if weights_path:
checkpoint = torch.load(weights_path, map_location=DEVICE)
model.load_state_dict(checkpoint)

wrapped_model = PyTorchFasterRCNN(
model, clip_values=(0.0, 1.0), channels_first=False, **wrapper_kwargs,
)
return wrapped_model

0 comments on commit 5ac5046

Please sign in to comment.