forked from duanzhiihao/RAPiD
-
Notifications
You must be signed in to change notification settings - Fork 0
/
export_onnx.py
53 lines (41 loc) · 1.51 KB
/
export_onnx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import argparse
import torch
from models.rapid_export import RAPiD
def export():
parser = argparse.ArgumentParser()
parser.add_argument('--weights', type=str, default='weights/pL1_MWHB1024_Mar11_4000.ckpt')
parser.add_argument('--device', type=str, default='cuda:0')
parser.add_argument('--half', action='store_true')
args = parser.parse_args()
device = torch.device(args.device)
input_shape = (1, 3, 1024, 1024)
model = RAPiD(input_hw=input_shape[2:4])
weights = torch.load(args.weights)
# from mycv.utils.torch_utils import summary_weights
# summary_weights(weights['model'])
model.load_state_dict(weights['model'])
model = model.to(device=device)
model.eval()
# for k, m in model.named_modules():
# if hasattr(m, 'num_batches_tracked'):
# m.num_batches_tracked = m.num_batches_tracked.float()
if args.half:
model = model.half()
if args.half:
x = torch.rand(*input_shape, dtype=torch.float16, device=device)
else:
x = torch.rand(*input_shape, device=device)
torch.onnx.export(model, x, 'rapid.onnx', verbose=True, opset_version=11)
def check():
import onnx
model = onnx.load("rapid.onnx")
# Check that the IR is well formed
onnx.checker.check_model(model)
# Print a human readable representation of the graph
s = onnx.helper.printable_graph(model.graph)
with open('tmp.txt', 'a') as f:
print(s, file=f)
debug = 1
if __name__ == '__main__':
export()
# check()