Skip to content

Commit

Permalink
move density threshold to train
Browse files Browse the repository at this point in the history
  • Loading branch information
kwea123 committed Jul 4, 2022
1 parent d97d663 commit cf7b4db
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 8 deletions.
7 changes: 2 additions & 5 deletions models/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import tinycudann as tcnn
import vren
from .custom_functions import TruncExp
from .rendering import MAX_SAMPLES
import numpy as np


Expand Down Expand Up @@ -146,8 +145,7 @@ def sample_uniform_and_occupied_cells(self, M):
return cells

@torch.no_grad()
def update_density_grid(self, warmup=False, decay=0.95,
density_threshold=0.05*MAX_SAMPLES/(2*3**0.5)):
def update_density_grid(self, density_threshold, warmup=False, decay=0.95):
# create temporary grid
tmp_grid = -torch.ones_like(self.density_grid)
if warmup: # during the first 256 steps
Expand All @@ -174,6 +172,5 @@ def update_density_grid(self, warmup=False, decay=0.95,
self.mean_density = self.density_grid.clamp(min=0).mean().item()

# pack to bitfield
vren.packbits(self.density_grid,
min(self.mean_density, density_threshold),
vren.packbits(self.density_grid, min(self.mean_density, density_threshold),
self.density_bitfield)
5 changes: 2 additions & 3 deletions models/rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def __render_rays_test(model, rays_o, rays_d, hits_t, **kwargs):
of each marching (the variable @N_samples)
"""
results = {}
T_threshold = kwargs.get('T_threshold', 1e-4)

# output tensors to be filled in
N_rays = len(rays_o)
Expand Down Expand Up @@ -90,8 +89,8 @@ def __render_rays_test(model, rays_o, rays_d, hits_t, **kwargs):

vren.composite_test_fw(
sigmas, rgbs, deltas, ts,
hits_t[:, 0], alive_indices, T_threshold, N_eff_samples,
opacity, depth, rgb)
hits_t[:, 0], alive_indices, kwargs.get('T_threshold', 1e-4),
N_eff_samples, opacity, depth, rgb)
alive_indices = alive_indices[alive_indices>=0] # remove converged rays

rgb_bg = torch.ones(3, device=device) # TODO: infer env map from network
Expand Down

0 comments on commit cf7b4db

Please sign in to comment.