Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PR to integrate Gaussian Splatting #50

Merged
merged 11 commits into from
Mar 19, 2024
10 changes: 5 additions & 5 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@ jobs:
black:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: psf/black@stable
- uses: actions/checkout@v4
- uses: psf/black@23.3.0
isort:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: 3.7
- uses: isort/isort-action@master
- uses: isort/isort-action@master
235 changes: 235 additions & 0 deletions aitviewer/renderables/gaussian_splats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
import moderngl
import numpy as np
from moderngl_window.opengl.vao import VAO

from aitviewer.scene.camera import CameraInterface
from aitviewer.scene.node import Node
from aitviewer.shaders import (
get_gaussian_splat_draw_program,
get_gaussian_splat_prepare_program,
)
from aitviewer.utils.decorators import hooked
from aitviewer.utils.gpu_sort import GpuSort


class GaussianSplats(Node):
PREPARE_GROUP_SIZE = 128

def __init__(self, splat_positions, splat_shs, splat_opacities, splat_scales, splat_rotations, **kwargs):
super().__init__(**kwargs)

self.num_splats = splat_positions.shape[0]
self.splat_positions: np.ndarray = splat_positions
self.splat_shs: np.ndarray = splat_shs
self.splat_opacities: np.ndarray = splat_opacities
self.splat_scales: np.ndarray = splat_scales
self.splat_rotations: np.ndarray = splat_rotations

self.splat_opacity_scale = 1.0
self.splat_size_scale = 1.0

self.backface_culling = False

self._debug_gui = False

def is_transparent(self):
return True

@property
def bounds(self):
return self.current_bounds

@property
def current_bounds(self):
return self.get_bounds(self.splat_positions)

@Node.once
def make_renderable(self, ctx: moderngl.Context):
self.prog_prepare = get_gaussian_splat_prepare_program(self.PREPARE_GROUP_SIZE)
self.prog_draw = get_gaussian_splat_draw_program()

# Buffer for splat positions.
self.splat_positions_buf = ctx.buffer(self.splat_positions.astype(np.float32).tobytes())

# Buffer for other splat data: SHs, opacity, scale, rotation.
#
# TODO: In theory we could pre-process rotations and scales and store them
# as a 6 element covariance matrix directly here.
#
# TODO: Currently this only renders with base colors (first spherical harmonics coefficient)
# We need to unswizzle the other coefficients and evaluate them for rendering here.
splat_data = np.hstack(
(
self.splat_shs[:, :3] * 0.2820948 + 0.5,
self.splat_opacities[:, np.newaxis],
self.splat_scales,
np.zeros((self.num_splats, 1), np.float32),
self.splat_rotations,
)
)
self.splat_data_buf = ctx.buffer(splat_data.astype(np.float32).tobytes())

# Buffer for splat views.
self.splat_views_buf = ctx.buffer(None, reserve=self.num_splats * 48)

# Buffers for distances and sorted indices.
self.splat_distances_buf = ctx.buffer(None, reserve=self.num_splats * 4)
self.splat_sorted_indices_buf = ctx.buffer(None, reserve=self.num_splats * 4)

# Create a vao for rendering a single quad.
indices = np.array((0, 1, 2, 1, 3, 2), np.uint32)
self.vbo_indices = ctx.buffer(indices.tobytes())
self.vao = VAO()
self.vao.index_buffer(self.vbo_indices)

self.gpu_sort = GpuSort(ctx, self.num_splats)

# Time queries for profiling.
self.time_queries = {
"prepare": ctx.query(time=True),
"sort": ctx.query(time=True),
"draw": ctx.query(time=True),
}
self.ctx = ctx

def render(self, camera: CameraInterface, **kwargs):
# Convert gaussians from 3D to 2D quads.
with self.time_queries["prepare"]:
self.splat_positions_buf.bind_to_storage_buffer(0)
self.splat_data_buf.bind_to_storage_buffer(1)
self.splat_views_buf.bind_to_storage_buffer(2)
self.splat_distances_buf.bind_to_storage_buffer(3)
self.splat_sorted_indices_buf.bind_to_storage_buffer(4)

self.prog_prepare["u_opacity_scale"] = self.splat_opacity_scale
self.prog_prepare["u_scale2"] = np.square(self.splat_size_scale)

self.prog_prepare["u_num_splats"] = self.num_splats
V = camera.get_view_matrix()
P = camera.get_projection_matrix()
self.prog_prepare["u_limit"] = 1.3 / P[0, 0]
self.prog_prepare["u_focal"] = kwargs["window_size"][0] * P[0, 0] * 0.5

self.prog_prepare["u_world_from_object"].write(self.model_matrix.T.astype("f4").tobytes())
self.prog_prepare["u_view_from_world"].write(V.T.astype("f4").tobytes())
self.prog_prepare["u_clip_from_world"].write((P @ V).astype("f4").T.tobytes())

num_groups = (self.num_splats + self.PREPARE_GROUP_SIZE - 1) // self.PREPARE_GROUP_SIZE
self.prog_prepare.run(num_groups, 1, 1)
self.ctx.memory_barrier()

# Sort splats based on distance to camera plane.
with self.time_queries["sort"]:
self.gpu_sort.run(self.ctx, self.splat_distances_buf, self.splat_sorted_indices_buf)

# Render each splat as a 2D quad with instancing.
with self.time_queries["draw"]:
self.splat_views_buf.bind_to_storage_buffer(2)
self.splat_sorted_indices_buf.bind_to_storage_buffer(3)

self.prog_draw["u_screen_size"].write(np.array(kwargs["window_size"], np.float32).tobytes())

kwargs["fbo"].depth_mask = False
self.vao.render(self.prog_draw, moderngl.TRIANGLES, 6, 0, self.num_splats)
kwargs["fbo"].depth_mask = True

def gui(self, imgui):
if self._debug_gui:
# Draw debugging stats about marching cubes and rendering time.
total = 0
for k, v in self.time_queries.items():
imgui.text(f"{k}: {v.elapsed * 1e-6:5.3f}ms")
total += v.elapsed * 1e-6
imgui.text(f"Total: {total: 5.3f}ms")

_, self.splat_size_scale = imgui.drag_float(
"Size",
self.splat_size_scale,
1e-2,
min_value=0.001,
max_value=10.0,
format="%.3f",
)

_, self.splat_opacity_scale = imgui.drag_float(
"Opacity",
self.splat_opacity_scale,
1e-2,
min_value=0.001,
max_value=10.0,
format="%.3f",
)

super().gui(imgui)

@classmethod
def from_ply(cls, path, sh_degree=3, **kwargs):
with open(path, "rb") as f:
# Read header.
head = f.readline().decode("utf-8").strip().lower()
if head != "ply":
print(head)
raise ValueError(f"Not a ply file: {head}")

encoding = f.readline().decode("utf-8").strip().lower()
if "binary_little_endian" not in encoding:
raise ValueError(f"Invalid encoding: {encoding}")

elements = f.readline().decode("utf-8").strip().lower()
count = int(elements.split()[2])

# Read until end of header.
while f.readline().decode("utf-8").strip().lower() != "end_header":
pass

# Number of 32 bit floats used to encode Spherical Harmonics coefficients.
# The last multiplication by 3 is because we have 3 components (RGB) for each coefficient.
sh_coeffs = (sh_degree + 1) * (sh_degree + 1) * 3

# Position (vec3), normal (vec3), spherical harmonics (sh_coeffs), opacity (float),
# scale (vec3) and rotation (quaternion). All values are float32 (4 bytes).
size = count * (3 + 3 + sh_coeffs + 1 + 3 + 4) * 4

data = f.read(size)
arr = np.frombuffer(data, dtype=np.float32).reshape((count, -1))

# Positions.
position = arr[:, :3].copy()

# Currently we don't need normals for rendering.
# normal = arr[:, 3:6].copy()

# Spherical harmonic coefficients.
sh = arr[:, 6 : 6 + sh_coeffs].copy()

# Activate alpha: sigmoid(alpha).
opacity = 1.0 / (1.0 + np.exp(-arr[:, 6 + sh_coeffs]))

# Exponentiate scale.
scale = np.exp(arr[:, 7 + sh_coeffs : 10 + sh_coeffs])

# Normalize quaternions.
rotation = arr[:, 10 + sh_coeffs : 14 + sh_coeffs].copy()
rotation /= np.linalg.norm(rotation, ord=2, axis=1)[..., np.newaxis]

# Convert from wxyz to xyzw.
rotation = np.roll(rotation, -1, axis=1)

return cls(position, sh, opacity, scale, rotation, **kwargs)

@hooked
def release(self):
# vao and vbos are released by Meshes release method.
if self.is_renderable:
self.splat_positions_buf.release()
self.splat_data_buf.release()
self.splat_views_buf.release()
self.splat_distances_buf.release()
self.splat_sorted_indices_buf.release()
self.vao.release()
self.gpu_sort.release()
self.splat_positions = None
self.splat_shs = None
self.splat_opacities = None
self.splat_scales = None
self.splat_rotations = None
13 changes: 13 additions & 0 deletions aitviewer/renderables/meshes.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,19 @@ def gui(self, imgui):
if imgui.button("Show Normals ##show_normals{}".format(self.unique_name)):
self._show_normals()

@hooked
def gui_stats(self, imgui):
for k, v in {
"Vertices": self.vertices.shape[1],
"Faces": self.faces.shape[0],
"Vertex colors": "yes" if self._vertex_colors is not None else "no",
"UVs": "yes" if self.uv_coords is not None else "no",
"Face normals": "yes" if self._face_normals is not None else "no",
"Face colors": "yes" if self._face_colors is not None else "no",
"Texture": "yes" if self.has_texture else "no",
}.items():
imgui.text(f"{k}: {v}")

def gui_context_menu(self, imgui, x: int, y: int):
_, self.flat_shading = imgui.menu_item("Flat shading", "F", selected=self.flat_shading, enabled=True)
_, self.draw_edges = imgui.menu_item("Draw edges", "E", selected=self.draw_edges, enabled=True)
Expand Down
60 changes: 36 additions & 24 deletions aitviewer/scene/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,36 +327,39 @@ def current_frame_id(self):
def current_frame_id(self, frame_id):
# Check if the frame changed.
last_frame_id = self._current_frame_id if self._enabled_frames is None else self._internal_frame_id
if self.n_frames == 1 or frame_id == last_frame_id:
return

self.on_before_frame_update()
if self._enabled_frames is None:
if frame_id < 0:
self._current_frame_id = 0
elif frame_id >= len(self):
self._current_frame_id = len(self) - 1
updated = self.n_frames != 1 and frame_id != last_frame_id
if updated:
self.on_before_frame_update()
if self._enabled_frames is None:
if frame_id < 0:
self._current_frame_id = 0
elif frame_id >= len(self):
self._current_frame_id = len(self) - 1
else:
self._current_frame_id = frame_id
else:
self._current_frame_id = frame_id
else:
# If an enabled_frames is present use it to get the current frame.
if frame_id < 0:
self._internal_frame_id = 0
elif frame_id >= self._enabled_frames.shape[0]:
self._internal_frame_id = self._enabled_frames.shape[0] - 1
else:
self._internal_frame_id = frame_id
self._current_frame_id = self._enabled_frame_id[self._internal_frame_id]
# Update enabled using the mask.
self.enabled = self._enabled_frames[self._internal_frame_id]
# If an enabled_frames is present use it to get the current frame.
if frame_id < 0:
self._internal_frame_id = 0
elif frame_id >= self._enabled_frames.shape[0]:
self._internal_frame_id = self._enabled_frames.shape[0] - 1
else:
self._internal_frame_id = frame_id
self._current_frame_id = self._enabled_frame_id[self._internal_frame_id]
# Update enabled using the mask.
self.enabled = self._enabled_frames[self._internal_frame_id]

# Update frame id of all children nodes.
for n in self.nodes:
n.current_frame_id = self._current_frame_id
n.current_frame_id = frame_id

self.on_frame_update()
if self.parent and (self._positions.shape[0] > 1 or self._rotations.shape[0] > 1 or self._scales.shape[0] > 1):
self.update_transform(self.parent.model_matrix)
if updated:
self.on_frame_update()
if self.parent and (
self._positions.shape[0] > 1 or self._rotations.shape[0] > 1 or self._scales.shape[0] > 1
):
self.update_transform(self.parent.model_matrix)

def next_frame(self):
self.current_frame_id = self.current_frame_id + 1 if self.current_frame_id < len(self) - 1 else 0
Expand Down Expand Up @@ -440,6 +443,15 @@ def gui(self, imgui):
"""
pass

def gui_stats(self, imgui):
"""
Render GUI for stats about the node, rendered at the bottom of the scene hierarchy.
Should be implemented with @hooked decorator to print stats from parent classes.
:param imgui: imgui context.
See https://pyimgui.readthedocs.io/en/latest/reference/imgui.core.html for available elements to render
"""
imgui.text(f"Frames: {self.n_frames}")

def gui_modes(self, imgui):
"""Render GUI with toolbar (tools) for this particular node"""

Expand Down
9 changes: 9 additions & 0 deletions aitviewer/scene/scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,15 @@ def gui_selected(self, imgui):
imgui.spacing()
imgui.spacing()

imgui.spacing()
imgui.spacing()
imgui.spacing()
imgui.same_line(spacing=8)
imgui.begin_group()
if imgui.collapsing_header("Stats")[0]:
s.gui_stats(imgui)
imgui.end_group()

def gui(self, imgui):
imgui.text(f"FPS: {self.fps:.1f}")
# Background color
Expand Down
Loading
Loading