Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix YOLOv5 Detect Layer compatibility #345

Merged
merged 5 commits into from
Mar 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 6 additions & 13 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,27 +92,20 @@ def test_load_from_ultralytics_voc(

# Define YOLOv5 model
model_yolov5 = load_yolov5_model(checkpoint_path)
model_yolov5.conf = conf # confidence threshold (0-1)
model_yolov5.iou = iou # NMS IoU threshold (0-1)
model_yolov5.eval()
with torch.no_grad():
outs = model_yolov5(img[None])[0]
outs = non_max_suppression(outs, conf, iou, agnostic=True)
out_from_yolov5 = outs[0]
out_yolov5 = outs[0]

# Define yolort model
model_yolort = YOLO.load_from_yolov5(
checkpoint_path,
score_thresh=conf,
version=version,
)
model_yolort = YOLO.load_from_yolov5(checkpoint_path, score_thresh=conf, version=version)
model_yolort.eval()
with torch.no_grad():
out_from_yolort = model_yolort(img[None])
out_yolort = model_yolort(img[None])

torch.testing.assert_allclose(out_from_yolort[0]["boxes"], out_from_yolov5[:, :4])
torch.testing.assert_allclose(out_from_yolort[0]["scores"], out_from_yolov5[:, 4])
torch.testing.assert_allclose(out_from_yolort[0]["labels"], out_from_yolov5[:, 5].to(dtype=torch.int64))
torch.testing.assert_allclose(out_yolort[0]["boxes"], out_yolov5[:, :4])
torch.testing.assert_allclose(out_yolort[0]["scores"], out_yolov5[:, 4])
torch.testing.assert_allclose(out_yolort[0]["labels"], out_yolov5[:, 5].to(dtype=torch.int64))


def test_read_image_to_tensor():
Expand Down
15 changes: 9 additions & 6 deletions test/test_v5.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import hashlib

from torch import Tensor
from yolort.v5 import load_yolov5_model, attempt_download
from yolort.v5 import AutoShape, attempt_download, load_yolov5_model


def test_attempt_download():
Expand All @@ -15,16 +15,19 @@ def test_attempt_download():
assert readable_hash[:8] == "9ca9a642"


def test_load_yolov5_model():
def test_load_yolov5_model_autoshape_attached():
img_path = "test/assets/zidane.jpg"

model_url = "https://github.com/ultralytics/yolov5/releases/download/v4.0/yolov5s.pt"
checkpoint_path = attempt_download(model_url, hash_prefix="9ca9a642")
model_url = "https://github.com/ultralytics/yolov5/releases/download/v6.1/yolov5s.pt"
checkpoint_path = attempt_download(model_url, hash_prefix="8b3b748c")

model = load_yolov5_model(checkpoint_path)
# Attach AutoShape
model = AutoShape(model)

model = load_yolov5_model(checkpoint_path, autoshape=True, verbose=False)
results = model(img_path)

assert isinstance(results.pred, list)
assert len(results.pred) == 1
assert isinstance(results.pred[0], Tensor)
assert results.pred[0].shape == (3, 6)
assert results.pred[0].shape == (4, 6)
46 changes: 25 additions & 21 deletions yolort/v5/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
from pathlib import Path

import torch
from torch import nn

from .models import AutoShape
from .models.yolo import Model
from .utils import attempt_download, intersect_dicts, set_logging
from .models.yolo import Detect, Model
from .utils import attempt_download

__all__ = ["add_yolov5_context", "load_yolov5_model", "get_yolov5_size"]

Expand Down Expand Up @@ -46,32 +46,36 @@ def get_yolov5_size(depth_multiple, width_multiple):
)


def load_yolov5_model(checkpoint_path: str, autoshape: bool = False, verbose: bool = True):
def load_yolov5_model(checkpoint_path: str, fuse: bool = False):
"""
Creates a specified YOLOv5 model
Creates a specified YOLOv5 model.

Note:
Currently this tool is mainly used to load the checkpoints trained by yolov5
with support for versions v3.1, v4.0 (v5.0) and v6.0 (v6.1). In addition it is
available for inference with AutoShape attached for versions v6.0 (v6.1).

Args:
checkpoint_path (str): path of the YOLOv5 model, i.e. 'yolov5s.pt'
autoshape (bool): apply YOLOv5 .autoshape() wrapper to model. Default: False.
verbose (bool): print all information to screen. Default: True.
fuse (bool): fuse model Conv2d() + BatchNorm2d() layers. Default: False

Returns:
YOLOv5 pytorch model
"""
set_logging(verbose=verbose)

with add_yolov5_context():
ckpt = torch.load(attempt_download(checkpoint_path), map_location=torch.device("cpu"))

if isinstance(ckpt, dict):
model_ckpt = ckpt["model"] # load model

model = Model(model_ckpt.yaml) # create model
ckpt_state_dict = model_ckpt.float().state_dict() # checkpoint state_dict as FP32
ckpt_state_dict = intersect_dicts(ckpt_state_dict, model.state_dict(), exclude=["anchors"])
model.load_state_dict(ckpt_state_dict, strict=False)

if autoshape:
model = AutoShape(model)

return model
if fuse:
model = ckpt["ema" if ckpt.get("ema") else "model"].float().fuse().eval()
else: # without layer fuse
model = ckpt["ema" if ckpt.get("ema") else "model"].float().eval()

# Compatibility updates
for m in model.modules():
if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model]:
if isinstance(m, Detect):
if not isinstance(m.anchor_grid, list): # new Detect Layer compatibility
delattr(m, "anchor_grid")
setattr(m, "anchor_grid", [torch.zeros(1)] * m.nl)

return model
2 changes: 1 addition & 1 deletion yolort/v5/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, version="r4.0"):
def forward(self, x: Tensor) -> Tensor:
return self.act(self.bn(self.conv(x)))

def fuseforward(self, x):
def forward_fuse(self, x: Tensor) -> Tensor:
return self.act(self.conv(x))


Expand Down