@@ -287,34 +287,36 @@ def configure_optimizers(self):
287
287
trainer .fit (model )
288
288
289
289
290
- def test_complex_nested_model ():
291
- """Test flattening, freezing, and thawing of models which contain parent (non-leaf) modules with parameters
292
- directly themselves rather than exclusively their submodules containing parameters."""
290
+ class ConvBlock (nn .Module ):
291
+ def __init__ (self , in_channels , out_channels ):
292
+ super ().__init__ ()
293
+ self .conv = nn .Conv2d (in_channels , out_channels , 3 )
294
+ self .act = nn .ReLU ()
295
+ self .bn = nn .BatchNorm2d (out_channels )
293
296
294
- class ConvBlock (nn .Module ):
295
- def __init__ (self , in_channels , out_channels ):
296
- super ().__init__ ()
297
- self .conv = nn .Conv2d (in_channels , out_channels , 3 )
298
- self .act = nn .ReLU ()
299
- self .bn = nn .BatchNorm2d (out_channels )
297
+ def forward (self , x ):
298
+ x = self .conv (x )
299
+ x = self .act (x )
300
+ return self .bn (x )
300
301
301
- def forward (self , x ):
302
- x = self .conv (x )
303
- x = self .act (x )
304
- return self .bn (x )
305
302
306
- class ConvBlockParam (nn .Module ):
307
- def __init__ (self , in_channels , out_channels ):
308
- super ().__init__ ()
309
- self .module_dict = nn .ModuleDict ({"conv" : nn .Conv2d (in_channels , out_channels , 3 ), "act" : nn .ReLU ()})
310
- # add trivial test parameter to convblock to validate parent (non-leaf) module parameter handling
311
- self .parent_param = nn .Parameter (torch .zeros ((1 ), dtype = torch .float ))
312
- self .bn = nn .BatchNorm2d (out_channels )
303
+ class ConvBlockParam (nn .Module ):
304
+ def __init__ (self , in_channels , out_channels ):
305
+ super ().__init__ ()
306
+ self .module_dict = nn .ModuleDict ({"conv" : nn .Conv2d (in_channels , out_channels , 3 ), "act" : nn .ReLU ()})
307
+ # add trivial test parameter to convblock to validate parent (non-leaf) module parameter handling
308
+ self .parent_param = nn .Parameter (torch .zeros ((1 ), dtype = torch .float ))
309
+ self .bn = nn .BatchNorm2d (out_channels )
313
310
314
- def forward (self , x ):
315
- x = self .module_dict ["conv" ](x )
316
- x = self .module_dict ["act" ](x )
317
- return self .bn (x )
311
+ def forward (self , x ):
312
+ x = self .module_dict ["conv" ](x )
313
+ x = self .module_dict ["act" ](x )
314
+ return self .bn (x )
315
+
316
+
317
+ def test_complex_nested_model ():
318
+ """Test flattening, freezing, and thawing of models which contain parent (non-leaf) modules with parameters
319
+ directly themselves rather than exclusively their submodules containing parameters."""
318
320
319
321
model = nn .Sequential (
320
322
OrderedDict (
0 commit comments