|
| 1 | +# Copyright (c) Facebook, Inc. and its affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +from typing import Tuple |
| 8 | + |
| 9 | +import torch |
| 10 | + |
| 11 | +# ------------------------ Laplacian Matrices ------------------------ # |
| 12 | +# This file contains implementations of differentiable laplacian matrices. |
| 13 | +# These include |
| 14 | +# 1) Standard Laplacian matrix |
| 15 | +# 2) Cotangent Laplacian matrix |
| 16 | +# 3) Norm Laplacian matrix |
| 17 | +# -------------------------------------------------------------------- # |
| 18 | + |
| 19 | + |
| 20 | +def laplacian(verts: torch.Tensor, edges: torch.Tensor) -> torch.Tensor: |
| 21 | + """ |
| 22 | + Computes the laplacian matrix. |
| 23 | + The definition of the laplacian is |
| 24 | + L[i, j] = -1 , if i == j |
| 25 | + L[i, j] = 1 / deg(i) , if (i, j) is an edge |
| 26 | + L[i, j] = 0 , otherwise |
| 27 | + where deg(i) is the degree of the i-th vertex in the graph. |
| 28 | +
|
| 29 | + Args: |
| 30 | + verts: tensor of shape (V, 3) containing the vertices of the graph |
| 31 | + edges: tensor of shape (E, 2) containing the vertex indices of each edge |
| 32 | + Returns: |
| 33 | + L: Sparse FloatTensor of shape (V, V) |
| 34 | + """ |
| 35 | + V = verts.shape[0] |
| 36 | + |
| 37 | + e0, e1 = edges.unbind(1) |
| 38 | + |
| 39 | + idx01 = torch.stack([e0, e1], dim=1) # (E, 2) |
| 40 | + idx10 = torch.stack([e1, e0], dim=1) # (E, 2) |
| 41 | + idx = torch.cat([idx01, idx10], dim=0).t() # (2, 2*E) |
| 42 | + |
| 43 | + # First, we construct the adjacency matrix, |
| 44 | + # i.e. A[i, j] = 1 if (i,j) is an edge, or |
| 45 | + # A[e0, e1] = 1 & A[e1, e0] = 1 |
| 46 | + ones = torch.ones(idx.shape[1], dtype=torch.float32, device=verts.device) |
| 47 | + A = torch.sparse.FloatTensor(idx, ones, (V, V)) |
| 48 | + |
| 49 | + # the sum of i-th row of A gives the degree of the i-th vertex |
| 50 | + deg = torch.sparse.sum(A, dim=1).to_dense() |
| 51 | + |
| 52 | + # We construct the Laplacian matrix by adding the non diagonal values |
| 53 | + # i.e. L[i, j] = 1 ./ deg(i) if (i, j) is an edge |
| 54 | + deg0 = deg[e0] |
| 55 | + deg0 = torch.where(deg0 > 0.0, 1.0 / deg0, deg0) |
| 56 | + deg1 = deg[e1] |
| 57 | + deg1 = torch.where(deg1 > 0.0, 1.0 / deg1, deg1) |
| 58 | + val = torch.cat([deg0, deg1]) |
| 59 | + L = torch.sparse.FloatTensor(idx, val, (V, V)) |
| 60 | + |
| 61 | + # Then we add the diagonal values L[i, i] = -1. |
| 62 | + idx = torch.arange(V, device=verts.device) |
| 63 | + idx = torch.stack([idx, idx], dim=0) |
| 64 | + ones = torch.ones(idx.shape[1], dtype=torch.float32, device=verts.device) |
| 65 | + L -= torch.sparse.FloatTensor(idx, ones, (V, V)) |
| 66 | + |
| 67 | + return L |
| 68 | + |
| 69 | + |
| 70 | +def cot_laplacian( |
| 71 | + verts: torch.Tensor, faces: torch.Tensor, eps: float = 1e-12 |
| 72 | +) -> Tuple[torch.Tensor, torch.Tensor]: |
| 73 | + """ |
| 74 | + Returns the Laplacian matrix with cotangent weights and the inverse of the |
| 75 | + face areas. |
| 76 | +
|
| 77 | + Args: |
| 78 | + verts: tensor of shape (V, 3) containing the vertices of the graph |
| 79 | + faces: tensor of shape (F, 3) containing the vertex indices of each face |
| 80 | + Returns: |
| 81 | + 2-element tuple containing |
| 82 | + - **L**: Sparse FloatTensor of shape (V,V) for the Laplacian matrix. |
| 83 | + Here, L[i, j] = cot a_ij + cot b_ij iff (i, j) is an edge in meshes. |
| 84 | + See the description above for more clarity. |
| 85 | + - **inv_areas**: FloatTensor of shape (V,) containing the inverse of sum of |
| 86 | + face areas containing each vertex |
| 87 | + """ |
| 88 | + V, F = verts.shape[0], faces.shape[0] |
| 89 | + |
| 90 | + face_verts = verts[faces] |
| 91 | + v0, v1, v2 = face_verts[:, 0], face_verts[:, 1], face_verts[:, 2] |
| 92 | + |
| 93 | + # Side lengths of each triangle, of shape (sum(F_n),) |
| 94 | + # A is the side opposite v1, B is opposite v2, and C is opposite v3 |
| 95 | + A = (v1 - v2).norm(dim=1) |
| 96 | + B = (v0 - v2).norm(dim=1) |
| 97 | + C = (v0 - v1).norm(dim=1) |
| 98 | + |
| 99 | + # Area of each triangle (with Heron's formula); shape is (sum(F_n),) |
| 100 | + s = 0.5 * (A + B + C) |
| 101 | + # note that the area can be negative (close to 0) causing nans after sqrt() |
| 102 | + # we clip it to a small positive value |
| 103 | + area = (s * (s - A) * (s - B) * (s - C)).clamp_(min=eps).sqrt() |
| 104 | + |
| 105 | + # Compute cotangents of angles, of shape (sum(F_n), 3) |
| 106 | + A2, B2, C2 = A * A, B * B, C * C |
| 107 | + cota = (B2 + C2 - A2) / area |
| 108 | + cotb = (A2 + C2 - B2) / area |
| 109 | + cotc = (A2 + B2 - C2) / area |
| 110 | + cot = torch.stack([cota, cotb, cotc], dim=1) |
| 111 | + cot /= 4.0 |
| 112 | + |
| 113 | + # Construct a sparse matrix by basically doing: |
| 114 | + # L[v1, v2] = cota |
| 115 | + # L[v2, v0] = cotb |
| 116 | + # L[v0, v1] = cotc |
| 117 | + ii = faces[:, [1, 2, 0]] |
| 118 | + jj = faces[:, [2, 0, 1]] |
| 119 | + idx = torch.stack([ii, jj], dim=0).view(2, F * 3) |
| 120 | + L = torch.sparse.FloatTensor(idx, cot.view(-1), (V, V)) |
| 121 | + |
| 122 | + # Make it symmetric; this means we are also setting |
| 123 | + # L[v2, v1] = cota |
| 124 | + # L[v0, v2] = cotb |
| 125 | + # L[v1, v0] = cotc |
| 126 | + L += L.t() |
| 127 | + |
| 128 | + # For each vertex, compute the sum of areas for triangles containing it. |
| 129 | + idx = faces.view(-1) |
| 130 | + inv_areas = torch.zeros(V, dtype=torch.float32, device=verts.device) |
| 131 | + val = torch.stack([area] * 3, dim=1).view(-1) |
| 132 | + inv_areas.scatter_add_(0, idx, val) |
| 133 | + idx = inv_areas > 0 |
| 134 | + inv_areas[idx] = 1.0 / inv_areas[idx] |
| 135 | + inv_areas = inv_areas.view(-1, 1) |
| 136 | + |
| 137 | + return L, inv_areas |
| 138 | + |
| 139 | + |
| 140 | +def norm_laplacian( |
| 141 | + verts: torch.Tensor, edges: torch.Tensor, eps: float = 1e-12 |
| 142 | +) -> torch.Tensor: |
| 143 | + """ |
| 144 | + Norm laplacian computes a variant of the laplacian matrix which weights each |
| 145 | + affinity with the normalized distance of the neighboring nodes. |
| 146 | + More concretely, |
| 147 | + L[i, j] = 1. / wij where wij = ||vi - vj|| if (vi, vj) are neighboring nodes |
| 148 | +
|
| 149 | + Args: |
| 150 | + verts: tensor of shape (V, 3) containing the vertices of the graph |
| 151 | + edges: tensor of shape (E, 2) containing the vertex indices of each edge |
| 152 | + Returns: |
| 153 | + L: Sparse FloatTensor of shape (V, V) |
| 154 | + """ |
| 155 | + edge_verts = verts[edges] # (E, 2, 3) |
| 156 | + v0, v1 = edge_verts[:, 0], edge_verts[:, 1] |
| 157 | + |
| 158 | + # Side lengths of each edge, of shape (E,) |
| 159 | + w01 = 1.0 / ((v0 - v1).norm(dim=1) + eps) |
| 160 | + |
| 161 | + # Construct a sparse matrix by basically doing: |
| 162 | + # L[v0, v1] = w01 |
| 163 | + # L[v1, v0] = w01 |
| 164 | + e01 = edges.t() # (2, E) |
| 165 | + |
| 166 | + V = verts.shape[0] |
| 167 | + L = torch.sparse.FloatTensor(e01, w01, (V, V)) |
| 168 | + L = L + L.t() |
| 169 | + |
| 170 | + return L |
0 commit comments