Skip to content

Commit 07a5a68

Browse files
gkioxarifacebook-github-bot
authored andcommitted
refactor laplacian matrices
Summary: Refactor of all functions to compute laplacian matrices in one file. Support for: * Standard Laplacian * Cotangent Laplacian * Norm Laplacian Reviewed By: nikhilaravi Differential Revision: D29297466 fbshipit-source-id: b96b88915ce8ef0c2f5693ec9b179fd27b70abf9
1 parent da9974b commit 07a5a68

8 files changed

+298
-198
lines changed

pytorch3d/loss/mesh_laplacian_smoothing.py

+3-71
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77

88
import torch
9+
from pytorch3d.ops import cot_laplacian
910

1011

1112
def mesh_laplacian_smoothing(meshes, method: str = "uniform"):
@@ -94,6 +95,7 @@ def mesh_laplacian_smoothing(meshes, method: str = "uniform"):
9495

9596
N = len(meshes)
9697
verts_packed = meshes.verts_packed() # (sum(V_n), 3)
98+
faces_packed = meshes.faces_packed() # (sum(F_n), 3)
9799
num_verts_per_mesh = meshes.num_verts_per_mesh() # (N,)
98100
verts_packed_idx = meshes.verts_packed_to_mesh_idx() # (sum(V_n),)
99101
weights = num_verts_per_mesh.gather(0, verts_packed_idx) # (sum(V_n),)
@@ -106,7 +108,7 @@ def mesh_laplacian_smoothing(meshes, method: str = "uniform"):
106108
if method == "uniform":
107109
L = meshes.laplacian_packed()
108110
elif method in ["cot", "cotcurv"]:
109-
L, inv_areas = laplacian_cot(meshes)
111+
L, inv_areas = cot_laplacian(verts_packed, faces_packed)
110112
if method == "cot":
111113
norm_w = torch.sparse.sum(L, dim=1).to_dense().view(-1, 1)
112114
idx = norm_w > 0
@@ -127,73 +129,3 @@ def mesh_laplacian_smoothing(meshes, method: str = "uniform"):
127129

128130
loss = loss * weights
129131
return loss.sum() / N
130-
131-
132-
def laplacian_cot(meshes):
133-
"""
134-
Returns the Laplacian matrix with cotangent weights and the inverse of the
135-
face areas.
136-
137-
Args:
138-
meshes: Meshes object with a batch of meshes.
139-
Returns:
140-
2-element tuple containing
141-
- **L**: FloatTensor of shape (V,V) for the Laplacian matrix (V = sum(V_n))
142-
Here, L[i, j] = cot a_ij + cot b_ij iff (i, j) is an edge in meshes.
143-
See the description above for more clarity.
144-
- **inv_areas**: FloatTensor of shape (V,) containing the inverse of sum of
145-
face areas containing each vertex
146-
"""
147-
verts_packed = meshes.verts_packed() # (sum(V_n), 3)
148-
faces_packed = meshes.faces_packed() # (sum(F_n), 3)
149-
# V = sum(V_n), F = sum(F_n)
150-
V, F = verts_packed.shape[0], faces_packed.shape[0]
151-
152-
face_verts = verts_packed[faces_packed]
153-
v0, v1, v2 = face_verts[:, 0], face_verts[:, 1], face_verts[:, 2]
154-
155-
# Side lengths of each triangle, of shape (sum(F_n),)
156-
# A is the side opposite v1, B is opposite v2, and C is opposite v3
157-
A = (v1 - v2).norm(dim=1)
158-
B = (v0 - v2).norm(dim=1)
159-
C = (v0 - v1).norm(dim=1)
160-
161-
# Area of each triangle (with Heron's formula); shape is (sum(F_n),)
162-
s = 0.5 * (A + B + C)
163-
# note that the area can be negative (close to 0) causing nans after sqrt()
164-
# we clip it to a small positive value
165-
area = (s * (s - A) * (s - B) * (s - C)).clamp_(min=1e-12).sqrt()
166-
167-
# Compute cotangents of angles, of shape (sum(F_n), 3)
168-
A2, B2, C2 = A * A, B * B, C * C
169-
cota = (B2 + C2 - A2) / area
170-
cotb = (A2 + C2 - B2) / area
171-
cotc = (A2 + B2 - C2) / area
172-
cot = torch.stack([cota, cotb, cotc], dim=1)
173-
cot /= 4.0
174-
175-
# Construct a sparse matrix by basically doing:
176-
# L[v1, v2] = cota
177-
# L[v2, v0] = cotb
178-
# L[v0, v1] = cotc
179-
ii = faces_packed[:, [1, 2, 0]]
180-
jj = faces_packed[:, [2, 0, 1]]
181-
idx = torch.stack([ii, jj], dim=0).view(2, F * 3)
182-
L = torch.sparse.FloatTensor(idx, cot.view(-1), (V, V))
183-
184-
# Make it symmetric; this means we are also setting
185-
# L[v2, v1] = cota
186-
# L[v0, v2] = cotb
187-
# L[v1, v0] = cotc
188-
L += L.t()
189-
190-
# For each vertex, compute the sum of areas for triangles containing it.
191-
idx = faces_packed.view(-1)
192-
inv_areas = torch.zeros(V, dtype=torch.float32, device=meshes.device)
193-
val = torch.stack([area] * 3, dim=1).view(-1)
194-
inv_areas.scatter_add_(0, idx, val)
195-
idx = inv_areas > 0
196-
inv_areas[idx] = 1.0 / inv_areas[idx]
197-
inv_areas = inv_areas.view(-1, 1)
198-
199-
return L, inv_areas

pytorch3d/ops/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .graph_conv import GraphConv
1010
from .interp_face_attrs import interpolate_face_attributes
1111
from .knn import knn_gather, knn_points
12+
from .laplacian_matrices import laplacian, cot_laplacian, norm_laplacian
1213
from .mesh_face_areas_normals import mesh_face_areas_normals
1314
from .mesh_filtering import taubin_smoothing
1415
from .packed_to_padded import packed_to_padded, padded_to_packed

pytorch3d/ops/laplacian_matrices.py

+170
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Tuple
8+
9+
import torch
10+
11+
# ------------------------ Laplacian Matrices ------------------------ #
12+
# This file contains implementations of differentiable laplacian matrices.
13+
# These include
14+
# 1) Standard Laplacian matrix
15+
# 2) Cotangent Laplacian matrix
16+
# 3) Norm Laplacian matrix
17+
# -------------------------------------------------------------------- #
18+
19+
20+
def laplacian(verts: torch.Tensor, edges: torch.Tensor) -> torch.Tensor:
21+
"""
22+
Computes the laplacian matrix.
23+
The definition of the laplacian is
24+
L[i, j] = -1 , if i == j
25+
L[i, j] = 1 / deg(i) , if (i, j) is an edge
26+
L[i, j] = 0 , otherwise
27+
where deg(i) is the degree of the i-th vertex in the graph.
28+
29+
Args:
30+
verts: tensor of shape (V, 3) containing the vertices of the graph
31+
edges: tensor of shape (E, 2) containing the vertex indices of each edge
32+
Returns:
33+
L: Sparse FloatTensor of shape (V, V)
34+
"""
35+
V = verts.shape[0]
36+
37+
e0, e1 = edges.unbind(1)
38+
39+
idx01 = torch.stack([e0, e1], dim=1) # (E, 2)
40+
idx10 = torch.stack([e1, e0], dim=1) # (E, 2)
41+
idx = torch.cat([idx01, idx10], dim=0).t() # (2, 2*E)
42+
43+
# First, we construct the adjacency matrix,
44+
# i.e. A[i, j] = 1 if (i,j) is an edge, or
45+
# A[e0, e1] = 1 & A[e1, e0] = 1
46+
ones = torch.ones(idx.shape[1], dtype=torch.float32, device=verts.device)
47+
A = torch.sparse.FloatTensor(idx, ones, (V, V))
48+
49+
# the sum of i-th row of A gives the degree of the i-th vertex
50+
deg = torch.sparse.sum(A, dim=1).to_dense()
51+
52+
# We construct the Laplacian matrix by adding the non diagonal values
53+
# i.e. L[i, j] = 1 ./ deg(i) if (i, j) is an edge
54+
deg0 = deg[e0]
55+
deg0 = torch.where(deg0 > 0.0, 1.0 / deg0, deg0)
56+
deg1 = deg[e1]
57+
deg1 = torch.where(deg1 > 0.0, 1.0 / deg1, deg1)
58+
val = torch.cat([deg0, deg1])
59+
L = torch.sparse.FloatTensor(idx, val, (V, V))
60+
61+
# Then we add the diagonal values L[i, i] = -1.
62+
idx = torch.arange(V, device=verts.device)
63+
idx = torch.stack([idx, idx], dim=0)
64+
ones = torch.ones(idx.shape[1], dtype=torch.float32, device=verts.device)
65+
L -= torch.sparse.FloatTensor(idx, ones, (V, V))
66+
67+
return L
68+
69+
70+
def cot_laplacian(
71+
verts: torch.Tensor, faces: torch.Tensor, eps: float = 1e-12
72+
) -> Tuple[torch.Tensor, torch.Tensor]:
73+
"""
74+
Returns the Laplacian matrix with cotangent weights and the inverse of the
75+
face areas.
76+
77+
Args:
78+
verts: tensor of shape (V, 3) containing the vertices of the graph
79+
faces: tensor of shape (F, 3) containing the vertex indices of each face
80+
Returns:
81+
2-element tuple containing
82+
- **L**: Sparse FloatTensor of shape (V,V) for the Laplacian matrix.
83+
Here, L[i, j] = cot a_ij + cot b_ij iff (i, j) is an edge in meshes.
84+
See the description above for more clarity.
85+
- **inv_areas**: FloatTensor of shape (V,) containing the inverse of sum of
86+
face areas containing each vertex
87+
"""
88+
V, F = verts.shape[0], faces.shape[0]
89+
90+
face_verts = verts[faces]
91+
v0, v1, v2 = face_verts[:, 0], face_verts[:, 1], face_verts[:, 2]
92+
93+
# Side lengths of each triangle, of shape (sum(F_n),)
94+
# A is the side opposite v1, B is opposite v2, and C is opposite v3
95+
A = (v1 - v2).norm(dim=1)
96+
B = (v0 - v2).norm(dim=1)
97+
C = (v0 - v1).norm(dim=1)
98+
99+
# Area of each triangle (with Heron's formula); shape is (sum(F_n),)
100+
s = 0.5 * (A + B + C)
101+
# note that the area can be negative (close to 0) causing nans after sqrt()
102+
# we clip it to a small positive value
103+
area = (s * (s - A) * (s - B) * (s - C)).clamp_(min=eps).sqrt()
104+
105+
# Compute cotangents of angles, of shape (sum(F_n), 3)
106+
A2, B2, C2 = A * A, B * B, C * C
107+
cota = (B2 + C2 - A2) / area
108+
cotb = (A2 + C2 - B2) / area
109+
cotc = (A2 + B2 - C2) / area
110+
cot = torch.stack([cota, cotb, cotc], dim=1)
111+
cot /= 4.0
112+
113+
# Construct a sparse matrix by basically doing:
114+
# L[v1, v2] = cota
115+
# L[v2, v0] = cotb
116+
# L[v0, v1] = cotc
117+
ii = faces[:, [1, 2, 0]]
118+
jj = faces[:, [2, 0, 1]]
119+
idx = torch.stack([ii, jj], dim=0).view(2, F * 3)
120+
L = torch.sparse.FloatTensor(idx, cot.view(-1), (V, V))
121+
122+
# Make it symmetric; this means we are also setting
123+
# L[v2, v1] = cota
124+
# L[v0, v2] = cotb
125+
# L[v1, v0] = cotc
126+
L += L.t()
127+
128+
# For each vertex, compute the sum of areas for triangles containing it.
129+
idx = faces.view(-1)
130+
inv_areas = torch.zeros(V, dtype=torch.float32, device=verts.device)
131+
val = torch.stack([area] * 3, dim=1).view(-1)
132+
inv_areas.scatter_add_(0, idx, val)
133+
idx = inv_areas > 0
134+
inv_areas[idx] = 1.0 / inv_areas[idx]
135+
inv_areas = inv_areas.view(-1, 1)
136+
137+
return L, inv_areas
138+
139+
140+
def norm_laplacian(
141+
verts: torch.Tensor, edges: torch.Tensor, eps: float = 1e-12
142+
) -> torch.Tensor:
143+
"""
144+
Norm laplacian computes a variant of the laplacian matrix which weights each
145+
affinity with the normalized distance of the neighboring nodes.
146+
More concretely,
147+
L[i, j] = 1. / wij where wij = ||vi - vj|| if (vi, vj) are neighboring nodes
148+
149+
Args:
150+
verts: tensor of shape (V, 3) containing the vertices of the graph
151+
edges: tensor of shape (E, 2) containing the vertex indices of each edge
152+
Returns:
153+
L: Sparse FloatTensor of shape (V, V)
154+
"""
155+
edge_verts = verts[edges] # (E, 2, 3)
156+
v0, v1 = edge_verts[:, 0], edge_verts[:, 1]
157+
158+
# Side lengths of each edge, of shape (E,)
159+
w01 = 1.0 / ((v0 - v1).norm(dim=1) + eps)
160+
161+
# Construct a sparse matrix by basically doing:
162+
# L[v0, v1] = w01
163+
# L[v1, v0] = w01
164+
e01 = edges.t() # (2, E)
165+
166+
V = verts.shape[0]
167+
L = torch.sparse.FloatTensor(e01, w01, (V, V))
168+
L = L + L.t()
169+
170+
return L

pytorch3d/ops/mesh_filtering.py

+1-29
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import torch
8+
from pytorch3d.ops import norm_laplacian
89
from pytorch3d.structures import Meshes, utils as struct_utils
910

1011

@@ -19,35 +20,6 @@
1920
# ----------------------- Taubin Smoothing ----------------------- #
2021

2122

22-
def norm_laplacian(verts: torch.Tensor, edges: torch.Tensor, eps: float = 1e-12):
23-
"""
24-
Norm laplacian computes a variant of the laplacian matrix which weights each
25-
affinity with the normalized distance of the neighboring nodes.
26-
More concretely,
27-
L[i, j] = 1. / wij where wij = ||vi - vj|| if (vi, vj) are neighboring nodes
28-
29-
Args:
30-
verts: tensor of shape (V, 3) containing the vertices of the graph
31-
edges: tensor of shape (E, 2) containing the vertex indices of each edge
32-
"""
33-
edge_verts = verts[edges] # (E, 2, 3)
34-
v0, v1 = edge_verts[:, 0], edge_verts[:, 1]
35-
36-
# Side lengths of each edge, of shape (E,)
37-
w01 = 1.0 / ((v0 - v1).norm(dim=1) + eps)
38-
39-
# Construct a sparse matrix by basically doing:
40-
# L[v0, v1] = w01
41-
# L[v1, v0] = w01
42-
e01 = edges.t() # (2, E)
43-
44-
V = verts.shape[0]
45-
L = torch.sparse.FloatTensor(e01, w01, (V, V))
46-
L = L + L.t()
47-
48-
return L
49-
50-
5123
def taubin_smoothing(
5224
meshes: Meshes, lambd: float = 0.53, mu: float = -0.53, num_iter: int = 10
5325
) -> Meshes:

pytorch3d/structures/meshes.py

+4-33
Original file line numberDiff line numberDiff line change
@@ -1142,6 +1142,8 @@ def _compute_laplacian_packed(self, refresh: bool = False):
11421142
Sparse FloatTensor of shape (V, V) where V = sum(V_n)
11431143
11441144
"""
1145+
from ..ops import laplacian
1146+
11451147
if not (refresh or self._laplacian_packed is None):
11461148
return
11471149

@@ -1153,39 +1155,8 @@ def _compute_laplacian_packed(self, refresh: bool = False):
11531155

11541156
verts_packed = self.verts_packed() # (sum(V_n), 3)
11551157
edges_packed = self.edges_packed() # (sum(E_n), 3)
1156-
V = verts_packed.shape[0] # sum(V_n)
1157-
1158-
e0, e1 = edges_packed.unbind(1)
1159-
1160-
idx01 = torch.stack([e0, e1], dim=1) # (sum(E_n), 2)
1161-
idx10 = torch.stack([e1, e0], dim=1) # (sum(E_n), 2)
1162-
idx = torch.cat([idx01, idx10], dim=0).t() # (2, 2*sum(E_n))
1163-
1164-
# First, we construct the adjacency matrix,
1165-
# i.e. A[i, j] = 1 if (i,j) is an edge, or
1166-
# A[e0, e1] = 1 & A[e1, e0] = 1
1167-
ones = torch.ones(idx.shape[1], dtype=torch.float32, device=self.device)
1168-
A = torch.sparse.FloatTensor(idx, ones, (V, V))
1169-
1170-
# the sum of i-th row of A gives the degree of the i-th vertex
1171-
deg = torch.sparse.sum(A, dim=1).to_dense()
1172-
1173-
# We construct the Laplacian matrix by adding the non diagonal values
1174-
# i.e. L[i, j] = 1 ./ deg(i) if (i, j) is an edge
1175-
deg0 = deg[e0]
1176-
deg0 = torch.where(deg0 > 0.0, 1.0 / deg0, deg0)
1177-
deg1 = deg[e1]
1178-
deg1 = torch.where(deg1 > 0.0, 1.0 / deg1, deg1)
1179-
val = torch.cat([deg0, deg1])
1180-
L = torch.sparse.FloatTensor(idx, val, (V, V))
1181-
1182-
# Then we add the diagonal values L[i, i] = -1.
1183-
idx = torch.arange(V, device=self.device)
1184-
idx = torch.stack([idx, idx], dim=0)
1185-
ones = torch.ones(idx.shape[1], dtype=torch.float32, device=self.device)
1186-
L -= torch.sparse.FloatTensor(idx, ones, (V, V))
1187-
1188-
self._laplacian_packed = L
1158+
1159+
self._laplacian_packed = laplacian(verts_packed, edges_packed)
11891160

11901161
def clone(self):
11911162
"""

0 commit comments

Comments
 (0)