6
6
"""
7
7
import os
8
8
import unittest
9
+ from collections import namedtuple
9
10
10
11
import numpy as np
11
12
import torch
53
54
DATA_DIR = get_tests_dir () / "data"
54
55
TUTORIAL_DATA_DIR = get_pytorch3d_dir () / "docs/tutorials/data"
55
56
57
+ ShaderTest = namedtuple ("ShaderTest" , ["shader" , "reference_name" , "debug_name" ])
58
+
56
59
57
60
class TestRenderMeshes (TestCaseMixin , unittest .TestCase ):
58
61
def test_simple_sphere (self , elevated_camera = False , check_depth = False ):
@@ -107,13 +110,13 @@ def test_simple_sphere(self, elevated_camera=False, check_depth=False):
107
110
blend_params = BlendParams (1e-4 , 1e-4 , (0 , 0 , 0 ))
108
111
109
112
# Test several shaders
110
- shaders = {
111
- "phong" : HardPhongShader ,
112
- "gouraud" : HardGouraudShader ,
113
- "flat" : HardFlatShader ,
114
- }
115
- for ( name , shader_init ) in shaders . items () :
116
- shader = shader_init (
113
+ shader_tests = [
114
+ ShaderTest ( HardPhongShader , "phong" , "hard_phong" ) ,
115
+ ShaderTest ( HardGouraudShader , "gouraud" , "hard_gouraud" ) ,
116
+ ShaderTest ( HardFlatShader , "flat" , "hard_flat" ) ,
117
+ ]
118
+ for test in shader_tests :
119
+ shader = test . shader (
117
120
lights = lights ,
118
121
cameras = cameras ,
119
122
materials = materials ,
@@ -135,7 +138,7 @@ def test_simple_sphere(self, elevated_camera=False, check_depth=False):
135
138
136
139
rgb = images [0 , ..., :3 ].squeeze ().cpu ()
137
140
filename = "simple_sphere_light_%s%s%s.png" % (
138
- name ,
141
+ test . reference_name ,
139
142
postfix ,
140
143
cam_type .__name__ ,
141
144
)
@@ -144,7 +147,12 @@ def test_simple_sphere(self, elevated_camera=False, check_depth=False):
144
147
self .assertClose (rgb , image_ref , atol = 0.05 )
145
148
146
149
if DEBUG :
147
- filename = "DEBUG_%s" % filename
150
+ debug_filename = "simple_sphere_light_%s%s%s.png" % (
151
+ test .debug_name ,
152
+ postfix ,
153
+ cam_type .__name__ ,
154
+ )
155
+ filename = "DEBUG_%s" % debug_filename
148
156
Image .fromarray ((rgb .numpy () * 255 ).astype (np .uint8 )).save (
149
157
DATA_DIR / filename
150
158
)
@@ -269,7 +277,8 @@ def test_simple_sphere_screen(self):
269
277
def test_simple_sphere_batched (self ):
270
278
"""
271
279
Test a mesh with vertex textures can be extended to form a batch, and
272
- is rendered correctly with Phong, Gouraud and Flat Shaders.
280
+ is rendered correctly with Phong, Gouraud and Flat Shaders with batched
281
+ lighting and hard and soft blending.
273
282
"""
274
283
batch_size = 5
275
284
device = torch .device ("cuda:0" )
@@ -291,24 +300,28 @@ def test_simple_sphere_batched(self):
291
300
R , T = look_at_view_transform (dist , elev , azim )
292
301
cameras = FoVPerspectiveCameras (device = device , R = R , T = T )
293
302
raster_settings = RasterizationSettings (
294
- image_size = 512 , blur_radius = 0.0 , faces_per_pixel = 1
303
+ image_size = 512 , blur_radius = 0.0 , faces_per_pixel = 4
295
304
)
296
305
297
306
# Init shader settings
298
307
materials = Materials (device = device )
299
- lights = PointLights (device = device )
300
- lights .location = torch .tensor ([0.0 , 0.0 , + 2.0 ], device = device )[None ]
308
+ lights_location = torch .tensor ([0.0 , 0.0 , + 2.0 ], device = device )
309
+ lights_location = lights_location [None ].expand (batch_size , - 1 )
310
+ lights = PointLights (device = device , location = lights_location )
301
311
blend_params = BlendParams (1e-4 , 1e-4 , (0 , 0 , 0 ))
302
312
303
313
# Init renderer
304
314
rasterizer = MeshRasterizer (cameras = cameras , raster_settings = raster_settings )
305
- shaders = {
306
- "phong" : HardPhongShader ,
307
- "gouraud" : HardGouraudShader ,
308
- "flat" : HardFlatShader ,
309
- }
310
- for (name , shader_init ) in shaders .items ():
311
- shader = shader_init (
315
+ shader_tests = [
316
+ ShaderTest (HardPhongShader , "phong" , "hard_phong" ),
317
+ ShaderTest (SoftPhongShader , "phong" , "soft_phong" ),
318
+ ShaderTest (HardGouraudShader , "gouraud" , "hard_gouraud" ),
319
+ ShaderTest (HardFlatShader , "flat" , "hard_flat" ),
320
+ ]
321
+ for test in shader_tests :
322
+ reference_name = test .reference_name
323
+ debug_name = test .debug_name
324
+ shader = test .shader (
312
325
lights = lights ,
313
326
cameras = cameras ,
314
327
materials = materials ,
@@ -317,14 +330,15 @@ def test_simple_sphere_batched(self):
317
330
renderer = MeshRenderer (rasterizer = rasterizer , shader = shader )
318
331
images = renderer (sphere_meshes )
319
332
image_ref = load_rgb_image (
320
- "test_simple_sphere_light_%s_%s.png" % (name , type (cameras ).__name__ ),
333
+ "test_simple_sphere_light_%s_%s.png"
334
+ % (reference_name , type (cameras ).__name__ ),
321
335
DATA_DIR ,
322
336
)
323
337
for i in range (batch_size ):
324
338
rgb = images [i , ..., :3 ].squeeze ().cpu ()
325
339
if i == 0 and DEBUG :
326
340
filename = "DEBUG_simple_sphere_batched_%s_%s.png" % (
327
- name ,
341
+ debug_name ,
328
342
type (cameras ).__name__ ,
329
343
)
330
344
Image .fromarray ((rgb .numpy () * 255 ).astype (np .uint8 )).save (
0 commit comments