forked from hkchengrex/XMem
-
Notifications
You must be signed in to change notification settings - Fork 0
/
interactive_demo.py
113 lines (91 loc) · 4.76 KB
/
interactive_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
"""
A simple user interface for XMem
"""
import os
from os import path
# fix for Windows
if 'QT_QPA_PLATFORM_PLUGIN_PATH' not in os.environ:
os.environ['QT_QPA_PLATFORM_PLUGIN_PATH'] = ''
import sys
from argparse import ArgumentParser
import torch
from model.network import XMem
from inference.interact.s2m_controller import S2MController
from inference.interact.fbrs_controller import FBRSController
from inference.interact.s2m.s2m_network import deeplabv3plus_resnet50 as S2M
from PyQt6.QtWidgets import QApplication
from inference.interact.gui import App
from inference.interact.resource_manager import ResourceManager
from contextlib import nullcontext
torch.set_grad_enabled(False)
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
if __name__ == '__main__':
# Arguments parsing
parser = ArgumentParser()
parser.add_argument('--model', default='./saves/XMem.pth')
parser.add_argument('--s2m_model', default='saves/s2m.pth')
parser.add_argument('--fbrs_model', default='saves/fbrs.pth')
"""
Priority 1: If a "images" folder exists in the workspace, we will read from that directory
Priority 2: If --images is specified, we will copy/resize those images to the workspace
Priority 3: If --video is specified, we will extract the frames to the workspace (in an "images" folder) and read from there
In any case, if a "masks" folder exists in the workspace, we will use that to initialize the mask
That way, you can continue annotation from an interrupted run as long as the same workspace is used.
"""
parser.add_argument('--images', help='Folders containing input images.', default=None)
parser.add_argument('--video', help='Video file readable by OpenCV.', default=None)
parser.add_argument('--workspace', help='directory for storing buffered images (if needed) and output masks', default=None)
parser.add_argument('--buffer_size', help='Correlate with CPU memory consumption', type=int, default=100)
parser.add_argument('--num_objects', type=int, default=1)
# Long-memory options
# Defaults. Some can be changed in the GUI.
parser.add_argument('--max_mid_term_frames', help='T_max in paper, decrease to save memory', type=int, default=10)
parser.add_argument('--min_mid_term_frames', help='T_min in paper, decrease to save memory', type=int, default=5)
parser.add_argument('--max_long_term_elements', help='LT_max in paper, increase if objects disappear for a long time',
type=int, default=10000)
parser.add_argument('--num_prototypes', help='P in paper', type=int, default=128)
parser.add_argument('--top_k', type=int, default=30)
parser.add_argument('--mem_every', type=int, default=10)
parser.add_argument('--deep_update_every', help='Leave -1 normally to synchronize with mem_every', type=int, default=-1)
parser.add_argument('--no_amp', help='Turn off AMP', action='store_true')
parser.add_argument('--size', default=480, type=int,
help='Resize the shorter side to this size. -1 to use original resolution. ')
args = parser.parse_args()
# create temporary workspace if not specified
config = vars(args)
config['enable_long_term'] = True
config['enable_long_term_count_usage'] = True
if config["workspace"] is None:
if config["images"] is not None:
basename = path.basename(config["images"])
elif config["video"] is not None:
basename = path.basename(config["video"])[:-4]
else:
raise NotImplementedError(
'Either images, video, or workspace has to be specified')
config["workspace"] = path.join('./workspace', basename)
with torch.cuda.amp.autocast(enabled=not args.no_amp) if device.type == 'cuda' else nullcontext():
# Load our checkpoint
network = XMem(config, args.model, map_location=device).to(device).eval()
# Loads the S2M model
if args.s2m_model is not None:
s2m_saved = torch.load(args.s2m_model, map_location=device)
s2m_model = S2M().to(device).eval()
s2m_model.load_state_dict(s2m_saved)
else:
s2m_model = None
s2m_controller = S2MController(s2m_model, args.num_objects, ignore_class=255, device=device)
if args.fbrs_model is not None:
fbrs_controller = FBRSController(args.fbrs_model, device=device)
else:
fbrs_controller = None
# Manages most IO
resource_manager = ResourceManager(config)
app = QApplication(sys.argv)
ex = App(network, resource_manager, s2m_controller, fbrs_controller, config, device)
sys.exit(app.exec())