-
Notifications
You must be signed in to change notification settings - Fork 341
/
Copy path__init__.py
197 lines (173 loc) · 5.98 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file.
#
# For inquiries contact [email protected]
#
from typing import NamedTuple
import torch.nn as nn
import torch
from . import _C
def rasterize_gaussians(
means3D,
means2D,
sh,
colors_precomp,
opacities,
scales,
rotations,
cov3Ds_precomp,
raster_settings,
):
return _RasterizeGaussians.apply(
means3D,
means2D,
sh,
colors_precomp,
opacities,
scales,
rotations,
cov3Ds_precomp,
raster_settings,
)
class _RasterizeGaussians(torch.autograd.Function):
@staticmethod
def forward(
ctx,
means3D,
means2D,
sh,
colors_precomp,
opacities,
scales,
rotations,
cov3Ds_precomp,
raster_settings,
):
# Restructure arguments the way that the C++ lib expects them
args = (
raster_settings.bg,
means3D,
colors_precomp,
opacities,
scales,
rotations,
raster_settings.scale_modifier,
cov3Ds_precomp,
raster_settings.viewmatrix,
raster_settings.projmatrix,
raster_settings.tanfovx,
raster_settings.tanfovy,
raster_settings.image_height,
raster_settings.image_width,
sh,
raster_settings.sh_degree,
raster_settings.campos,
raster_settings.prefiltered,
)
# Invoke C++/CUDA rasterizer
# num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args)
num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer, depth = _C.rasterize_gaussians(*args)
# Keep relevant tensors for backward
ctx.raster_settings = raster_settings
ctx.num_rendered = num_rendered
ctx.save_for_backward(colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer)
return color, radii, depth
@staticmethod
def backward(ctx, grad_out_color, _, depth):
# Restore necessary values from context
num_rendered = ctx.num_rendered
raster_settings = ctx.raster_settings
colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer = ctx.saved_tensors
# Restructure args as C++ method expects them
args = (raster_settings.bg,
means3D,
radii,
colors_precomp,
scales,
rotations,
raster_settings.scale_modifier,
cov3Ds_precomp,
raster_settings.viewmatrix,
raster_settings.projmatrix,
raster_settings.tanfovx,
raster_settings.tanfovy,
grad_out_color,
sh,
raster_settings.sh_degree,
raster_settings.campos,
geomBuffer,
num_rendered,
binningBuffer,
imgBuffer)
# Compute gradients for relevant tensors by invoking backward method
grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args)
grads = (
grad_means3D,
grad_means2D,
grad_sh,
grad_colors_precomp,
grad_opacities,
grad_scales,
grad_rotations,
grad_cov3Ds_precomp,
None,
)
return grads
class GaussianRasterizationSettings(NamedTuple):
image_height: int
image_width: int
tanfovx : float
tanfovy : float
bg : torch.Tensor
scale_modifier : float
viewmatrix : torch.Tensor
projmatrix : torch.Tensor
sh_degree : int
campos : torch.Tensor
prefiltered : bool
class GaussianRasterizer(nn.Module):
def __init__(self, raster_settings):
super().__init__()
self.raster_settings = raster_settings
def markVisible(self, positions):
# Mark visible points (based on frustum culling for camera) with a boolean
with torch.no_grad():
raster_settings = self.raster_settings
visible = _C.mark_visible(
positions,
raster_settings.viewmatrix,
raster_settings.projmatrix)
return visible
def forward(self, means3D, means2D, opacities, shs = None, colors_precomp = None, scales = None, rotations = None, cov3D_precomp = None):
raster_settings = self.raster_settings
if (shs is None and colors_precomp is None) or (shs is not None and colors_precomp is not None):
raise Exception('Please provide excatly one of either SHs or precomputed colors!')
if ((scales is None or rotations is None) and cov3D_precomp is None) or ((scales is not None or rotations is not None) and cov3D_precomp is not None):
raise Exception('Please provide exactly one of either scale/rotation pair or precomputed 3D covariance!')
if shs is None:
shs = torch.Tensor([])
if colors_precomp is None:
colors_precomp = torch.Tensor([])
if scales is None:
scales = torch.Tensor([])
if rotations is None:
rotations = torch.Tensor([])
if cov3D_precomp is None:
cov3D_precomp = torch.Tensor([])
# Invoke C++/CUDA rasterization routine
return rasterize_gaussians(
means3D,
means2D,
shs,
colors_precomp,
opacities,
scales,
rotations,
cov3D_precomp,
raster_settings,
)