-
Notifications
You must be signed in to change notification settings - Fork 10
/
predict_places_marigold.py
192 lines (141 loc) · 5.61 KB
/
predict_places_marigold.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import submitit
from submitit.helpers import Checkpointable, DelayedSubmission
import os
from enum import Enum
from typing import Optional
PLACES_PATH = "REPLACE_ME" # path to the places365 dataset
class SlurmJobType(Enum):
CPU = 0
GPU = 1
def is_slurm_available() -> bool:
return submitit.AutoExecutor(".").cluster == "slurm"
def setup_slurm(
name: str,
job_type: SlurmJobType,
submitit_folder: str = "submitit",
depend_on: Optional[str] = None,
timeout: int = 180,
high_compute_memory: bool = False,
) -> submitit.AutoExecutor:
os.makedirs(submitit_folder, exist_ok=True)
executor = submitit.AutoExecutor(folder=submitit_folder, slurm_max_num_timeout=10)
################################################
## ##
## ADAPT THESE PARAMETERS TO YOUR CLUSTER ##
## ##
################################################
# You may choose low-priority partitions where job preemption is enabled as
# any preempted jobs will automatically resume/restart when rescheduled.
if job_type == SlurmJobType.CPU:
kwargs = {
"slurm_partition": "compute",
"gpus_per_node": 0,
"slurm_cpus_per_task": 14,
"slurm_mem": "32GB" if not high_compute_memory else "64GB",
}
elif job_type == SlurmJobType.GPU:
kwargs = {
"slurm_partition": "low-prio-gpu",
"gpus_per_node": 1,
"slurm_cpus_per_task": 4,
"slurm_mem": "32GB",
# If your cluster supports choosing specific GPUs based on constraints,
# you can uncomment this line to select low-memory GPUs.
"slurm_constraint": "p40",
}
###################
## ##
## ALL DONE! ##
## ##
###################
kwargs = {
**kwargs,
"slurm_job_name": name,
"timeout_min": timeout,
"tasks_per_node": 1,
"slurm_additional_parameters": {"depend": f"afterany:{depend_on}"}
if depend_on is not None
else {},
}
executor.update_parameters(**kwargs)
return executor
def get_marigold_model():
import sys
sys.path.append("PATH_TO_MARIGOLD_REPOSITORY_CLONE")
from marigold import MarigoldPipeline
marigold = MarigoldPipeline.from_pretrained("PATH_TO_MARIGOLD_CHECKPOINT")
try:
import xformers
marigold.enable_xformers_memory_efficient_attention()
except:
pass
return marigold
def run_inference_for_category(category_id, out_path):
import numpy as np
from PIL import Image
from tqdm.auto import tqdm
from torch.utils.data import Dataset, DataLoader
class CategoryDataset(Dataset):
def __init__(self, category_id, out_path):
self.category_id = category_id
self.category_path = os.path.join(PLACES_PATH, str(category_id))
images_processed = len(os.listdir(os.path.join(out_path, str(category_id))))
print(f"Found {images_processed} images that have already been processed")
self.images = sorted(os.listdir(self.category_path))[images_processed:]
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
image_name = self.images[idx]
image_path = os.path.join(self.category_path, image_name)
image = Image.open(image_path).convert("RGB")
return image_name, image
print(f"This runner is for category {category_id}")
os.makedirs(os.path.join(out_path, category_id), exist_ok=True)
dataset = CategoryDataset(category_id, out_path)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=4, collate_fn=lambda x: x)
marigold = get_marigold_model().to("cuda")
print("Initialized Marigold model")
for image_names in tqdm(dataloader):
image_name, image = image_names[0]
marigold_out = marigold(
image,
denoising_steps=10,
ensemble_size=10,
match_input_res=True,
batch_size=0,
color_map="Spectral",
show_progress_bar=False,
)
out_image_arr = marigold_out["depth_np"].squeeze()
np.save(os.path.join(out_path, str(category_id), image_name.replace(".jpg", ".npy")), out_image_arr)
class CategoryInference(Checkpointable):
def __call__(self, *args, **kwargs):
return run_inference_for_category(*args, **kwargs)
def checkpoint(self, *args, **kwargs) -> DelayedSubmission:
"""Resubmits the same callable with the same arguments"""
return DelayedSubmission(self, *args, **kwargs) # type: ignore
def run_inference_for_all_categories(out_path):
os.makedirs(out_path, exist_ok=True)
category_ids = sorted(os.listdir(PLACES_PATH))
if is_slurm_available():
print("SLURM is available")
executor = setup_slurm(
f"places365",
SlurmJobType.GPU,
timeout=48 * 60,
)
executor.update_parameters(slurm_array_parallelism=20)
with executor.batch():
for category_id in category_ids:
executor.submit(CategoryInference(), category_id, out_path)
print(f"Submitted {len(category_ids)} jobs to SLURM")
else:
from tqdm.auto import tqdm
for category_id in tqdm(category_ids):
run_inference_for_category(category_id, out_path)
def main(out_path):
run_inference_for_all_categories(out_path)
if __name__ == "__main__":
import fire
fire.Fire(main)
# %%