Skip to content

Commit 3b7d78c

Browse files
nikhilaravifacebook-github-bot
authored andcommitted
Farthest point sampling python naive
Summary: This is a naive python implementation of the iterative farthest point sampling algorithm along with associated simple tests. The C++/CUDA implementations will follow in subsequent diffs. The algorithm is used to subsample a pointcloud with better coverage of the space of the pointcloud. The function has not been added to `__init__.py`. I will add this after the full C++/CUDA implementations. Reviewed By: jcjohnson Differential Revision: D30285716 fbshipit-source-id: 33f4181041fc652776406bcfd67800a6f0c3dd58
1 parent a0d76a7 commit 3b7d78c

File tree

5 files changed

+295
-12
lines changed

5 files changed

+295
-12
lines changed

pytorch3d/ops/ball_query.py

+2-12
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from torch.autograd.function import once_differentiable
1313

1414
from .knn import _KNN
15+
from .utils import masked_gather
1516

1617

1718
class _ball_query(Function):
@@ -123,7 +124,6 @@ def ball_query(
123124
p2 = p2.contiguous()
124125
P1 = p1.shape[1]
125126
P2 = p2.shape[1]
126-
D = p2.shape[2]
127127
N = p1.shape[0]
128128

129129
if lengths1 is None:
@@ -135,16 +135,6 @@ def ball_query(
135135
dists, idx = _ball_query.apply(p1, p2, lengths1, lengths2, K, radius)
136136

137137
# Gather the neighbors if needed
138-
points_nn = None
139-
if return_nn:
140-
idx_expanded = idx[:, :, :, None].expand(-1, -1, -1, D)
141-
idx_mask = idx_expanded.eq(-1)
142-
idx_new = idx_expanded.clone()
143-
# Replace -1 values with 0 for gather
144-
idx_new[idx_mask] = 0
145-
# Gather points from p2
146-
points_nn = p2[:, :, None].expand(-1, -1, K, -1).gather(1, idx_new)
147-
# Replace padded values
148-
points_nn[idx_mask] = 0.0
138+
points_nn = masked_gather(p2, idx) if return_nn else None
149139

150140
return _KNN(dists=dists, idx=idx, knn=points_nn)
+124
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from random import randint
8+
from typing import Optional, Tuple, Union, List
9+
10+
import torch
11+
12+
from .utils import masked_gather
13+
14+
15+
def sample_farthest_points_naive(
16+
points: torch.Tensor,
17+
lengths: Optional[torch.Tensor] = None,
18+
K: Union[int, List, torch.Tensor] = 50,
19+
random_start_point: bool = False,
20+
) -> Tuple[torch.Tensor, torch.Tensor]:
21+
"""
22+
Iterative farthest point sampling algorithm [1] to subsample a set of
23+
K points from a given pointcloud. At each iteration, a point is selected
24+
which has the largest nearest neighbor distance to any of the
25+
already selected points.
26+
27+
Farthest point sampling provides more uniform coverage of the input
28+
point cloud compared to uniform random sampling.
29+
30+
[1] Charles R. Qi et al, "PointNet++: Deep Hierarchical Feature Learning
31+
on Point Sets in a Metric Space", NeurIPS 2017.
32+
33+
Args:
34+
points: (N, P, D) array containing the batch of pointclouds
35+
lengths: (N,) number of points in each pointcloud (to support heterogeneous
36+
batches of pointclouds)
37+
K: samples you want in each sampled point cloud (this is typically << P). If
38+
K is an int then the same number of samples are selected for each
39+
pointcloud in the batch. If K is a tensor is should be length (N,)
40+
giving the number of samples to select for each element in the batch
41+
random_start_point: bool, if True, a random point is selected as the starting
42+
point for iterative sampling.
43+
44+
Returns:
45+
selected_points: (N, K, D), array of selected values from points. If the input
46+
K is a tensor, then the shape will be (N, max(K), D), and padded with
47+
0.0 for batch elements where k_i < max(K).
48+
selected_indices: (N, K) array of selected indices. If the input
49+
K is a tensor, then the shape will be (N, max(K), D), and padded with
50+
-1 for batch elements where k_i < max(K).
51+
"""
52+
N, P, D = points.shape
53+
device = points.device
54+
55+
# Validate inputs
56+
if lengths is None:
57+
lengths = torch.full((N,), P, dtype=torch.int64, device=device)
58+
59+
if lengths.shape[0] != N:
60+
raise ValueError("points and lengths must have same batch dimension.")
61+
62+
# TODO: support providing K as a ratio of the total number of points instead of as an int
63+
if isinstance(K, int):
64+
K = torch.full((N,), K, dtype=torch.int64, device=device)
65+
elif isinstance(K, list):
66+
K = torch.tensor(K, dtype=torch.int64, device=device)
67+
68+
if K.shape[0] != N:
69+
raise ValueError("K and points must have the same batch dimension")
70+
71+
# Find max value of K
72+
max_K = torch.max(K)
73+
74+
# List of selected indices from each batch element
75+
all_sampled_indices = []
76+
77+
for n in range(N):
78+
# Initialize an array for the sampled indices, shape: (max_K,)
79+
sample_idx_batch = torch.full(
80+
(max_K,), fill_value=-1, dtype=torch.int64, device=device
81+
)
82+
83+
# Initialize closest distances to inf, shape: (P,)
84+
# This will be updated at each iteration to track the closest distance of the
85+
# remaining points to any of the selected points
86+
# pyre-fixme[16]: `torch.Tensor` has no attribute new_full.
87+
closest_dists = points.new_full(
88+
(lengths[n],), float("inf"), dtype=torch.float32
89+
)
90+
91+
# Select a random point index and save it as the starting point
92+
selected_idx = randint(0, lengths[n] - 1) if random_start_point else 0
93+
sample_idx_batch[0] = selected_idx
94+
95+
# If the pointcloud has fewer than K points then only iterate over the min
96+
k_n = min(lengths[n], K[n])
97+
98+
# Iteratively select points for a maximum of k_n
99+
for i in range(1, k_n):
100+
# Find the distance between the last selected point
101+
# and all the other points. If a point has already been selected
102+
# it's distance will be 0.0 so it will not be selected again as the max.
103+
dist = points[n, selected_idx, :] - points[n, : lengths[n], :]
104+
dist_to_last_selected = (dist ** 2).sum(-1) # (P - i)
105+
106+
# If closer than currently saved distance to one of the selected
107+
# points, then updated closest_dists
108+
closest_dists = torch.min(dist_to_last_selected, closest_dists) # (P - i)
109+
110+
# The aim is to pick the point that has the largest
111+
# nearest neighbour distance to any of the already selected points
112+
selected_idx = torch.argmax(closest_dists)
113+
sample_idx_batch[i] = selected_idx
114+
115+
# Add the list of points for this batch to the final list
116+
all_sampled_indices.append(sample_idx_batch)
117+
118+
all_sampled_indices = torch.stack(all_sampled_indices, dim=0)
119+
120+
# Gather the points
121+
all_sampled_points = masked_gather(points, all_sampled_indices)
122+
123+
# Return (N, max_K, D) subsampled points and indices
124+
return all_sampled_points, all_sampled_indices

pytorch3d/ops/utils.py

+48
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,54 @@
1515
from pytorch3d.structures import Pointclouds
1616

1717

18+
def masked_gather(points: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
19+
"""
20+
Helper function for torch.gather to collect the points at
21+
the given indices in idx where some of the indices might be -1 to
22+
indicate padding. These indices are first replaced with 0.
23+
Then the points are gathered after which the padded values
24+
are set to 0.0.
25+
26+
Args:
27+
points: (N, P, D) float32 tensor of points
28+
idx: (N, K) or (N, P, K) long tensor of indices into points, where
29+
some indices are -1 to indicate padding
30+
31+
Returns:
32+
selected_points: (N, K, D) float32 tensor of points
33+
at the given indices
34+
"""
35+
36+
if len(idx) != len(points):
37+
raise ValueError("points and idx must have the same batch dimension")
38+
39+
N, P, D = points.shape
40+
41+
if idx.ndim == 3:
42+
# Case: KNN, Ball Query where idx is of shape (N, P', K)
43+
# where P' is not necessarily the same as P as the
44+
# points may be gathered from a different pointcloud.
45+
K = idx.shape[2]
46+
# Match dimensions for points and indices
47+
idx_expanded = idx[..., None].expand(-1, -1, -1, D)
48+
points = points[:, :, None, :].expand(-1, -1, K, -1)
49+
elif idx.ndim == 2:
50+
# Farthest point sampling where idx is of shape (N, K)
51+
idx_expanded = idx[..., None].expand(-1, -1, D)
52+
else:
53+
raise ValueError("idx format is not supported %s" % repr(idx.shape))
54+
55+
idx_expanded_mask = idx_expanded.eq(-1)
56+
idx_expanded = idx_expanded.clone()
57+
# Replace -1 values with 0 for gather
58+
idx_expanded[idx_expanded_mask] = 0
59+
# Gather points
60+
selected_points = points.gather(dim=1, index=idx_expanded)
61+
# Replace padded values
62+
selected_points[idx_expanded_mask] = 0.0
63+
return selected_points
64+
65+
1866
def wmean(
1967
x: torch.Tensor,
2068
weight: Optional[torch.Tensor] = None,

tests/test_ops_utils.py

+10
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,13 @@ def test_wmean(self):
7676
mean = oputil.wmean(x, dim=(0, 1), weight=weight, keepdim=False)
7777
mean_gt = np.average(x_np, axis=(0, 1), weights=weight_np)
7878
self.assertClose(mean.cpu().data.numpy(), mean_gt)
79+
80+
def test_masked_gather_errors(self):
81+
idx = torch.randint(0, 10, size=(5, 10, 4, 2))
82+
points = torch.randn(size=(5, 10, 3))
83+
with self.assertRaisesRegex(ValueError, "format is not supported"):
84+
oputil.masked_gather(points, idx)
85+
86+
points = torch.randn(size=(2, 10, 3))
87+
with self.assertRaisesRegex(ValueError, "same batch dimension"):
88+
oputil.masked_gather(points, idx)

tests/test_sample_farthest_points.py

+111
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from common_testing import TestCaseMixin, get_random_cuda_device
11+
from pytorch3d.ops.sample_farthest_points import sample_farthest_points_naive
12+
from pytorch3d.ops.utils import masked_gather
13+
14+
15+
class TestFPS(TestCaseMixin, unittest.TestCase):
16+
def test_simple(self):
17+
device = get_random_cuda_device()
18+
# fmt: off
19+
points = torch.tensor(
20+
[
21+
[
22+
[-1.0, -1.0], # noqa: E241, E201
23+
[-1.3, 1.1], # noqa: E241, E201
24+
[ 0.2, -1.1], # noqa: E241, E201
25+
[ 0.0, 0.0], # noqa: E241, E201
26+
[ 1.3, 1.3], # noqa: E241, E201
27+
[ 1.0, 0.5], # noqa: E241, E201
28+
[-1.3, 0.2], # noqa: E241, E201
29+
[ 1.5, -0.5], # noqa: E241, E201
30+
],
31+
[
32+
[-2.2, -2.4], # noqa: E241, E201
33+
[-2.1, 2.0], # noqa: E241, E201
34+
[ 2.2, 2.1], # noqa: E241, E201
35+
[ 2.1, -2.4], # noqa: E241, E201
36+
[ 0.4, -1.0], # noqa: E241, E201
37+
[ 0.3, 0.3], # noqa: E241, E201
38+
[ 1.2, 0.5], # noqa: E241, E201
39+
[ 4.5, 4.5], # noqa: E241, E201
40+
],
41+
],
42+
dtype=torch.float32,
43+
device=device,
44+
)
45+
# fmt: on
46+
expected_inds = torch.tensor([[0, 4], [0, 7]], dtype=torch.int64, device=device)
47+
out_points, out_inds = sample_farthest_points_naive(points, K=2)
48+
self.assertClose(out_inds, expected_inds)
49+
50+
# Gather the points
51+
expected_inds = expected_inds[..., None].expand(-1, -1, points.shape[-1])
52+
self.assertClose(out_points, points.gather(dim=1, index=expected_inds))
53+
54+
# Different number of points sampled for each pointcloud in the batch
55+
expected_inds = torch.tensor(
56+
[[0, 4, 1], [0, 7, -1]], dtype=torch.int64, device=device
57+
)
58+
out_points, out_inds = sample_farthest_points_naive(points, K=[3, 2])
59+
self.assertClose(out_inds, expected_inds)
60+
61+
# Gather the points
62+
expected_points = masked_gather(points, expected_inds)
63+
self.assertClose(out_points, expected_points)
64+
65+
def test_random_heterogeneous(self):
66+
device = get_random_cuda_device()
67+
N, P, D, K = 5, 40, 5, 8
68+
points = torch.randn((N, P, D), device=device)
69+
out_points, out_idxs = sample_farthest_points_naive(points, K=K)
70+
self.assertTrue(out_idxs.min() >= 0)
71+
for n in range(N):
72+
self.assertEqual(out_idxs[n].ne(-1).sum(), K)
73+
74+
lengths = torch.randint(low=1, high=P, size=(N,), device=device)
75+
out_points, out_idxs = sample_farthest_points_naive(points, lengths, K=50)
76+
77+
for n in range(N):
78+
# Check that for heterogeneous batches, the max number of
79+
# selected points is less than the length
80+
self.assertTrue(out_idxs[n].ne(-1).sum() <= lengths[n])
81+
self.assertTrue(out_idxs[n].max() <= lengths[n])
82+
83+
# Check there are no duplicate indices
84+
val_mask = out_idxs[n].ne(-1)
85+
vals, counts = torch.unique(out_idxs[n][val_mask], return_counts=True)
86+
self.assertTrue(counts.le(1).all())
87+
88+
def test_errors(self):
89+
device = get_random_cuda_device()
90+
N, P, D, K = 5, 40, 5, 8
91+
points = torch.randn((N, P, D), device=device)
92+
wrong_batch_dim = torch.randint(low=1, high=K, size=(K,), device=device)
93+
94+
# K has diferent batch dimension to points
95+
with self.assertRaisesRegex(ValueError, "K and points must have"):
96+
sample_farthest_points_naive(points, K=wrong_batch_dim)
97+
98+
# lengths has diferent batch dimension to points
99+
with self.assertRaisesRegex(ValueError, "points and lengths must have"):
100+
sample_farthest_points_naive(points, lengths=wrong_batch_dim, K=K)
101+
102+
def test_random_start(self):
103+
device = get_random_cuda_device()
104+
N, P, D, K = 5, 40, 5, 8
105+
points = torch.randn((N, P, D), device=device)
106+
out_points, out_idxs = sample_farthest_points_naive(
107+
points, K=K, random_start_point=True
108+
)
109+
# Check the first index is not 0 for all batch elements
110+
# when random_start_point = True
111+
self.assertTrue(out_idxs[:, 0].sum() > 0)

0 commit comments

Comments
 (0)