|
| 1 | +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. |
| 2 | + |
| 3 | +import torch |
| 4 | + |
| 5 | +from .so3 import hat, _so3_exp_map, so3_log_map |
| 6 | + |
| 7 | + |
| 8 | +def se3_exp_map(log_transform: torch.Tensor, eps: float = 1e-4) -> torch.Tensor: |
| 9 | + """ |
| 10 | + Convert a batch of logarithmic representations of SE(3) matrices `log_transform` |
| 11 | + to a batch of 4x4 SE(3) matrices using the exponential map. |
| 12 | + See e.g. [1], Sec 9.4.2. for more detailed description. |
| 13 | +
|
| 14 | + A SE(3) matrix has the following form: |
| 15 | + ``` |
| 16 | + [ R 0 ] |
| 17 | + [ T 1 ] , |
| 18 | + ``` |
| 19 | + where `R` is a 3x3 rotation matrix and `T` is a 3-D translation vector. |
| 20 | + SE(3) matrices are commonly used to represent rigid motions or camera extrinsics. |
| 21 | +
|
| 22 | + In the SE(3) logarithmic representation SE(3) matrices are |
| 23 | + represented as 6-dimensional vectors `[log_translation | log_rotation]`, |
| 24 | + i.e. a concatenation of two 3D vectors `log_translation` and `log_rotation`. |
| 25 | +
|
| 26 | + The conversion from the 6D representation to a 4x4 SE(3) matrix `transform` |
| 27 | + is done as follows: |
| 28 | + ``` |
| 29 | + transform = exp( [ hat(log_rotation) 0 ] |
| 30 | + [ log_translation 1 ] ) , |
| 31 | + ``` |
| 32 | + where `exp` is the matrix exponential and `hat` is the Hat operator [2]. |
| 33 | +
|
| 34 | + Note that for any `log_transform` with `0 <= ||log_rotation|| < 2pi` |
| 35 | + (i.e. the rotation angle is between 0 and 2pi), the following identity holds: |
| 36 | + ``` |
| 37 | + se3_log_map(se3_exponential_map(log_transform)) == log_transform |
| 38 | + ``` |
| 39 | +
|
| 40 | + The conversion has a singularity around `||log(transform)|| = 0` |
| 41 | + which is handled by clamping controlled with the `eps` argument. |
| 42 | +
|
| 43 | + Args: |
| 44 | + log_transform: Batch of vectors of shape `(minibatch, 6)`. |
| 45 | + eps: A threshold for clipping the squared norm of the rotation logarithm |
| 46 | + to avoid unstable gradients in the singular case. |
| 47 | +
|
| 48 | + Returns: |
| 49 | + Batch of transformation matrices of shape `(minibatch, 4, 4)`. |
| 50 | +
|
| 51 | + Raises: |
| 52 | + ValueError if `log_transform` is of incorrect shape. |
| 53 | +
|
| 54 | + [1] https://jinyongjeong.github.io/Download/SE3/jlblanco2010geometry3d_techrep.pdf |
| 55 | + [2] https://en.wikipedia.org/wiki/Hat_operator |
| 56 | + """ |
| 57 | + |
| 58 | + if log_transform.ndim != 2 or log_transform.shape[1] != 6: |
| 59 | + raise ValueError("Expected input to be of shape (N, 6).") |
| 60 | + |
| 61 | + N, _ = log_transform.shape |
| 62 | + |
| 63 | + log_translation = log_transform[..., :3] |
| 64 | + log_rotation = log_transform[..., 3:] |
| 65 | + |
| 66 | + # rotation is an exponential map of log_rotation |
| 67 | + ( |
| 68 | + R, |
| 69 | + rotation_angles, |
| 70 | + log_rotation_hat, |
| 71 | + log_rotation_hat_square, |
| 72 | + ) = _so3_exp_map(log_rotation, eps=eps) |
| 73 | + |
| 74 | + # translation is V @ T |
| 75 | + V = _se3_V_matrix( |
| 76 | + log_rotation, |
| 77 | + log_rotation_hat, |
| 78 | + log_rotation_hat_square, |
| 79 | + rotation_angles, |
| 80 | + eps=eps, |
| 81 | + ) |
| 82 | + T = torch.bmm(V, log_translation[:, :, None])[:, :, 0] |
| 83 | + |
| 84 | + transform = torch.zeros( |
| 85 | + N, 4, 4, dtype=log_transform.dtype, device=log_transform.device |
| 86 | + ) |
| 87 | + |
| 88 | + transform[:, :3, :3] = R |
| 89 | + transform[:, :3, 3] = T |
| 90 | + transform[:, 3, 3] = 1.0 |
| 91 | + |
| 92 | + return transform.permute(0, 2, 1) |
| 93 | + |
| 94 | + |
| 95 | +def se3_log_map( |
| 96 | + transform: torch.Tensor, eps: float = 1e-4, cos_bound: float = 1e-4 |
| 97 | +) -> torch.Tensor: |
| 98 | + """ |
| 99 | + Convert a batch of 4x4 transformation matrices `transform` |
| 100 | + to a batch of 6-dimensional SE(3) logarithms of the SE(3) matrices. |
| 101 | + See e.g. [1], Sec 9.4.2. for more detailed description. |
| 102 | +
|
| 103 | + A SE(3) matrix has the following form: |
| 104 | + ``` |
| 105 | + [ R 0 ] |
| 106 | + [ T 1 ] , |
| 107 | + ``` |
| 108 | + where `R` is an orthonormal 3x3 rotation matrix and `T` is a 3-D translation vector. |
| 109 | + SE(3) matrices are commonly used to represent rigid motions or camera extrinsics. |
| 110 | +
|
| 111 | + In the SE(3) logarithmic representation SE(3) matrices are |
| 112 | + represented as 6-dimensional vectors `[log_translation | log_rotation]`, |
| 113 | + i.e. a concatenation of two 3D vectors `log_translation` and `log_rotation`. |
| 114 | +
|
| 115 | + The conversion from the 4x4 SE(3) matrix `transform` to the |
| 116 | + 6D representation `log_transform = [log_translation | log_rotation]` |
| 117 | + is done as follows: |
| 118 | + ``` |
| 119 | + log_transform = log(transform) |
| 120 | + log_translation = log_transform[3, :3] |
| 121 | + log_rotation = inv_hat(log_transform[:3, :3]) |
| 122 | + ``` |
| 123 | + where `log` is the matrix logarithm |
| 124 | + and `inv_hat` is the inverse of the Hat operator [2]. |
| 125 | +
|
| 126 | + Note that for any valid 4x4 `transform` matrix, the following identity holds: |
| 127 | + ``` |
| 128 | + se3_exp_map(se3_log_map(transform)) == transform |
| 129 | + ``` |
| 130 | +
|
| 131 | + The conversion has a singularity around `(transform=I)` which is handled |
| 132 | + by clamping controlled with the `eps` and `cos_bound` arguments. |
| 133 | +
|
| 134 | + Args: |
| 135 | + transform: batch of SE(3) matrices of shape `(minibatch, 4, 4)`. |
| 136 | + eps: A threshold for clipping the squared norm of the rotation logarithm |
| 137 | + to avoid division by zero in the singular case. |
| 138 | + cos_bound: Clamps the cosine of the rotation angle to |
| 139 | + [-1 + cos_bound, 3 - cos_bound] to avoid non-finite outputs. |
| 140 | + The non-finite outputs can be caused by passing small rotation angles |
| 141 | + to the `acos` function in `so3_rotation_angle` of `so3_log_map`. |
| 142 | +
|
| 143 | + Returns: |
| 144 | + Batch of logarithms of input SE(3) matrices |
| 145 | + of shape `(minibatch, 6)`. |
| 146 | +
|
| 147 | + Raises: |
| 148 | + ValueError if `transform` is of incorrect shape. |
| 149 | + ValueError if `R` has an unexpected trace. |
| 150 | +
|
| 151 | + [1] https://jinyongjeong.github.io/Download/SE3/jlblanco2010geometry3d_techrep.pdf |
| 152 | + [2] https://en.wikipedia.org/wiki/Hat_operator |
| 153 | + """ |
| 154 | + |
| 155 | + if transform.ndim != 3: |
| 156 | + raise ValueError("Input tensor shape has to be (N, 4, 4).") |
| 157 | + |
| 158 | + N, dim1, dim2 = transform.shape |
| 159 | + if dim1 != 4 or dim2 != 4: |
| 160 | + raise ValueError("Input tensor shape has to be (N, 4, 4).") |
| 161 | + |
| 162 | + if not torch.allclose(transform[:, :3, 3], torch.zeros_like(transform[:, :3, 3])): |
| 163 | + raise ValueError("All elements of `transform[:, :3, 3]` should be 0.") |
| 164 | + |
| 165 | + # log_rot is just so3_log_map of the upper left 3x3 block |
| 166 | + R = transform[:, :3, :3].permute(0, 2, 1) |
| 167 | + log_rotation = so3_log_map(R, eps=eps, cos_bound=cos_bound) |
| 168 | + |
| 169 | + # log_translation is V^-1 @ T |
| 170 | + T = transform[:, 3, :3] |
| 171 | + V = _se3_V_matrix(*_get_se3_V_input(log_rotation), eps=eps) |
| 172 | + log_translation = torch.linalg.solve(V, T[:, :, None])[:, :, 0] |
| 173 | + |
| 174 | + return torch.cat((log_translation, log_rotation), dim=1) |
| 175 | + |
| 176 | + |
| 177 | +def _se3_V_matrix( |
| 178 | + log_rotation: torch.Tensor, |
| 179 | + log_rotation_hat: torch.Tensor, |
| 180 | + log_rotation_hat_square: torch.Tensor, |
| 181 | + rotation_angles: torch.Tensor, |
| 182 | + eps: float = 1e-4, |
| 183 | +) -> torch.Tensor: |
| 184 | + """ |
| 185 | + A helper function that computes the "V" matrix from [1], Sec 9.4.2. |
| 186 | + [1] https://jinyongjeong.github.io/Download/SE3/jlblanco2010geometry3d_techrep.pdf |
| 187 | + """ |
| 188 | + |
| 189 | + V = ( |
| 190 | + torch.eye(3, dtype=log_rotation.dtype, device=log_rotation.device)[None] |
| 191 | + + log_rotation_hat |
| 192 | + * ((1 - torch.cos(rotation_angles)) / (rotation_angles ** 2))[:, None, None] |
| 193 | + + ( |
| 194 | + log_rotation_hat_square |
| 195 | + * ((rotation_angles - torch.sin(rotation_angles)) / (rotation_angles ** 3))[ |
| 196 | + :, None, None |
| 197 | + ] |
| 198 | + ) |
| 199 | + ) |
| 200 | + |
| 201 | + return V |
| 202 | + |
| 203 | + |
| 204 | +def _get_se3_V_input(log_rotation: torch.Tensor, eps: float = 1e-4): |
| 205 | + """ |
| 206 | + A helper function that computes the input variables to the `_se3_V_matrix` |
| 207 | + function. |
| 208 | + """ |
| 209 | + nrms = (log_rotation ** 2).sum(-1) |
| 210 | + rotation_angles = torch.clamp(nrms, eps).sqrt() |
| 211 | + log_rotation_hat = hat(log_rotation) |
| 212 | + log_rotation_hat_square = torch.bmm(log_rotation_hat, log_rotation_hat) |
| 213 | + return log_rotation, log_rotation_hat, log_rotation_hat_square, rotation_angles |
0 commit comments