-
Notifications
You must be signed in to change notification settings - Fork 3
/
render.py
315 lines (267 loc) · 11.3 KB
/
render.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
import argparse
import os
import time
import jax
import jax.numpy as jnp
from jax import random
import flax
from flax import jax_utils
from flax import optim
from flax.training import checkpoints
import cv2
import functools
from absl import logging
import numpy as np
import mediapy
from pathlib import Path
import gin
from IPython.display import display, Markdown
from tqdm import tqdm
from hypernerf import evaluation
from hypernerf import schedules
from hypernerf import training
from hypernerf import models
from hypernerf import configs
from hypernerf import datasets
from hypernerf import image_utils
from hypernerf import visualization as viz
from hypernerf import model_utils
from hypernerf import utils
def render_scene(exp_dir, data_dir, camera_path_name='vrig_camera', interval=1, chunk_size=4096):
# print('Detected Devices:', jax.devices())
# @title Define imports and utility functions.
# Monkey patch logging.
def myprint(msg, *args, **kwargs):
pass
# print(msg % args)
logging.info = myprint
logging.warn = myprint
logging.error = myprint
checkpoint_dir = Path(exp_dir, 'checkpoints')
checkpoint_dir.mkdir(exist_ok=True, parents=True)
config_path = Path(exp_dir, 'config.gin')
with open(config_path, 'r') as f:
logging.info('Loading config from %s', config_path)
config_str = f.read()
gin.parse_config(config_str)
exp_config = configs.ExperimentConfig()
train_config = configs.TrainConfig()
eval_config = configs.EvalConfig()
spec_config = configs.SpecularConfig()
display(Markdown(
gin.config.markdown(gin.config_str())))
# @title Create datasource and show an example.
if spec_config.use_hyper_spec_model:
dummy_model = models.HyperSpecModel({}, 0, 0)
else:
dummy_model = models.NerfModel({}, 0, 0)
datasource = exp_config.datasource_cls(
data_dir=data_dir,
image_scale=exp_config.image_scale,
random_seed=exp_config.random_seed,
# Enable metadata based on model needs.
use_warp_id=dummy_model.use_warp,
use_appearance_id=(
dummy_model.nerf_embed_key == 'appearance'
or dummy_model.hyper_embed_key == 'appearance'),
use_camera_id=dummy_model.nerf_embed_key == 'camera',
use_time=dummy_model.warp_embed_key == 'time',
)
# @title Load model
# @markdown Defines the model and initializes its parameters.
rng = random.PRNGKey(exp_config.random_seed)
np.random.seed(exp_config.random_seed + jax.process_index())
devices_to_use = jax.devices()
learning_rate_sched = schedules.from_config(train_config.lr_schedule)
nerf_alpha_sched = schedules.from_config(train_config.nerf_alpha_schedule)
warp_alpha_sched = schedules.from_config(train_config.warp_alpha_schedule)
elastic_loss_weight_sched = schedules.from_config(
train_config.elastic_loss_weight_schedule)
hyper_alpha_sched = schedules.from_config(train_config.hyper_alpha_schedule)
hyper_sheet_alpha_sched = schedules.from_config(
train_config.hyper_sheet_alpha_schedule)
norm_loss_weight_sched = schedules.from_config(spec_config.norm_loss_weight_schedule)
norm_input_alpha_sched = schedules.from_config(spec_config.norm_input_alpha_schedule)
rng, key = random.split(rng)
params = {}
model, params['model'] = models.construct_nerf(
key,
use_hyper_spec_model=spec_config.use_hyper_spec_model,
batch_size=train_config.batch_size,
embeddings_dict=datasource.embeddings_dict,
near=datasource.near,
far=datasource.far,
use_sigma_gradient=spec_config.use_sigma_gradient,
use_predicted_norm=spec_config.use_predicted_norm,
)
optimizer_def = optim.Adam(learning_rate_sched(0))
optimizer = optimizer_def.create(params)
# state = model_utils.TrainState(optimizer=optimizer)
state = model_utils.TrainState(
optimizer=optimizer,
nerf_alpha=nerf_alpha_sched(0),
warp_alpha=warp_alpha_sched(0),
hyper_alpha=hyper_alpha_sched(0),
hyper_sheet_alpha=hyper_sheet_alpha_sched(0),
norm_loss_weight=norm_loss_weight_sched(0),
norm_input_alpha=norm_input_alpha_sched(0),
)
logging.info('Restoring checkpoint from %s', checkpoint_dir)
state = checkpoints.restore_checkpoint(checkpoint_dir, state)
step = state.optimizer.state.step + 1
state = jax_utils.replicate(state, devices=devices_to_use)
# param_count = sum(x.size for x in jax.tree_util.tree_leaves(params))
# print("Total number of params:", param_count)
del params
# @title Define pmapped render function.
devices = jax.devices()
def _model_fn(key_0, key_1, key_2, params, rays_dict, extra_params):
out = model.apply({'params': params},
rays_dict,
extra_params=extra_params,
rngs={
'coarse': key_0,
'fine': key_1,
'voxel': key_2
},
mutable=False,
use_predicted_norm=spec_config.use_predicted_norm,
return_points=False,
return_nv_details=False,
mask_ratio=1, # inference ratio is always 1
sharp_weights_std=0.1
)
return jax.lax.all_gather(out, axis_name='batch')
pmodel_fn = jax.pmap(
# Note rng_keys are useless in eval mode since there's no randomness.
_model_fn,
in_axes=(0, 0, 0, 0, 0, 0), # Only distribute the data input.
devices=devices_to_use,
axis_name='batch',
)
# pmodel_fn = jax.vmap(
# # Note rng_keys are useless in eval mode since there's no randomness.
# _model_fn,
# in_axes=(0, 0, 0, 0, 0, 0), # Only distribute the data input.
# # devices=devices_to_use,
# axis_name='batch',
# )
render_fn = functools.partial(evaluation.render_image,
model_fn=pmodel_fn,
device_count=len(devices),
chunk=chunk_size)
# @title Load cameras.
camera_dir = Path(data_dir, camera_path_name)
print(f'Loading cameras from {camera_dir}')
test_camera_paths = datasource.glob_cameras(camera_dir)
test_camera_paths = sort_camera_paths(test_camera_paths)
test_cameras = utils.parallel_map(datasource.load_camera, test_camera_paths, show_pbar=True)
mask_dir = Path(data_dir, 'resized_mask', f"{int(exp_config.image_scale)}x")
print(f"Loading masks from {mask_dir}")
mask_list = datasets.load_camera_masks(mask_dir, test_camera_paths, 1) # already resized
# @title Render video frames.
rng = rng + jax.process_index() # Make random seed separate across hosts.
keys = random.split(rng, len(devices))
results = []
relevant_keys = ['rgb', 'med_depth', 'ray_norm', 'ray_delta_x', 'med_points',
'ray_predicted_mask', 'ray_rotation_field']
raw_result_list = []
if interval == 1:
camera_path_name += "_full"
for i in tqdm(range(0, len(test_cameras), interval)):
# print(f'Rendering frame {i + 1}/{len(test_cameras)}')
camera = test_cameras[i]
batch = datasets.camera_to_rays(camera)
if not camera_path_name.startswith('vrig'):
batch['metadata'] = {
'appearance': jnp.ones_like(batch['origins'][..., 0, jnp.newaxis], jnp.uint32) * i,
'warp': jnp.ones_like(batch['origins'][..., 0, jnp.newaxis], jnp.uint32) * i,
'camera': jnp.ones_like(batch['origins'][..., 0, jnp.newaxis], jnp.uint32) * 0
}
else:
batch['metadata'] = {
# 'appearance': jnp.ones_like(batch['origins'][..., 0, jnp.newaxis], jnp.uint32) * ((i + 1) % 2 + 1),
'appearance': jnp.ones_like(batch['origins'][..., 0, jnp.newaxis], jnp.uint32) * i,
'warp': jnp.ones_like(batch['origins'][..., 0, jnp.newaxis], jnp.uint32) * i,
'camera': jnp.ones_like(batch['origins'][..., 0, jnp.newaxis], jnp.uint32) * ((i + 1) % 2 + 1)
}
mask = mask_list[i]
batch['mask'] = mask
render = render_fn(state, batch, rng=rng)
# save raw results for future use
raw_result = {}
# value_size = 0
for key in render:
if not key in relevant_keys:
continue
raw_result[key] = np.array(render[key])
# print(key, raw_result[key].size * raw_result[key].itemsize)
# value_size += raw_result[key].size * raw_result[key].itemsize
raw_result_list.append(raw_result)
rgb = np.array(render['rgb'])
depth_med = np.array(render['med_depth'])
dummy_image = np.zeros_like(rgb)
ray_norm = np.array(render['ray_norm'])
ray_norm = model_utils.normalize_vector(ray_norm)
ray_norm = ray_norm / 2.0 + 0.5
ray_delta_x = np.array(render['ray_delta_x'])
ray_delta_x = np.abs(ray_delta_x)
ray_delta_x = ray_delta_x * 10
med_points = np.array(render['med_points'])
med_points = (med_points + 1.5) / 3 # -1.5 ~ 1.5 --> 0 ~ 1
if 'ray_predicted_mask' in render:
ray_predicted_mask = np.array(render['ray_predicted_mask'])
ray_predicted_mask = np.broadcast_to(ray_predicted_mask, dummy_image.shape) # grayscale to color
else:
ray_predicted_mask = dummy_image
cv2.imshow('rgb', cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR))
cv2.waitKey(1)
results.append((rgb, depth_med, ray_norm, ray_predicted_mask, ray_delta_x, med_points, dummy_image))
# save raw render results
raw_result_save_path = os.path.join(exp_dir, "render_result_{}".format(camera_path_name))
with open(raw_result_save_path, "wb+") as f:
np.save(f, raw_result_list)
# @title Show rendered video.
fps = 30 # @param {type:'number'}
rgb_frames = []
debug_frames = []
for rgb, depth_med, ray_norm, ray_predicted_mask, ray_delta_x, med_points, dummy_image in results:
depth_viz = viz.colorize(depth_med.squeeze(), cmin=datasource.near, cmax=datasource.far, invert=True)
med_points = med_points[..., :3].squeeze()
row1 = np.concatenate([rgb, depth_viz, ray_norm], axis=1)
row2 = np.concatenate([ray_predicted_mask, ray_delta_x, med_points], axis=1)
debug_frame = np.concatenate([row1, row2], axis=0)
debug_frames.append(image_utils.image_to_uint8(debug_frame))
rgb_frames.append(image_utils.image_to_uint8(rgb))
mediapy.set_show_save_dir(exp_dir)
mediapy.show_video(rgb_frames, fps=fps, title="result_{}_rgb".format(camera_path_name))
mediapy.show_video(debug_frames, fps=fps, title="result_{}".format(camera_path_name))
def sort_camera_paths(camera_paths):
camera_names = [path.stem for path in camera_paths]
id_path_pairs = []
for i in range(len(camera_names)):
camera_name = camera_names[i]
camera_path = camera_paths[i]
try:
camera_id = camera_name.split('_')[1]
int(camera_id)
except:
camera_id = camera_name.split('_')[0]
int(camera_id)
id_path_pairs.append((camera_id, camera_path))
id_path_pairs = sorted(id_path_pairs)
camera_paths = [path for id, path in id_path_pairs]
return camera_paths
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--base_folder", type=str)
parser.add_argument("--data_dir", type=str)
parser.add_argument("--interval", type=int, default=1)
parser.add_argument("--chunk_size", type=int, default=4096)
args = parser.parse_args()
exp_dir = args.base_folder
data_dir = args.data_dir
interval = args.interval
chunk_size = args.chunk_size
camera_path_name = 'vrig_camera'
render_scene(exp_dir, data_dir, camera_path_name=camera_path_name, interval=interval, chunk_size=chunk_size)