diff --git a/models/networks.py b/models/networks.py index ee855f17..12d45fab 100644 --- a/models/networks.py +++ b/models/networks.py @@ -3,7 +3,6 @@ import tinycudann as tcnn import vren from .custom_functions import TruncExp -from .rendering import MAX_SAMPLES import numpy as np @@ -146,8 +145,7 @@ def sample_uniform_and_occupied_cells(self, M): return cells @torch.no_grad() - def update_density_grid(self, warmup=False, decay=0.95, - density_threshold=0.05*MAX_SAMPLES/(2*3**0.5)): + def update_density_grid(self, density_threshold, warmup=False, decay=0.95): # create temporary grid tmp_grid = -torch.ones_like(self.density_grid) if warmup: # during the first 256 steps @@ -174,6 +172,5 @@ def update_density_grid(self, warmup=False, decay=0.95, self.mean_density = self.density_grid.clamp(min=0).mean().item() # pack to bitfield - vren.packbits(self.density_grid, - min(self.mean_density, density_threshold), + vren.packbits(self.density_grid, min(self.mean_density, density_threshold), self.density_bitfield) \ No newline at end of file diff --git a/models/rendering.py b/models/rendering.py index 352153d1..b6213450 100644 --- a/models/rendering.py +++ b/models/rendering.py @@ -50,7 +50,6 @@ def __render_rays_test(model, rays_o, rays_d, hits_t, **kwargs): of each marching (the variable @N_samples) """ results = {} - T_threshold = kwargs.get('T_threshold', 1e-4) # output tensors to be filled in N_rays = len(rays_o) @@ -90,8 +89,8 @@ def __render_rays_test(model, rays_o, rays_d, hits_t, **kwargs): vren.composite_test_fw( sigmas, rgbs, deltas, ts, - hits_t[:, 0], alive_indices, T_threshold, N_eff_samples, - opacity, depth, rgb) + hits_t[:, 0], alive_indices, kwargs.get('T_threshold', 1e-4), + N_eff_samples, opacity, depth, rgb) alive_indices = alive_indices[alive_indices>=0] # remove converged rays rgb_bg = torch.ones(3, device=device) # TODO: infer env map from network