-
Notifications
You must be signed in to change notification settings - Fork 3
/
GaussianSplattingCUDAExt.jl
73 lines (59 loc) · 2.38 KB
/
GaussianSplattingCUDAExt.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
module GaussianSplattingCUDAExt
# using Adapt
using CUDA
# using cuDNN
# using KernelAbstractions
using GaussianSplatting
# using PrecompileTools
# using Statistics
# using Zygote
function GaussianSplatting.allocate_pinned(::CUDABackend, ::Type{T}, shape) where T
x = Array{T}(undef, shape)
buf = CUDA.register(CUDA.HostMemory, pointer(x), sizeof(x),
CUDA.MEMHOSTREGISTER_DEVICEMAP)
xptr = convert(CuPtr{Float32}, buf)
xd = unsafe_wrap(CuArray, xptr, size(x))
return x, xd
end
function GaussianSplatting.unpin_memory(x::CuArray)
CUDA.unregister(x.data.rc.obj.mem)
return
end
# @setup_workload let
# kab = GaussianSplatting.gpu_backend()
# # TODO KernelAbstractions.functional(kab)
# (kab isa CUDABackend && CUDA.functional()) || return
# @info "Precompiling for `$kab` GPU backend."
# points = adapt(kab, rand(Float32, 3, 128))
# colors = adapt(kab, rand(Float32, 3, 128))
# scales = adapt(kab, rand(Float32, 3, 128))
# camera = GaussianSplatting.Camera(; fx=100f0, fy=100f0, width=256, height=256)
# opt_params = GaussianSplatting.OptimizationParams()
# gaussians = GaussianSplatting.GaussianModel(points, colors, scales; max_sh_degree=0)
# rasterizer = GaussianSplatting.GaussianRasterizer(kab, camera; auxiliary=false)
# ssim = GaussianSplatting.SSIM(kab)
# θ = (
# gaussians.points, gaussians.features_dc, gaussians.features_rest,
# gaussians.opacities, gaussians.scales, gaussians.rotations)
# target_image = adapt(kab, rand(Float32, 256, 256, 3, 1))
# @compile_workload begin
# Zygote.gradient(
# θ...,
# ) do means_3d, features_dc, features_rest, opacities, scales, rotations
# shs = isempty(features_rest) ?
# features_dc : hcat(features_dc, features_rest)
# img = rasterizer(
# means_3d, opacities, scales, rotations, shs;
# camera, sh_degree=gaussians.sh_degree)
# # From (c, w, h) to (w, h, c, 1) for SSIM.
# img_tmp = permutedims(img, (2, 3, 1))
# img_eval = reshape(img_tmp, size(img_tmp)..., 1)
# l1 = mean(abs.(img_eval .- target_image))
# s = 1f0 - ssim(img_eval, target_image)
# (1f0 - opt_params.λ_dssim) * l1 + opt_params.λ_dssim * s
# end
# end
# @info "Done precompiling!"
# return
# end
end