diff --git a/test/test_utils.py b/test/test_utils.py index bfbacfcf..d1d57542 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -92,27 +92,20 @@ def test_load_from_ultralytics_voc( # Define YOLOv5 model model_yolov5 = load_yolov5_model(checkpoint_path) - model_yolov5.conf = conf # confidence threshold (0-1) - model_yolov5.iou = iou # NMS IoU threshold (0-1) - model_yolov5.eval() with torch.no_grad(): outs = model_yolov5(img[None])[0] outs = non_max_suppression(outs, conf, iou, agnostic=True) - out_from_yolov5 = outs[0] + out_yolov5 = outs[0] # Define yolort model - model_yolort = YOLO.load_from_yolov5( - checkpoint_path, - score_thresh=conf, - version=version, - ) + model_yolort = YOLO.load_from_yolov5(checkpoint_path, score_thresh=conf, version=version) model_yolort.eval() with torch.no_grad(): - out_from_yolort = model_yolort(img[None]) + out_yolort = model_yolort(img[None]) - torch.testing.assert_allclose(out_from_yolort[0]["boxes"], out_from_yolov5[:, :4]) - torch.testing.assert_allclose(out_from_yolort[0]["scores"], out_from_yolov5[:, 4]) - torch.testing.assert_allclose(out_from_yolort[0]["labels"], out_from_yolov5[:, 5].to(dtype=torch.int64)) + torch.testing.assert_allclose(out_yolort[0]["boxes"], out_yolov5[:, :4]) + torch.testing.assert_allclose(out_yolort[0]["scores"], out_yolov5[:, 4]) + torch.testing.assert_allclose(out_yolort[0]["labels"], out_yolov5[:, 5].to(dtype=torch.int64)) def test_read_image_to_tensor(): diff --git a/test/test_v5.py b/test/test_v5.py index da2aaa43..e193a776 100644 --- a/test/test_v5.py +++ b/test/test_v5.py @@ -2,7 +2,7 @@ import hashlib from torch import Tensor -from yolort.v5 import load_yolov5_model, attempt_download +from yolort.v5 import AutoShape, attempt_download, load_yolov5_model def test_attempt_download(): @@ -15,16 +15,19 @@ def test_attempt_download(): assert readable_hash[:8] == "9ca9a642" -def test_load_yolov5_model(): +def test_load_yolov5_model_autoshape_attached(): img_path = "test/assets/zidane.jpg" - model_url = "https://github.com/ultralytics/yolov5/releases/download/v4.0/yolov5s.pt" - checkpoint_path = attempt_download(model_url, hash_prefix="9ca9a642") + model_url = "https://github.com/ultralytics/yolov5/releases/download/v6.1/yolov5s.pt" + checkpoint_path = attempt_download(model_url, hash_prefix="8b3b748c") + + model = load_yolov5_model(checkpoint_path) + # Attach AutoShape + model = AutoShape(model) - model = load_yolov5_model(checkpoint_path, autoshape=True, verbose=False) results = model(img_path) assert isinstance(results.pred, list) assert len(results.pred) == 1 assert isinstance(results.pred[0], Tensor) - assert results.pred[0].shape == (3, 6) + assert results.pred[0].shape == (4, 6) diff --git a/yolort/v5/helper.py b/yolort/v5/helper.py index 2db16469..2677e6d7 100644 --- a/yolort/v5/helper.py +++ b/yolort/v5/helper.py @@ -4,10 +4,10 @@ from pathlib import Path import torch +from torch import nn -from .models import AutoShape -from .models.yolo import Model -from .utils import attempt_download, intersect_dicts, set_logging +from .models.yolo import Detect, Model +from .utils import attempt_download __all__ = ["add_yolov5_context", "load_yolov5_model", "get_yolov5_size"] @@ -46,32 +46,36 @@ def get_yolov5_size(depth_multiple, width_multiple): ) -def load_yolov5_model(checkpoint_path: str, autoshape: bool = False, verbose: bool = True): +def load_yolov5_model(checkpoint_path: str, fuse: bool = False): """ - Creates a specified YOLOv5 model + Creates a specified YOLOv5 model. + + Note: + Currently this tool is mainly used to load the checkpoints trained by yolov5 + with support for versions v3.1, v4.0 (v5.0) and v6.0 (v6.1). In addition it is + available for inference with AutoShape attached for versions v6.0 (v6.1). Args: checkpoint_path (str): path of the YOLOv5 model, i.e. 'yolov5s.pt' - autoshape (bool): apply YOLOv5 .autoshape() wrapper to model. Default: False. - verbose (bool): print all information to screen. Default: True. + fuse (bool): fuse model Conv2d() + BatchNorm2d() layers. Default: False Returns: YOLOv5 pytorch model """ - set_logging(verbose=verbose) with add_yolov5_context(): ckpt = torch.load(attempt_download(checkpoint_path), map_location=torch.device("cpu")) - - if isinstance(ckpt, dict): - model_ckpt = ckpt["model"] # load model - - model = Model(model_ckpt.yaml) # create model - ckpt_state_dict = model_ckpt.float().state_dict() # checkpoint state_dict as FP32 - ckpt_state_dict = intersect_dicts(ckpt_state_dict, model.state_dict(), exclude=["anchors"]) - model.load_state_dict(ckpt_state_dict, strict=False) - - if autoshape: - model = AutoShape(model) - - return model + if fuse: + model = ckpt["ema" if ckpt.get("ema") else "model"].float().fuse().eval() + else: # without layer fuse + model = ckpt["ema" if ckpt.get("ema") else "model"].float().eval() + + # Compatibility updates + for m in model.modules(): + if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model]: + if isinstance(m, Detect): + if not isinstance(m.anchor_grid, list): # new Detect Layer compatibility + delattr(m, "anchor_grid") + setattr(m, "anchor_grid", [torch.zeros(1)] * m.nl) + + return model diff --git a/yolort/v5/models/common.py b/yolort/v5/models/common.py index 1d981b07..6449038f 100644 --- a/yolort/v5/models/common.py +++ b/yolort/v5/models/common.py @@ -68,7 +68,7 @@ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, version="r4.0"): def forward(self, x: Tensor) -> Tensor: return self.act(self.bn(self.conv(x))) - def fuseforward(self, x): + def forward_fuse(self, x: Tensor) -> Tensor: return self.act(self.conv(x))