Skip to content

Commit 365945b

Browse files
davnov134facebook-github-bot
authored andcommitted
Pointcloud normals estimation.
Summary: Estimates normals of a point cloud. Reviewed By: gkioxari Differential Revision: D20860182 fbshipit-source-id: 652ec2743fa645e02c01ffa37c2971bf27b89cef
1 parent 8abbe22 commit 365945b

File tree

7 files changed

+482
-10
lines changed

7 files changed

+482
-10
lines changed

pytorch3d/io/ply_io.py

+41-8
Original file line numberDiff line numberDiff line change
@@ -700,24 +700,38 @@ def load_ply(f):
700700
return verts, faces
701701

702702

703-
def _save_ply(f, verts, faces, decimal_places: Optional[int]) -> None:
703+
def _save_ply(
704+
f,
705+
verts: torch.Tensor,
706+
faces: torch.LongTensor,
707+
verts_normals: torch.Tensor,
708+
decimal_places: Optional[int] = None,
709+
) -> None:
704710
"""
705-
Internal implementation for saving a mesh to a .ply file.
711+
Internal implementation for saving 3D data to a .ply file.
706712
707713
Args:
708-
f: File object to which the mesh should be written.
714+
f: File object to which the 3D data should be written.
709715
verts: FloatTensor of shape (V, 3) giving vertex coordinates.
710-
faces: LongTensor of shape (F, 3) giving faces.
716+
faces: LongTensor of shsape (F, 3) giving faces.
717+
verts_normals: FloatTensor of shape (V, 3) giving vertex normals.
711718
decimal_places: Number of decimal places for saving.
712719
"""
713720
assert not len(verts) or (verts.dim() == 2 and verts.size(1) == 3)
714721
assert not len(faces) or (faces.dim() == 2 and faces.size(1) == 3)
722+
assert not len(verts_normals) or (
723+
verts_normals.dim() == 2 and verts_normals.size(1) == 3
724+
)
715725

716726
print("ply\nformat ascii 1.0", file=f)
717727
print(f"element vertex {verts.shape[0]}", file=f)
718728
print("property float x", file=f)
719729
print("property float y", file=f)
720730
print("property float z", file=f)
731+
if verts_normals.numel() > 0:
732+
print("property float nx", file=f)
733+
print("property float ny", file=f)
734+
print("property float nz", file=f)
721735
print(f"element face {faces.shape[0]}", file=f)
722736
print("property list uchar int vertex_index", file=f)
723737
print("end_header", file=f)
@@ -731,8 +745,8 @@ def _save_ply(f, verts, faces, decimal_places: Optional[int]) -> None:
731745
else:
732746
float_str = "%" + ".%df" % decimal_places
733747

734-
verts_array = verts.detach().numpy()
735-
np.savetxt(f, verts_array, float_str)
748+
vert_data = torch.cat((verts, verts_normals), dim=1)
749+
np.savetxt(f, vert_data.detach().numpy(), float_str)
736750

737751
faces_array = faces.detach().numpy()
738752

@@ -743,16 +757,27 @@ def _save_ply(f, verts, faces, decimal_places: Optional[int]) -> None:
743757
np.savetxt(f, faces_array, "3 %d %d %d")
744758

745759

746-
def save_ply(f, verts, faces, decimal_places: Optional[int] = None):
760+
def save_ply(
761+
f,
762+
verts: torch.Tensor,
763+
faces: Optional[torch.LongTensor] = None,
764+
verts_normals: Optional[torch.Tensor] = None,
765+
decimal_places: Optional[int] = None,
766+
) -> None:
747767
"""
748768
Save a mesh to a .ply file.
749769
750770
Args:
751771
f: File (or path) to which the mesh should be written.
752772
verts: FloatTensor of shape (V, 3) giving vertex coordinates.
753773
faces: LongTensor of shape (F, 3) giving faces.
774+
verts_normals: FloatTensor of shape (V, 3) giving vertex normals.
754775
decimal_places: Number of decimal places for saving.
755776
"""
777+
778+
verts_normals = torch.FloatTensor([]) if verts_normals is None else verts_normals
779+
faces = torch.LongTensor([]) if faces is None else faces
780+
756781
if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3):
757782
message = "Argument 'verts' should either be empty or of shape (num_verts, 3)."
758783
raise ValueError(message)
@@ -761,6 +786,14 @@ def save_ply(f, verts, faces, decimal_places: Optional[int] = None):
761786
message = "Argument 'faces' should either be empty or of shape (num_faces, 3)."
762787
raise ValueError(message)
763788

789+
if len(verts_normals) and not (
790+
verts_normals.dim() == 2
791+
and verts_normals.size(1) == 3
792+
and verts_normals.size(0) == verts.size(0)
793+
):
794+
message = "Argument 'verts_normals' should either be empty or of shape (num_verts, 3)."
795+
raise ValueError(message)
796+
764797
new_f = False
765798
if isinstance(f, str):
766799
new_f = True
@@ -769,7 +802,7 @@ def save_ply(f, verts, faces, decimal_places: Optional[int] = None):
769802
new_f = True
770803
f = f.open("w")
771804
try:
772-
_save_ply(f, verts, faces, decimal_places)
805+
_save_ply(f, verts, faces, verts_normals, decimal_places)
773806
finally:
774807
if new_f:
775808
f.close()

pytorch3d/ops/__init__.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,19 @@
77
from .mesh_face_areas_normals import mesh_face_areas_normals
88
from .packed_to_padded import packed_to_padded, padded_to_packed
99
from .points_alignment import corresponding_points_alignment, iterative_closest_point
10+
from .points_normals import (
11+
estimate_pointcloud_local_coord_frames,
12+
estimate_pointcloud_normals,
13+
)
1014
from .sample_points_from_meshes import sample_points_from_meshes
1115
from .subdivide_meshes import SubdivideMeshes
12-
from .utils import convert_pointclouds_to_tensor, eyes, is_pointclouds, wmean
16+
from .utils import (
17+
convert_pointclouds_to_tensor,
18+
eyes,
19+
get_point_covariances,
20+
is_pointclouds,
21+
wmean,
22+
)
1323
from .vert_align import vert_align
1424

1525

pytorch3d/ops/points_normals.py

+172
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
3+
from typing import TYPE_CHECKING, Tuple, Union
4+
5+
import torch
6+
7+
from .utils import convert_pointclouds_to_tensor, get_point_covariances
8+
9+
10+
if TYPE_CHECKING:
11+
from ..structures import Pointclouds
12+
13+
14+
def estimate_pointcloud_normals(
15+
pointclouds: Union[torch.Tensor, "Pointclouds"],
16+
neighborhood_size: int = 50,
17+
disambiguate_directions: bool = True,
18+
) -> torch.Tensor:
19+
"""
20+
Estimates the normals of a batch of `pointclouds`.
21+
22+
The function uses `estimate_pointcloud_local_coord_frames` to estimate
23+
the normals. Please refer to this function for more detailed information.
24+
25+
Args:
26+
**pointclouds**: Batch of 3-dimensional points of shape
27+
`(minibatch, num_point, 3)` or a `Pointclouds` object.
28+
**neighborhood_size**: The size of the neighborhood used to estimate the
29+
geometry around each point.
30+
**disambiguate_directions**: If `True`, uses the algorithm from [1] to
31+
ensure sign consistency of the normals of neigboring points.
32+
33+
Returns:
34+
**normals**: A tensor of normals for each input point
35+
of shape `(minibatch, num_point, 3)`.
36+
If `pointclouds` are of `Pointclouds` class, returns a padded tensor.
37+
38+
References:
39+
[1] Tombari, Salti, Di Stefano: Unique Signatures of Histograms for
40+
Local Surface Description, ECCV 2010.
41+
"""
42+
43+
curvatures, local_coord_frames = estimate_pointcloud_local_coord_frames(
44+
pointclouds,
45+
neighborhood_size=neighborhood_size,
46+
disambiguate_directions=disambiguate_directions,
47+
)
48+
49+
# the normals correspond to the first vector of each local coord frame
50+
normals = local_coord_frames[:, :, :, 0]
51+
52+
return normals
53+
54+
55+
def estimate_pointcloud_local_coord_frames(
56+
pointclouds: Union[torch.Tensor, "Pointclouds"],
57+
neighborhood_size: int = 50,
58+
disambiguate_directions: bool = True,
59+
) -> Tuple[torch.Tensor, torch.Tensor]:
60+
"""
61+
Estimates the principal directions of curvature (which includes normals)
62+
of a batch of `pointclouds`.
63+
64+
The algorithm first finds `neighborhood_size` nearest neighbors for each
65+
point of the point clouds, followed by obtaining principal vectors of
66+
covariance matrices of each of the point neighborhoods.
67+
The main principal vector corresponds to the normals, while the
68+
other 2 are the direction of the highest curvature and the 2nd highest
69+
curvature.
70+
71+
Note that each principal direction is given up to a sign. Hence,
72+
the function implements `disambiguate_directions` switch that allows
73+
to ensure consistency of the sign of neighboring normals. The implementation
74+
follows the sign disabiguation from SHOT descriptors [1].
75+
76+
The algorithm also returns the curvature values themselves.
77+
These are the eigenvalues of the estimated covariance matrices
78+
of each point neighborhood.
79+
80+
Args:
81+
**pointclouds**: Batch of 3-dimensional points of shape
82+
`(minibatch, num_point, 3)` or a `Pointclouds` object.
83+
**neighborhood_size**: The size of the neighborhood used to estimate the
84+
geometry around each point.
85+
**disambiguate_directions**: If `True`, uses the algorithm from [1] to
86+
ensure sign consistency of the normals of neigboring points.
87+
88+
Returns:
89+
**curvatures**: The three principal curvatures of each point
90+
of shape `(minibatch, num_point, 3)`.
91+
If `pointclouds` are of `Pointclouds` class, returns a padded tensor.
92+
**local_coord_frames**: The three principal directions of the curvature
93+
around each point of shape `(minibatch, num_point, 3, 3)`.
94+
The principal directions are stored in columns of the output.
95+
E.g. `local_coord_frames[i, j, :, 0]` is the normal of
96+
`j`-th point in the `i`-th pointcloud.
97+
If `pointclouds` are of `Pointclouds` class, returns a padded tensor.
98+
99+
References:
100+
[1] Tombari, Salti, Di Stefano: Unique Signatures of Histograms for
101+
Local Surface Description, ECCV 2010.
102+
"""
103+
104+
points_padded, num_points = convert_pointclouds_to_tensor(pointclouds)
105+
106+
ba, N, dim = points_padded.shape
107+
if dim != 3:
108+
raise ValueError(
109+
"The pointclouds argument has to be of shape (minibatch, N, 3)"
110+
)
111+
112+
if (num_points <= neighborhood_size).any():
113+
raise ValueError(
114+
"The neighborhood_size argument has to be"
115+
+ " >= size of each of the point clouds."
116+
)
117+
118+
# undo global mean for stability
119+
# TODO: replace with tutil.wmean once landed
120+
pcl_mean = points_padded.sum(1) / num_points[:, None]
121+
points_centered = points_padded - pcl_mean[:, None, :]
122+
123+
# get the per-point covariance and nearest neighbors used to compute it
124+
cov, knns = get_point_covariances(points_centered, num_points, neighborhood_size)
125+
126+
# get the local coord frames as principal directions of
127+
# the per-point covariance
128+
# this is done with torch.symeig, which returns the
129+
# eigenvectors (=principal directions) in an ascending order of their
130+
# corresponding eigenvalues, while the smallest eigenvalue's eigenvector
131+
# corresponds to the normal direction
132+
curvatures, local_coord_frames = torch.symeig(cov, eigenvectors=True)
133+
134+
# disambiguate the directions of individual principal vectors
135+
if disambiguate_directions:
136+
# disambiguate normal
137+
n = _disambiguate_vector_directions(
138+
points_centered, knns, local_coord_frames[:, :, :, 0]
139+
)
140+
# disambiguate the main curvature
141+
z = _disambiguate_vector_directions(
142+
points_centered, knns, local_coord_frames[:, :, :, 2]
143+
)
144+
# the secondary curvature is just a cross between n and z
145+
y = torch.cross(n, z, dim=2)
146+
# cat to form the set of principal directions
147+
local_coord_frames = torch.stack((n, y, z), dim=3)
148+
149+
return curvatures, local_coord_frames
150+
151+
152+
def _disambiguate_vector_directions(pcl, knns, vecs):
153+
"""
154+
Disambiguates normal directions according to [1].
155+
156+
References:
157+
[1] Tombari, Salti, Di Stefano: Unique Signatures of Histograms for
158+
Local Surface Description, ECCV 2010.
159+
"""
160+
# parse out K from the shape of knns
161+
K = knns.shape[2]
162+
# the difference between the mean of each neighborhood and
163+
# each element of the neighborhood
164+
df = knns - pcl[:, :, None]
165+
# projection of the difference on the principal direction
166+
proj = (vecs[:, :, None] * df).sum(3)
167+
# check how many projections are positive
168+
n_pos = (proj > 0).type_as(knns).sum(2, keepdim=True)
169+
# flip the principal directions where number of positive correlations
170+
flip = (n_pos < (0.5 * K)).type_as(knns)
171+
vecs = (1.0 - 2.0 * flip) * vecs
172+
return vecs

pytorch3d/ops/utils.py

+48-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
import torch
55

6+
from .knn import knn_points
7+
68

79
if TYPE_CHECKING:
810
from pytorch3d.structures import Pointclouds
@@ -92,8 +94,53 @@ def convert_pointclouds_to_tensor(pcl: Union[torch.Tensor, "Pointclouds"]):
9294

9395

9496
def is_pointclouds(pcl: Union[torch.Tensor, "Pointclouds"]):
95-
""" Checks whether the input `pcl` is an instance `Pointclouds` of
97+
""" Checks whether the input `pcl` is an instance of `Pointclouds`
9698
by checking the existence of `points_padded` and `num_points_per_cloud`
9799
functions.
98100
"""
99101
return hasattr(pcl, "points_padded") and hasattr(pcl, "num_points_per_cloud")
102+
103+
104+
def get_point_covariances(
105+
points_padded: torch.Tensor,
106+
num_points_per_cloud: torch.Tensor,
107+
neighborhood_size: int,
108+
) -> Tuple[torch.Tensor, torch.Tensor]:
109+
"""
110+
Computes the per-point covariance matrices by of the 3D locations of
111+
K-nearest neighbors of each point.
112+
113+
Args:
114+
**points_padded**: Input point clouds as a padded tensor
115+
of shape `(minibatch, num_points, dim)`.
116+
**num_points_per_cloud**: Number of points per cloud
117+
of shape `(minibatch,)`.
118+
**neighborhood_size**: Number of nearest neighbors for each point
119+
used to estimate the covariance matrices.
120+
121+
Returns:
122+
**covariances**: A batch of per-point covariance matrices
123+
of shape `(minibatch, dim, dim)`.
124+
**k_nearest_neighbors**: A batch of `neighborhood_size` nearest
125+
neighbors for each of the point cloud points
126+
of shape `(minibatch, num_points, neighborhood_size, dim)`.
127+
"""
128+
# get K nearest neighbor idx for each point in the point cloud
129+
_, _, k_nearest_neighbors = knn_points(
130+
points_padded,
131+
points_padded,
132+
lengths1=num_points_per_cloud,
133+
lengths2=num_points_per_cloud,
134+
K=neighborhood_size,
135+
return_nn=True,
136+
)
137+
# obtain the mean of the neighborhood
138+
pt_mean = k_nearest_neighbors.mean(2, keepdim=True)
139+
# compute the diff of the neighborhood and the mean of the neighborhood
140+
central_diff = k_nearest_neighbors - pt_mean
141+
# per-nn-point covariances
142+
per_pt_cov = central_diff.unsqueeze(4) * central_diff.unsqueeze(3)
143+
# per-point covariances
144+
covariances = per_pt_cov.mean(2)
145+
146+
return covariances, k_nearest_neighbors

0 commit comments

Comments
 (0)