-
Notifications
You must be signed in to change notification settings - Fork 53
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added GaussianSplats renderable, GPU sort and gaussian_splatting.py e…
…xample
- Loading branch information
1 parent
72f1cb4
commit 043fb43
Showing
8 changed files
with
1,528 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,235 @@ | ||
import moderngl | ||
import numpy as np | ||
from moderngl_window.opengl.vao import VAO | ||
|
||
from aitviewer.scene.camera import CameraInterface | ||
from aitviewer.scene.node import Node | ||
from aitviewer.shaders import ( | ||
get_gaussian_splat_draw_program, | ||
get_gaussian_splat_prepare_program, | ||
) | ||
from aitviewer.utils.decorators import hooked | ||
from aitviewer.utils.gpu_sort import GpuSort | ||
|
||
|
||
class GaussianSplats(Node): | ||
PREPARE_GROUP_SIZE = 128 | ||
|
||
def __init__(self, splat_positions, splat_shs, splat_opacities, splat_scales, splat_rotations, **kwargs): | ||
super().__init__(**kwargs) | ||
|
||
self.num_splats = splat_positions.shape[0] | ||
self.splat_positions: np.ndarray = splat_positions | ||
self.splat_shs: np.ndarray = splat_shs | ||
self.splat_opacities: np.ndarray = splat_opacities | ||
self.splat_scales: np.ndarray = splat_scales | ||
self.splat_rotations: np.ndarray = splat_rotations | ||
|
||
self.splat_opacity_scale = 1.0 | ||
self.splat_size_scale = 1.0 | ||
|
||
self.backface_culling = False | ||
|
||
self._debug_gui = False | ||
|
||
def is_transparent(self): | ||
return True | ||
|
||
@property | ||
def bounds(self): | ||
return self.current_bounds | ||
|
||
@property | ||
def current_bounds(self): | ||
return self.get_bounds(self.splat_positions) | ||
|
||
@Node.once | ||
def make_renderable(self, ctx: moderngl.Context): | ||
self.prog_prepare = get_gaussian_splat_prepare_program(self.PREPARE_GROUP_SIZE) | ||
self.prog_draw = get_gaussian_splat_draw_program() | ||
|
||
# Buffer for splat positions. | ||
self.splat_positions_buf = ctx.buffer(self.splat_positions.astype(np.float32).tobytes()) | ||
|
||
# Buffer for other splat data: SHs, opacity, scale, rotation. | ||
# | ||
# TODO: In theory we could pre-process rotations and scales and store them | ||
# as a 6 element covariance matrix directly here. | ||
# | ||
# TODO: Currently this only renders with base colors (first spherical harmonics coefficient) | ||
# We need to unswizzle the other coefficients and evaluate them for rendering here. | ||
splat_data = np.hstack( | ||
( | ||
self.splat_shs[:, :3] * 0.2820948 + 0.5, | ||
self.splat_opacities[:, np.newaxis], | ||
self.splat_scales, | ||
np.zeros((self.num_splats, 1), np.float32), | ||
self.splat_rotations, | ||
) | ||
) | ||
self.splat_data_buf = ctx.buffer(splat_data.astype(np.float32).tobytes()) | ||
|
||
# Buffer for splat views. | ||
self.splat_views_buf = ctx.buffer(None, reserve=self.num_splats * 48) | ||
|
||
# Buffers for distances and sorted indices. | ||
self.splat_distances_buf = ctx.buffer(None, reserve=self.num_splats * 4) | ||
self.splat_sorted_indices_buf = ctx.buffer(None, reserve=self.num_splats * 4) | ||
|
||
# Create a vao for rendering a single quad. | ||
indices = np.array((0, 1, 2, 1, 3, 2), np.uint32) | ||
self.vbo_indices = ctx.buffer(indices.tobytes()) | ||
self.vao = VAO() | ||
self.vao.index_buffer(self.vbo_indices) | ||
|
||
self.gpu_sort = GpuSort(ctx, self.num_splats) | ||
|
||
# Time queries for profiling. | ||
self.time_queries = { | ||
"prepare": ctx.query(time=True), | ||
"sort": ctx.query(time=True), | ||
"draw": ctx.query(time=True), | ||
} | ||
self.ctx = ctx | ||
|
||
def render(self, camera: CameraInterface, **kwargs): | ||
# Convert gaussians from 3D to 2D quads. | ||
with self.time_queries["prepare"]: | ||
self.splat_positions_buf.bind_to_storage_buffer(0) | ||
self.splat_data_buf.bind_to_storage_buffer(1) | ||
self.splat_views_buf.bind_to_storage_buffer(2) | ||
self.splat_distances_buf.bind_to_storage_buffer(3) | ||
self.splat_sorted_indices_buf.bind_to_storage_buffer(4) | ||
|
||
self.prog_prepare["u_opacity_scale"] = self.splat_opacity_scale | ||
self.prog_prepare["u_scale2"] = np.square(self.splat_size_scale) | ||
|
||
self.prog_prepare["u_num_splats"] = self.num_splats | ||
V = camera.get_view_matrix() | ||
P = camera.get_projection_matrix() | ||
self.prog_prepare["u_limit"] = 1.3 / P[0, 0] | ||
self.prog_prepare["u_focal"] = kwargs["window_size"][0] * P[0, 0] * 0.5 | ||
|
||
self.prog_prepare["u_world_from_object"].write(self.model_matrix.T.astype("f4").tobytes()) | ||
self.prog_prepare["u_view_from_world"].write(V.T.astype("f4").tobytes()) | ||
self.prog_prepare["u_clip_from_world"].write((P @ V).astype("f4").T.tobytes()) | ||
|
||
num_groups = (self.num_splats + self.PREPARE_GROUP_SIZE - 1) // self.PREPARE_GROUP_SIZE | ||
self.prog_prepare.run(num_groups, 1, 1) | ||
self.ctx.memory_barrier() | ||
|
||
# Sort splats based on distance to camera plane. | ||
with self.time_queries["sort"]: | ||
self.gpu_sort.run(self.ctx, self.splat_distances_buf, self.splat_sorted_indices_buf) | ||
|
||
# Render each splat as a 2D quad with instancing. | ||
with self.time_queries["draw"]: | ||
self.splat_views_buf.bind_to_storage_buffer(2) | ||
self.splat_sorted_indices_buf.bind_to_storage_buffer(3) | ||
|
||
self.prog_draw["u_screen_size"].write(np.array(kwargs["window_size"], np.float32).tobytes()) | ||
|
||
kwargs["fbo"].depth_mask = False | ||
self.vao.render(self.prog_draw, moderngl.TRIANGLES, 6, 0, self.num_splats) | ||
kwargs["fbo"].depth_mask = True | ||
|
||
def gui(self, imgui): | ||
if self._debug_gui: | ||
# Draw debugging stats about marching cubes and rendering time. | ||
total = 0 | ||
for k, v in self.time_queries.items(): | ||
imgui.text(f"{k}: {v.elapsed * 1e-6:5.3f}ms") | ||
total += v.elapsed * 1e-6 | ||
imgui.text(f"Total: {total: 5.3f}ms") | ||
|
||
_, self.splat_size_scale = imgui.drag_float( | ||
"Size", | ||
self.splat_size_scale, | ||
1e-2, | ||
min_value=0.001, | ||
max_value=10.0, | ||
format="%.3f", | ||
) | ||
|
||
_, self.splat_opacity_scale = imgui.drag_float( | ||
"Opacity", | ||
self.splat_opacity_scale, | ||
1e-2, | ||
min_value=0.001, | ||
max_value=10.0, | ||
format="%.3f", | ||
) | ||
|
||
super().gui(imgui) | ||
|
||
@classmethod | ||
def from_ply(cls, path, sh_degree=3, **kwargs): | ||
with open(path, "rb") as f: | ||
# Read header. | ||
head = f.readline().decode("utf-8").strip().lower() | ||
if head != "ply": | ||
print(head) | ||
raise ValueError(f"Not a ply file: {head}") | ||
|
||
encoding = f.readline().decode("utf-8").strip().lower() | ||
if "binary_little_endian" not in encoding: | ||
raise ValueError(f"Invalid encoding: {encoding}") | ||
|
||
elements = f.readline().decode("utf-8").strip().lower() | ||
count = int(elements.split()[2]) | ||
|
||
# Read until end of header. | ||
while f.readline().decode("utf-8").strip().lower() != "end_header": | ||
pass | ||
|
||
# Number of 32 bit floats used to encode Spherical Harmonics coefficients. | ||
# The last multiplication by 3 is because we have 3 components (RGB) for each coefficient. | ||
sh_coeffs = (sh_degree + 1) * (sh_degree + 1) * 3 | ||
|
||
# Position (vec3), normal (vec3), spherical harmonics (sh_coeffs), opacity (float), | ||
# scale (vec3) and rotation (quaternion). All values are float32 (4 bytes). | ||
size = count * (3 + 3 + sh_coeffs + 1 + 3 + 4) * 4 | ||
|
||
data = f.read(size) | ||
arr = np.frombuffer(data, dtype=np.float32).reshape((count, -1)) | ||
|
||
# Positions. | ||
position = arr[:, :3].copy() | ||
|
||
# Currently we don't need normals for rendering. | ||
# normal = arr[:, 3:6].copy() | ||
|
||
# Spherical harmonic coefficients. | ||
sh = arr[:, 6 : 6 + sh_coeffs].copy() | ||
|
||
# Activate alpha: sigmoid(alpha). | ||
opacity = 1.0 / (1.0 + np.exp(-arr[:, 6 + sh_coeffs])) | ||
|
||
# Exponentiate scale. | ||
scale = np.exp(arr[:, 7 + sh_coeffs : 10 + sh_coeffs]) | ||
|
||
# Normalize quaternions. | ||
rotation = arr[:, 10 + sh_coeffs : 14 + sh_coeffs].copy() | ||
rotation /= np.linalg.norm(rotation, ord=2, axis=1)[..., np.newaxis] | ||
|
||
# Convert from wxyz to xyzw. | ||
rotation = np.roll(rotation, -1, axis=1) | ||
|
||
return cls(position, sh, opacity, scale, rotation, **kwargs) | ||
|
||
@hooked | ||
def release(self): | ||
# vao and vbos are released by Meshes release method. | ||
if self.is_renderable: | ||
self.splat_positions_buf.release() | ||
self.splat_data_buf.release() | ||
self.splat_views_buf.release() | ||
self.splat_distances_buf.release() | ||
self.splat_sorted_indices_buf.release() | ||
self.vao.release() | ||
self.gpu_sort.release() | ||
self.splat_positions = None | ||
self.splat_shs = None | ||
self.splat_opacities = None | ||
self.splat_scales = None | ||
self.splat_rotations = None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
// Copyright (C) 2023 ETH Zurich, Manuel Kaufmann, Velko Vechev, Dario Mylonopoulos | ||
|
||
struct Splat { | ||
vec3 position; | ||
vec3 color; | ||
float opacity; | ||
vec3 scale; | ||
vec4 rotation; | ||
}; | ||
|
||
struct SplatData { | ||
vec3 color; | ||
float opacity; | ||
vec3 scale; | ||
float _padding; | ||
vec4 rotation; | ||
}; | ||
|
||
struct SplatView { | ||
vec4 position; | ||
vec2 axis1; | ||
vec2 axis2; | ||
vec4 color; | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
#version 450 | ||
|
||
// Copyright (C) 2023 ETH Zurich, Manuel Kaufmann, Velko Vechev, Dario Mylonopoulos | ||
|
||
#include gaussian_splatting/common.glsl | ||
|
||
#if defined VERTEX_SHADER | ||
|
||
layout(std430, binding=2) buffer in_splat_views | ||
{ | ||
SplatView views[]; | ||
} InSplatViews; | ||
|
||
layout(std430, binding=3) buffer in_splat_sorted_indices | ||
{ | ||
uint indices[]; | ||
} InSplatSortedIndices; | ||
|
||
uniform vec2 u_screen_size; | ||
|
||
out vec2 position; | ||
out vec4 color; | ||
|
||
void main() { | ||
uint splat_index = gl_InstanceID; | ||
uint vertex_index = gl_VertexID; | ||
|
||
uint sorted_index = InSplatSortedIndices.indices[splat_index]; | ||
SplatView view = InSplatViews.views[sorted_index]; | ||
if(view.position.w <= 0.0) { | ||
gl_Position.x = uintBitsToFloat(0x7fc00000); // NaN discards the primitive | ||
} else { | ||
vec2 quad = vec2(vertex_index & 1, (vertex_index >> 1) & 1) * 2.0 - 1.0; | ||
quad *= 2.0; | ||
|
||
vec2 delta = (quad.x * view.axis1 + quad.y * view.axis2) * 2.0 / u_screen_size; | ||
vec4 p = view.position; | ||
p.xy += delta * view.position.w; | ||
|
||
color = view.color; | ||
position = quad; | ||
gl_Position = p; | ||
} | ||
} | ||
|
||
#elif defined FRAGMENT_SHADER | ||
|
||
in vec2 position; | ||
in vec4 color; | ||
|
||
out vec4 out_color; | ||
|
||
void main() { | ||
float alpha = clamp(exp(-dot(position, position)) * color.a, 0.0, 1.0); | ||
|
||
if(alpha < 1.0/255.0) { | ||
discard; | ||
} | ||
out_color = vec4(color.rgb, alpha); | ||
} | ||
|
||
|
||
#endif |
Oops, something went wrong.