|
| 1 | +# Copyright (c) Facebook, Inc. and its affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | + |
| 8 | +import torch |
| 9 | + |
| 10 | + |
| 11 | +def sample_pdf_python( |
| 12 | + bins: torch.Tensor, |
| 13 | + weights: torch.Tensor, |
| 14 | + N_samples: int, |
| 15 | + det: bool = False, |
| 16 | + eps: float = 1e-5, |
| 17 | +) -> torch.Tensor: |
| 18 | + """ |
| 19 | + Samples probability density functions defined by bin edges `bins` and |
| 20 | + the non-negative per-bin probabilities `weights`. |
| 21 | +
|
| 22 | + Note: This is a direct conversion of the TensorFlow function from the original |
| 23 | + release [1] to PyTorch. |
| 24 | +
|
| 25 | + Args: |
| 26 | + bins: Tensor of shape `(..., n_bins+1)` denoting the edges of the sampling bins. |
| 27 | + weights: Tensor of shape `(..., n_bins)` containing non-negative numbers |
| 28 | + representing the probability of sampling the corresponding bin. |
| 29 | + N_samples: The number of samples to draw from each set of bins. |
| 30 | + det: If `False`, the sampling is random. `True` yields deterministic |
| 31 | + uniformly-spaced sampling from the inverse cumulative density function. |
| 32 | + eps: A constant preventing division by zero in case empty bins are present. |
| 33 | +
|
| 34 | + Returns: |
| 35 | + samples: Tensor of shape `(..., N_samples)` containing `N_samples` samples |
| 36 | + drawn from each probability distribution. |
| 37 | +
|
| 38 | + Refs: |
| 39 | + [1] https://github.com/bmild/nerf/blob/55d8b00244d7b5178f4d003526ab6667683c9da9/run_nerf_helpers.py#L183 # noqa E501 |
| 40 | + """ |
| 41 | + |
| 42 | + # Get pdf |
| 43 | + weights = weights + eps # prevent nans |
| 44 | + if weights.min() <= 0: |
| 45 | + raise ValueError("Negative weights provided.") |
| 46 | + pdf = weights / weights.sum(dim=-1, keepdim=True) |
| 47 | + cdf = torch.cumsum(pdf, -1) |
| 48 | + cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) |
| 49 | + |
| 50 | + # Take uniform samples u of shape (..., N_samples) |
| 51 | + if det: |
| 52 | + u = torch.linspace(0.0, 1.0, N_samples, device=cdf.device, dtype=cdf.dtype) |
| 53 | + u = u.expand(list(cdf.shape[:-1]) + [N_samples]).contiguous() |
| 54 | + else: |
| 55 | + u = torch.rand( |
| 56 | + list(cdf.shape[:-1]) + [N_samples], device=cdf.device, dtype=cdf.dtype |
| 57 | + ) |
| 58 | + |
| 59 | + # Invert CDF |
| 60 | + inds = torch.searchsorted(cdf, u, right=True) |
| 61 | + # inds has shape (..., N_samples) identifying the bin of each sample. |
| 62 | + below = (inds - 1).clamp(0) |
| 63 | + above = inds.clamp(max=cdf.shape[-1] - 1) |
| 64 | + # Below and above are of shape (..., N_samples), identifying the bin |
| 65 | + # edges surrounding each sample. |
| 66 | + |
| 67 | + inds_g = torch.stack([below, above], -1).view( |
| 68 | + *below.shape[:-1], below.shape[-1] * 2 |
| 69 | + ) |
| 70 | + cdf_g = torch.gather(cdf, -1, inds_g).view(*below.shape, 2) |
| 71 | + bins_g = torch.gather(bins, -1, inds_g).view(*below.shape, 2) |
| 72 | + # cdf_g and bins_g are of shape (..., N_samples, 2) and identify |
| 73 | + # the cdf and the index of the two bin edges surrounding each sample. |
| 74 | + |
| 75 | + denom = cdf_g[..., 1] - cdf_g[..., 0] |
| 76 | + denom = torch.where(denom < eps, torch.ones_like(denom), denom) |
| 77 | + t = (u - cdf_g[..., 0]) / denom |
| 78 | + # t is of shape (..., N_samples) and identifies how far through |
| 79 | + # each sample is in its bin. |
| 80 | + |
| 81 | + samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) |
| 82 | + |
| 83 | + return samples |
0 commit comments