Skip to content

Commit 8041178

Browse files
nikhilaravifacebook-github-bot
authored andcommitted
Fix to allow cameras in the renderer forward pass
Summary: Fix to resolve GitHub issue #796 - the cameras were being passed in the renderer forward pass instead of at initialization. The rasterizer was correctly using the cameras passed in the `kwargs` for the projection, but the `cameras` are still part of the `kwargs` for the `get_screen_to_ndc_transform` and `get_ndc_to_screen_transform` functions which is causing issues about duplicate arguments. Reviewed By: bottler Differential Revision: D30175679 fbshipit-source-id: 547e88d8439456e728fa2772722df5fa0fe4584d
1 parent 4046677 commit 8041178

File tree

2 files changed

+73
-11
lines changed

2 files changed

+73
-11
lines changed

pytorch3d/renderer/cameras.py

+22-11
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,10 @@ def get_ndc_camera_transform(self, **kwargs) -> Transform3d:
262262
# We don't flip xy because we assume that world points are in
263263
# PyTorch3D coordinates, and thus conversion from screen to ndc
264264
# is a mere scaling from image to [-1, 1] scale.
265-
return get_screen_to_ndc_transform(self, with_xyflip=False, **kwargs)
265+
image_size = kwargs.get("image_size", self.get_image_size())
266+
return get_screen_to_ndc_transform(
267+
self, with_xyflip=False, image_size=image_size
268+
)
266269

267270
def transform_points_ndc(
268271
self, points, eps: Optional[float] = None, **kwargs
@@ -318,8 +321,9 @@ def transform_points_screen(
318321
new_points: transformed points with the same shape as the input.
319322
"""
320323
points_ndc = self.transform_points_ndc(points, eps=eps, **kwargs)
324+
image_size = kwargs.get("image_size", self.get_image_size())
321325
return get_ndc_to_screen_transform(
322-
self, with_xyflip=True, **kwargs
326+
self, with_xyflip=True, image_size=image_size
323327
).transform_points(points_ndc, eps=eps)
324328

325329
def clone(self):
@@ -923,7 +927,7 @@ def __init__(
923927
K: (optional) A calibration matrix of shape (N, 4, 4)
924928
If provided, don't need focal_length, principal_point
925929
image_size: (height, width) of image size.
926-
A tensor of shape (N, 2). Required for screen cameras.
930+
A tensor of shape (N, 2) or a list/tuple. Required for screen cameras.
927931
device: torch.device or string
928932
"""
929933
# The initializer formats all inputs to torch tensors and broadcasts
@@ -1044,8 +1048,9 @@ def get_ndc_camera_transform(self, **kwargs) -> Transform3d:
10441048
pr_point_fix_transform = Transform3d(
10451049
matrix=pr_point_fix.transpose(1, 2).contiguous(), device=self.device
10461050
)
1051+
image_size = kwargs.get("image_size", self.get_image_size())
10471052
screen_to_ndc_transform = get_screen_to_ndc_transform(
1048-
self, with_xyflip=False, **kwargs
1053+
self, with_xyflip=False, image_size=image_size
10491054
)
10501055
ndc_transform = pr_point_fix_transform.compose(screen_to_ndc_transform)
10511056

@@ -1105,7 +1110,7 @@ def __init__(
11051110
K: Optional[torch.Tensor] = None,
11061111
device: Device = "cpu",
11071112
in_ndc: bool = True,
1108-
image_size: Optional[torch.Tensor] = None,
1113+
image_size: Optional[Union[List, Tuple, torch.Tensor]] = None,
11091114
) -> None:
11101115
"""
11111116
@@ -1123,7 +1128,7 @@ def __init__(
11231128
K: (optional) A calibration matrix of shape (N, 4, 4)
11241129
If provided, don't need focal_length, principal_point, image_size
11251130
image_size: (height, width) of image size.
1126-
A tensor of shape (N, 2). Required for screen cameras.
1131+
A tensor of shape (N, 2) or list/tuple. Required for screen cameras.
11271132
device: torch.device or string
11281133
"""
11291134
# The initializer formats all inputs to torch tensors and broadcasts
@@ -1241,8 +1246,9 @@ def get_ndc_camera_transform(self, **kwargs) -> Transform3d:
12411246
pr_point_fix_transform = Transform3d(
12421247
matrix=pr_point_fix.transpose(1, 2).contiguous(), device=self.device
12431248
)
1249+
image_size = kwargs.get("image_size", self.get_image_size())
12441250
screen_to_ndc_transform = get_screen_to_ndc_transform(
1245-
self, with_xyflip=False, **kwargs
1251+
self, with_xyflip=False, image_size=image_size
12461252
)
12471253
ndc_transform = pr_point_fix_transform.compose(screen_to_ndc_transform)
12481254

@@ -1537,7 +1543,9 @@ def look_at_view_transform(
15371543

15381544

15391545
def get_ndc_to_screen_transform(
1540-
cameras, with_xyflip: bool = False, **kwargs
1546+
cameras,
1547+
with_xyflip: bool = False,
1548+
image_size: Optional[Union[List, Tuple, torch.Tensor]] = None,
15411549
) -> Transform3d:
15421550
"""
15431551
PyTorch3D NDC to screen conversion.
@@ -1563,7 +1571,6 @@ def get_ndc_to_screen_transform(
15631571
15641572
"""
15651573
# We require the image size, which is necessary for the transform
1566-
image_size = kwargs.get("image_size", cameras.get_image_size())
15671574
if image_size is None:
15681575
msg = "For NDC to screen conversion, image_size=(height, width) needs to be specified."
15691576
raise ValueError(msg)
@@ -1605,7 +1612,9 @@ def get_ndc_to_screen_transform(
16051612

16061613

16071614
def get_screen_to_ndc_transform(
1608-
cameras, with_xyflip: bool = False, **kwargs
1615+
cameras,
1616+
with_xyflip: bool = False,
1617+
image_size: Optional[Union[List, Tuple, torch.Tensor]] = None,
16091618
) -> Transform3d:
16101619
"""
16111620
Screen to PyTorch3D NDC conversion.
@@ -1631,6 +1640,8 @@ def get_screen_to_ndc_transform(
16311640
16321641
"""
16331642
transform = get_ndc_to_screen_transform(
1634-
cameras, with_xyflip=with_xyflip, **kwargs
1643+
cameras,
1644+
with_xyflip=with_xyflip,
1645+
image_size=image_size,
16351646
).inverse()
16361647
return transform

tests/test_render_meshes.py

+51
Original file line numberDiff line numberDiff line change
@@ -1146,3 +1146,54 @@ def test_simple_sphere_outside_zfar(self):
11461146
)
11471147

11481148
self.assertClose(rgb, image_ref, atol=0.05)
1149+
1150+
def test_cameras_kwarg(self):
1151+
"""
1152+
Test that when cameras are passed in as a kwarg the rendering
1153+
works as expected
1154+
"""
1155+
device = torch.device("cuda:0")
1156+
1157+
# Init mesh
1158+
sphere_mesh = ico_sphere(5, device)
1159+
verts_padded = sphere_mesh.verts_padded()
1160+
faces_padded = sphere_mesh.faces_padded()
1161+
feats = torch.ones_like(verts_padded, device=device)
1162+
textures = TexturesVertex(verts_features=feats)
1163+
sphere_mesh = Meshes(verts=verts_padded, faces=faces_padded, textures=textures)
1164+
1165+
# No elevation or azimuth rotation
1166+
R, T = look_at_view_transform(2.7, 0.0, 0.0)
1167+
for cam_type in (
1168+
FoVPerspectiveCameras,
1169+
FoVOrthographicCameras,
1170+
PerspectiveCameras,
1171+
OrthographicCameras,
1172+
):
1173+
cameras = cam_type(device=device, R=R, T=T)
1174+
1175+
# Init shader settings
1176+
materials = Materials(device=device)
1177+
lights = PointLights(device=device)
1178+
lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None]
1179+
1180+
raster_settings = RasterizationSettings(
1181+
image_size=512, blur_radius=0.0, faces_per_pixel=1
1182+
)
1183+
rasterizer = MeshRasterizer(raster_settings=raster_settings)
1184+
blend_params = BlendParams(1e-4, 1e-4, (0, 0, 0))
1185+
1186+
shader = HardPhongShader(
1187+
lights=lights,
1188+
materials=materials,
1189+
blend_params=blend_params,
1190+
)
1191+
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
1192+
1193+
# Cameras can be passed into the renderer in the forward pass
1194+
images = renderer(sphere_mesh, cameras=cameras)
1195+
rgb = images.squeeze()[..., :3].cpu().numpy()
1196+
image_ref = load_rgb_image(
1197+
"test_simple_sphere_light_phong_%s.png" % cam_type.__name__, DATA_DIR
1198+
)
1199+
self.assertClose(rgb, image_ref, atol=0.05)

0 commit comments

Comments
 (0)