-
Notifications
You must be signed in to change notification settings - Fork 0
/
rendering.py
264 lines (218 loc) · 10.8 KB
/
rendering.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
import torch
from einops import rearrange, reduce, repeat
__all__ = ['render_rays']
def sample_pdf(bins, weights, N_importance, det=False, eps=1e-5):
"""
Importance sample the points on the ray
Distribution defined by weights
Inputs:
bins : (N_rays,N_samples_+1) N_samples_ = No of coarse samples per ray-2
weights : (N_rays, N_samples_ )
N_importance : No of samples to draw from the distribution
det : deterministic or not
eps: a small no
Outputs:
samples: the sampled samples
"""
N_rays, N_samples_ = weights.shape
weights = weights + eps # prevent division by zero (don't do inplace op!)
pdf = weights / reduce(weights, 'n1 n2 -> n1 1', 'sum') # (N_rays, N_samples_)
cdf = torch.cumsum(pdf, -1) # (N_rays, N_samples), cumulative distribution function
cdf = torch.cat([torch.zeros_like(cdf[: ,:1]), cdf], -1) # (N_rays, N_samples_+1)
# padded to 0~1 inclusive
if det:
u = torch.linspace(0, 1, N_importance, device=bins.device)
u = u.expand(N_rays, N_importance)
else:
u = torch.rand(N_rays, N_importance, device=bins.device)
u = u.contiguous()
inds = torch.searchsorted(cdf, u, right=True)
below = torch.clamp_min(inds-1, 0)
above = torch.clamp_max(inds, N_samples_)
inds_sampled = rearrange(torch.stack([below, above], -1), 'n1 n2 c -> n1 (n2 c)', c=2)
cdf_g = rearrange(torch.gather(cdf, 1, inds_sampled), 'n1 (n2 c) -> n1 n2 c', c=2)
bins_g = rearrange(torch.gather(bins, 1, inds_sampled), 'n1 (n2 c) -> n1 n2 c', c=2)
denom = cdf_g[...,1]-cdf_g[...,0]
denom[denom<eps] = 1 # denom equals 0 means a bin has weight 0, in which case it will not be sampled
# anyway, therefore any value for it is fine (set to 1 here)
samples = bins_g[...,0] + (u-cdf_g[...,0])/denom * (bins_g[...,1]-bins_g[...,0])
return samples
def sample_points(rays, N_samples,perturb,use_disp):
"""
Given Rays the function samples points on the rays depending on the distribution
Use disparity if modeling foreground separately and background separately
"""
N_rays = rays.shape[0]
rays_o, rays_d = rays[:, 0:3], rays[:, 3:6] # both (N_rays, 3)
near, far = rays[:, 6:7], rays[:, 7:8] # both (N_rays, 1)
z_steps = torch.linspace(0, 1, N_samples, device=rays.device) # (N_samples)
if not use_disp: # use linear sampling in depth space
z_vals = near * (1-z_steps) + far * z_steps
else: # use linear sampling in disparity space
z_vals = 1/(1/near * (1-z_steps) + 1/far * z_steps)
z_vals = z_vals.expand(N_rays, N_samples)
if perturb > 0: # perturb sampling depths (z_vals)
z_vals_mid = 0.5 * (z_vals[: ,:-1] + z_vals[: ,1:]) # (N_rays, N_samples-1) interval mid points
# get intervals between samples
upper = torch.cat([z_vals_mid, z_vals[: ,-1:]], -1)
lower = torch.cat([z_vals[: ,:1], z_vals_mid], -1)
perturb_rand = perturb * torch.rand(z_vals.shape, device=rays.device)
z_vals = lower + (upper - lower) * perturb_rand
return z_vals
def volumetric_rendering(z_vals,dir_,sigmas,noise_std):
"""
Volumteric Rendering from the samples obtained
Follow the equation. Delta is the diff of points samples
alphas are the rendered colors which are combined to give the pixel color
"""
deltas = z_vals[:, 1:] - z_vals[:, :-1] # (N_rays, N_samples_-1)
delta_inf = 1e10 * torch.ones_like(deltas[:, :1]) # (N_rays, 1) the last delta is infinity
deltas = torch.cat([deltas, delta_inf], -1) # (N_rays, N_samples_)
# Multiply each distance by the norm of its corresponding direction ray
# to convert to real world distance (accounts for non-unit directions).
deltas = deltas * torch.norm(dir_.unsqueeze(1), dim=-1)
noise = torch.randn(sigmas.shape, device=sigmas.device) * noise_std
# compute alpha by the formula (3)
alphas = 1-torch.exp(-deltas*torch.relu(sigmas+noise)) # (N_rays, N_samples_)
alphas_shifted = \
torch.cat([torch.ones_like(alphas[:, :1]), 1-alphas+1e-10], -1) # [1, a1, a2, ...]
weights = \
alphas * torch.cumprod(alphas_shifted, -1)[:, :-1] # (N_rays, N_samples_)
return weights
def inference(model, embedding_xyz, xyz_, dir_, dir_embedded, z_vals,chunk=1024*32,\
noise_std=1,white_back=True,weights_only=False):
"""
Helper function that performs model inference.
Inputs:
model: NeRF model (coarse or fine)
embedding_xyz: embedding module for xyz
xyz_: (N_rays, N_samples_, 3) sampled positions
N_samples_ is the number of sampled points in each ray;
= N_samples for coarse model
= N_samples+N_importance for fine model
dir_: (N_rays, 3) ray directions
dir_embedded: (N_rays, embed_dir_channels) embedded directions
z_vals: (N_rays, N_samples_) depths of the sampled positions
chunk: The chunk for validation
noise_std: Augmentation for points sampling
white_back: Special case for white_back images
weights_only: do inference on sigma only or not
Outputs:
if weights_only:
weights: (N_rays, N_samples_): weights of each sample
else:
rgb_final: (N_rays, 3) the final rgb image
depth_final: (N_rays) depth map
weights: (N_rays, N_samples_): weights of each sample
"""
N_rays,N_samples_,_ = xyz_.shape
# Embed directions
xyz_ = xyz_.view(-1, 3) # (N_rays*N_samples_, 3)
if not weights_only:
dir_embedded = torch.repeat_interleave(dir_embedded, repeats=N_samples_, dim=0)
# (N_rays*N_samples_, embed_dir_channels)
# Perform model inference to get rgb and raw sigma
B = xyz_.shape[0]
out_chunks = []
for i in range(0, B, chunk):
# Embed positions by chunk
xyz_embedded = embedding_xyz(xyz_[i:i+chunk])
if not weights_only:
xyzdir_embedded = torch.cat([xyz_embedded,
dir_embedded[i:i+chunk]], 1)
else:
xyzdir_embedded = xyz_embedded
out_chunks += [model(xyzdir_embedded, sigma_only=weights_only)]
out = torch.cat(out_chunks, 0)
if weights_only:
sigmas = out.view(N_rays, N_samples_)
else:
rgbsigma = out.view(N_rays, N_samples_, 4)
rgbs = rgbsigma[..., :3] # (N_rays, N_samples_, 3)
sigmas = rgbsigma[..., 3] # (N_rays, N_samples_)
# Convert these values using volume rendering (Section 4)
weights = volumetric_rendering(z_vals,dir_,sigmas,noise_std)
weights_sum = weights.sum(1) # (N_rays), the accumulated opacity along the rays
# equals "1 - (1-a1)(1-a2)...(1-an)" mathematically
if weights_only:
return weights
# compute final weighted outputs
rgb_final = torch.sum(weights.unsqueeze(-1)*rgbs, -2) # (N_rays, 3)
depth_final = torch.sum(weights*z_vals, -1) # (N_rays)
if white_back:
rgb_final = rgb_final + 1-weights_sum.unsqueeze(-1)
return rgb_final, depth_final, weights
def rendering(models,
embeddings,
rays,
N_samples=64,
use_disp=False,
perturb=0,
noise_std=1,
N_importance=0,
chunk=1024*32,
white_back=False,
test_time=False
):
"""
Render rays by computing the output of @model applied on @rays
Inputs:
models: list of NeRF models (coarse and fine) defined in nerf.py
embeddings: list of embedding models of origin and direction defined in nerf.py
rays: (N_rays, 3+3+2), ray origins, directions and near, far depth bounds
N_samples: number of coarse samples per ray
use_disp: whether to sample in disparity space (inverse depth)
perturb: factor to perturb the sampling position on the ray (for coarse model only)
noise_std: factor to perturb the model's prediction of sigma
N_importance: number of fine samples per ray
chunk: the chunk size in batched inference
white_back: whether the background is white (dataset dependent)
test_time: whether it is test (inference only) or not. If True, it will not do inference
on coarse rgb to save time
Outputs:
result: dictionary containing final rgb and depth maps for coarse and fine models
"""
# Extract models from lists
model_coarse = models['coarse']
embedding_xyz = embeddings['xyz']
embedding_dir = embeddings['dir']
# Decompose the inputs
N_rays = rays.shape[0]
rays_o, rays_d = rays[:, 0:3], rays[:, 3:6] # both (N_rays, 3)
near, far = rays[:, 6:7], rays[:, 7:8] # both (N_rays, 1)
# Embed direction
dir_embedded = embedding_dir(rays_d) # (N_rays, embed_dir_channels)
# Sample depth points
z_vals = sample_points(rays,N_samples,perturb,use_disp)
xyz_coarse_sampled = rays_o.unsqueeze(1) + \
rays_d.unsqueeze(1) * z_vals.unsqueeze(2) # (N_rays, N_samples, 3)
if test_time:
weights_coarse = inference(model_coarse, embedding_xyz, xyz_coarse_sampled, rays_d,\
dir_embedded, z_vals,chunk,noise_std,white_back, weights_only=True)
result = {'opacity_coarse': weights_coarse.sum(1)}
else:
rgb_coarse, depth_coarse, weights_coarse = \
inference(model_coarse, embedding_xyz, xyz_coarse_sampled, rays_d,
dir_embedded, z_vals,chunk,noise_std,white_back, weights_only=False)
result = {'rgb_coarse': rgb_coarse,
'depth_coarse': depth_coarse,
'opacity_coarse': weights_coarse.sum(1)
}
# Once we get sigmas from coarse model we use that as a prior for sampling the fine network
if N_importance > 0: # sample points for fine model
z_vals_mid = 0.5 * (z_vals[: ,:-1] + z_vals[: ,1:]) # (N_rays, N_samples-1) interval mid points
z_vals_ = sample_pdf(z_vals_mid, weights_coarse[:, 1:-1],
N_importance, det=(perturb==0)).detach()
# detach so that grad doesn't propogate to weights_coarse from here
z_vals, _ = torch.sort(torch.cat([z_vals, z_vals_], -1), -1)
xyz_fine_sampled = rays_o.unsqueeze(1) + \
rays_d.unsqueeze(1) * z_vals.unsqueeze(2)
# (N_rays, N_samples+N_importance, 3)
model_fine = models['fine']
rgb_fine, depth_fine, weights_fine = \
inference(model_fine, embedding_xyz,xyz_fine_sampled, rays_d,
dir_embedded,z_vals,chunk,noise_std,white_back, weights_only=False)
result['rgb_fine'] = rgb_fine
result['depth_fine'] = depth_fine
result['opacity_fine'] = weights_fine.sum(1)
return result