-
Notifications
You must be signed in to change notification settings - Fork 35
/
utils.py
114 lines (94 loc) · 4.21 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# Copyright Niantic 2019. Patent Pending. All rights reserved.
#
# This software is licensed under the terms of the Monodepth2 licence
# which allows for non-commercial use only, the full terms of which are made
# available in the LICENSE file.
from __future__ import absolute_import, division, print_function
import os
import hashlib
import zipfile
from six.moves import urllib
def readlines(filename):
"""Read all the lines in a text file and return as a list
"""
with open(filename, 'r') as f:
lines = f.read().splitlines()
return lines
def normalize_image(x):
"""Rescale image pixels to span range [0, 1]
"""
ma = float(x.max().cpu().data)
mi = float(x.min().cpu().data)
d = ma - mi if ma != mi else 1e5
return (x - mi) / d
def sec_to_hm(t):
"""Convert time in seconds to time in hours, minutes and seconds
e.g. 10239 -> (2, 50, 39)
"""
t = int(t)
s = t % 60
t //= 60
m = t % 60
t //= 60
return t, m, s
def sec_to_hm_str(t):
"""Convert time in seconds to a nice string
e.g. 10239 -> '02h50m39s'
"""
h, m, s = sec_to_hm(t)
return "{:02d}h{:02d}m{:02d}s".format(h, m, s)
def download_model_if_doesnt_exist(model_name):
"""If pretrained kitti model doesn't exist, download and unzip it
"""
# values are tuples of (<google cloud URL>, <md5 checksum>)
download_paths = {
"mono_640x192":
("https://storage.googleapis.com/niantic-lon-static/research/monodepth2/mono_640x192.zip",
"a964b8356e08a02d009609d9e3928f7c"),
"stereo_640x192":
("https://storage.googleapis.com/niantic-lon-static/research/monodepth2/stereo_640x192.zip",
"3dfb76bcff0786e4ec07ac00f658dd07"),
"mono+stereo_640x192":
("https://storage.googleapis.com/niantic-lon-static/research/monodepth2/mono%2Bstereo_640x192.zip",
"c024d69012485ed05d7eaa9617a96b81"),
"mono_no_pt_640x192":
("https://storage.googleapis.com/niantic-lon-static/research/monodepth2/mono_no_pt_640x192.zip",
"9c2f071e35027c895a4728358ffc913a"),
"stereo_no_pt_640x192":
("https://storage.googleapis.com/niantic-lon-static/research/monodepth2/stereo_no_pt_640x192.zip",
"41ec2de112905f85541ac33a854742d1"),
"mono+stereo_no_pt_640x192":
("https://storage.googleapis.com/niantic-lon-static/research/monodepth2/mono%2Bstereo_no_pt_640x192.zip",
"46c3b824f541d143a45c37df65fbab0a"),
"mono_1024x320":
("https://storage.googleapis.com/niantic-lon-static/research/monodepth2/mono_1024x320.zip",
"0ab0766efdfeea89a0d9ea8ba90e1e63"),
"stereo_1024x320":
("https://storage.googleapis.com/niantic-lon-static/research/monodepth2/stereo_1024x320.zip",
"afc2f2126d70cf3fdf26b550898b501a"),
"mono+stereo_1024x320":
("https://storage.googleapis.com/niantic-lon-static/research/monodepth2/mono%2Bstereo_1024x320.zip",
"cdc5fc9b23513c07d5b19235d9ef08f7"),
}
if not os.path.exists("models"):
os.makedirs("models")
model_path = os.path.join("models", model_name)
def check_file_matches_md5(checksum, fpath):
if not os.path.exists(fpath):
return False
with open(fpath, 'rb') as f:
current_md5checksum = hashlib.md5(f.read()).hexdigest()
return current_md5checksum == checksum
# see if we have the model already downloaded...
if not os.path.exists(os.path.join(model_path, "encoder.pth")):
model_url, required_md5checksum = download_paths[model_name]
if not check_file_matches_md5(required_md5checksum, model_path + ".zip"):
print("-> Downloading pretrained model to {}".format(model_path + ".zip"))
urllib.request.urlretrieve(model_url, model_path + ".zip")
if not check_file_matches_md5(required_md5checksum, model_path + ".zip"):
print(" Failed to download a file which matches the checksum - quitting")
quit()
print(" Unzipping model...")
with zipfile.ZipFile(model_path + ".zip", 'r') as f:
f.extractall(model_path)
print(" Model unzipped to {}".format(model_path))