Skip to content

Bug in culling: The size of tensor a (3) must match the size of tensor b (XXX) at non-singleton dimension 1 #631

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Algomorph opened this issue Apr 8, 2021 · 2 comments
Assignees
Labels
potential-bug Potential bug to flag an issue that needs to be looked into

Comments

@Algomorph
Copy link

Instructions To Reproduce the Issue:

Try running this code:

import numpy as np
import pytorch3d.utils
from pytorch3d.renderer.cameras import PerspectiveCameras
from pytorch3d.renderer.mesh import MeshRasterizer, RasterizationSettings
import torch

intrinsic_matrix = np.array([[575.548, 0.,     323.172, 0.],
                             [0.,      577.46, 236.417, 0.],
                             [0.,      0.,     1.,      0.],
                             [0.,      0., 0., 1.        ]])
intrinsic_matrix_torch = torch.from_numpy(intrinsic_matrix).cuda().unsqueeze(0)

torch_device = torch.device("cuda:0")
meshes_torch3d = pytorch3d.utils.torus(20.0, 85.0, 32, 16, device=torch_device)
image_size = (480, 640)

camera_rotation = torch.eye(3, dtype=torch.float32, device=torch_device)[None]  # (1, 3, 3)
camera_translation = torch.zeros(1, 3, dtype=torch.float32, device=torch_device)  # (1, 3)

cameras = PerspectiveCameras(device=torch_device,
                             R=camera_rotation,
                             T=camera_translation,
                             K=intrinsic_matrix_torch,
                             image_size=[image_size])

rasterization_settings = RasterizationSettings(image_size=image_size, cull_backfaces=True, cull_to_frustum=True)
rasterizer = MeshRasterizer(cameras, raster_settings=rasterization_settings)
fragments = rasterizer.forward(meshes_torch3d)

The output will be:

Traceback (most recent call last):
  File "/home/algomorph/Workbench/NeuralTracking/pipeline/rendering_test.py", line 51, in <module>
    sys.exit(main())
  File "/home/algomorph/Workbench/NeuralTracking/pipeline/rendering_test.py", line 44, in main
    fragments = rasterizer.forward(meshes_torch3d)
  File "/home/algomorph/.local/lib/python3.8/site-packages/pytorch3d/renderer/mesh/rasterizer.py", line 164, in forward
    pix_to_face, zbuf, bary_coords, dists = rasterize_meshes(
  File "/home/algomorph/.local/lib/python3.8/site-packages/pytorch3d/renderer/mesh/rasterize_meshes.py", line 175, in rasterize_meshes
    clipped_faces = clip_faces(
  File "/home/algomorph/.local/lib/python3.8/site-packages/pytorch3d/renderer/mesh/clip.py", line 400, in clip_faces
    cases1_unclipped = (faces_num_clipped_verts == 0) & faces_unculled
RuntimeError: The size of tensor a (3) must match the size of tensor b (1024) at non-singleton dimension 1

Version info:
pytorch3d built from commit cc08c6b (from this this morning)
pytorch3d.version : '0.4.0'
torch.version: '1.7.0a0+57bffc3' ( == 1.7.1 release, built from source)
torch.version.cuda: '11.1'
Python: 3.8.5
OS: Ubuntu 20.04

@Algomorph
Copy link
Author

I'm guessing this line here is the source of the bug:

faces_num_clipped_verts = torch.zeros([F, 3], device=device)

Should be instead:

faces_num_clipped_verts = torch.zeros([F], device=device)

@nikhilaravi nikhilaravi self-assigned this Apr 8, 2021
@nikhilaravi nikhilaravi added the potential-bug Potential bug to flag an issue that needs to be looked into label Apr 8, 2021
@nikhilaravi
Copy link
Contributor

This should now be fixed by a0f7931.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
potential-bug Potential bug to flag an issue that needs to be looked into
Projects
None yet
Development

No branches or pull requests

2 participants