Skip to content

Commit

Permalink
add: Efficient LoFTR (#64)
Browse files Browse the repository at this point in the history
* add: Efficient LoFTR

* fix: eloftr
  • Loading branch information
Vincentqyw authored Aug 23, 2024
1 parent f791ada commit 7a8c642
Show file tree
Hide file tree
Showing 7 changed files with 153 additions and 4 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,6 @@
[submodule "third_party/pram"]
path = third_party/pram
url = https://github.com/agipro/pram.git
[submodule "third_party/EfficientLoFTR"]
path = third_party/EfficientLoFTR
url = https://github.com/zju3dv/EfficientLoFTR.git
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ Here is a demo of the tool:
https://github.com/Vincentqyw/image-matching-webui/assets/18531182/263534692-c3484d1b-cc00-4fdc-9b31-e5b7af07ecd9

The tool currently supports various popular image matching algorithms, namely:
- [x] [EfficientLoFTR](https://github.com/zju3dv/EfficientLoFTR), CVPR 2024
- [x] [MASt3R](https://github.com/naver/mast3r), CVPR 2024
- [x] [DUSt3R](https://github.com/naver/dust3r), CVPR 2024
- [x] [OmniGlue](https://github.com/Vincentqyw/omniglue-onnx), CVPR 2024
Expand Down
4 changes: 2 additions & 2 deletions hloc/extractors/sfd2.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ def _init(self, conf):
self.norm_rgb = tvf.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
model_fn = tp_path / "pram" / "weights" / self.conf["model_name"]
self.net = load_sfd2(weight_path=model_fn).eval()
model_path = tp_path / "pram" / "weights" / self.conf["model_name"]
self.net = load_sfd2(weight_path=model_path).eval()

logger.info("Load SFD2 model done.")

Expand Down
19 changes: 19 additions & 0 deletions hloc/match_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,25 @@
"max_error": 1, # max error for assigned keypoints (in px)
"cell_size": 1, # size of quantization patch (max 1 kp/patch)
},
"eloftr": {
"output": "matches-eloftr",
"model": {
"name": "eloftr",
"weights": "weights/eloftr_outdoor.ckpt",
"max_keypoints": 2000,
"match_threshold": 0.2,
},
"preprocessing": {
"grayscale": True,
"resize_max": 1024,
"dfactor": 32,
"width": 640,
"height": 480,
"force_resize": True,
},
"max_error": 1, # max error for assigned keypoints (in px)
"cell_size": 1, # size of quantization patch (max 1 kp/patch)
},
# "loftr_quadtree": {
# "output": "matches-loftr-quadtree",
# "model": {
Expand Down
115 changes: 115 additions & 0 deletions hloc/matchers/eloftr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import subprocess
import sys
import warnings
from copy import deepcopy
from pathlib import Path

import torch
from huggingface_hub import hf_hub_download

tp_path = Path(__file__).parent / "../../third_party"
sys.path.append(str(tp_path))

from EfficientLoFTR.src.loftr import LoFTR as ELoFTR_
from EfficientLoFTR.src.loftr import (
full_default_cfg,
opt_default_cfg,
reparameter,
)

from hloc import logger

from ..utils.base_model import BaseModel


class ELoFTR(BaseModel):
default_conf = {
"weights": "weights/eloftr_outdoor.ckpt",
"match_threshold": 0.2,
# "sinkhorn_iterations": 20,
"max_keypoints": -1,
# You can choose model type in ['full', 'opt']
"model_type": "full", # 'full' for best quality, 'opt' for best efficiency
# You can choose numerical precision in ['fp32', 'mp', 'fp16']. 'fp16' for best efficiency
"precision": "fp32",
}
required_inputs = ["image0", "image1"]

def _init(self, conf):

if self.conf["model_type"] == "full":
_default_cfg = deepcopy(full_default_cfg)
elif self.conf["model_type"] == "opt":
_default_cfg = deepcopy(opt_default_cfg)

if self.conf["precision"] == "mp":
_default_cfg["mp"] = True
elif self.conf["precision"] == "fp16":
_default_cfg["half"] = True

model_path = tp_path / "EfficientLoFTR" / self.conf["weights"]

# Download the model.
if not model_path.exists():
model_path.parent.mkdir(exist_ok=True)
cached_file = hf_hub_download(
repo_type="space",
repo_id="Realcat/image-matching-webui",
filename="third_party/EfficientLoFTR/{}".format(
conf["weights"]
),
)
logger.info("Downloaded EfficientLoFTR model succeeed!")
cmd = [
"cp",
str(cached_file),
str(model_path),
]
subprocess.run(cmd, check=True)
logger.info(f"Copy model file `{cmd}`.")

cfg = _default_cfg
cfg["match_coarse"]["thr"] = conf["match_threshold"]
# cfg["match_coarse"]["skh_iters"] = conf["sinkhorn_iterations"]
state_dict = torch.load(model_path, map_location="cpu")["state_dict"]
matcher = ELoFTR_(config=cfg)
matcher.load_state_dict(state_dict)
self.net = reparameter(matcher)

if self.conf["precision"] == "fp16":
self.net = self.net.half()
logger.info(f"Loaded Efficient LoFTR with weights {conf['weights']}")

def _forward(self, data):
# For consistency with hloc pairs, we refine kpts in image0!
rename = {
"keypoints0": "keypoints1",
"keypoints1": "keypoints0",
"image0": "image1",
"image1": "image0",
"mask0": "mask1",
"mask1": "mask0",
}
data_ = {rename[k]: v for k, v in data.items()}
with warnings.catch_warnings():
warnings.simplefilter("ignore")
pred = self.net(data_)
pred = {
"keypoints0": data_["mkpts0_f"],
"keypoints1": data_["mkpts1_f"],
}
scores = data_["mconf"]

top_k = self.conf["max_keypoints"]
if top_k is not None and len(scores) > top_k:
keep = torch.argsort(scores, descending=True)[:top_k]
pred["keypoints0"], pred["keypoints1"] = (
pred["keypoints0"][keep],
pred["keypoints1"][keep],
)
scores = scores[keep]

# Switch back indices
pred = {(rename[k] if k in rename else k): v for k, v in pred.items()}
pred["scores"] = scores
return pred
1 change: 1 addition & 0 deletions third_party/EfficientLoFTR
Submodule EfficientLoFTR added at 68eb92
14 changes: 12 additions & 2 deletions ui/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ matcher_zoo:
DUSt3R:
# TODO: duster is under development
enable: true
skip_ci: true
# skip_ci: true
matcher: duster
dense: true
info:
Expand All @@ -53,7 +53,7 @@ matcher_zoo:
display: true
GIM(dkm):
enable: true
skip_ci: true
# skip_ci: true
matcher: gim(dkm)
dense: true
info:
Expand Down Expand Up @@ -95,6 +95,16 @@ matcher_zoo:
paper: https://arxiv.org/pdf/2104.00680
project: https://zju3dv.github.io/loftr
display: true
eloftr:
matcher: eloftr
dense: true
info:
name: Efficient LoFTR #dispaly name
source: "CVPR 2024"
github: https://github.com/zju3dv/efficientloftr
paper: https://zju3dv.github.io/efficientloftr/files/EfficientLoFTR.pdf
project: https://zju3dv.github.io/efficientloftr
display: true
cotr:
enable: false
skip_ci: true
Expand Down

0 comments on commit 7a8c642

Please sign in to comment.