@@ -1580,6 +1580,34 @@ def run_forward(model):
1580
1580
self .assertTrue (torch .allclose (output_without_group_offloading , output_with_group_offloading3 , atol = 1e-5 ))
1581
1581
self .assertTrue (torch .allclose (output_without_group_offloading , output_with_group_offloading4 , atol = 1e-5 ))
1582
1582
1583
+ @parameterized .expand ([(False , "block_level" ), (True , "leaf_level" )])
1584
+ @require_torch_accelerator
1585
+ @torch .no_grad ()
1586
+ def test_group_offloading_with_layerwise_casting (self , record_stream , offload_type ):
1587
+ torch .manual_seed (0 )
1588
+ init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
1589
+ model = self .model_class (** init_dict )
1590
+
1591
+ if not getattr (model , "_supports_group_offloading" , True ):
1592
+ return
1593
+
1594
+ model .to (torch_device )
1595
+ model .eval ()
1596
+ _ = model (** inputs_dict )[0 ]
1597
+
1598
+ torch .manual_seed (0 )
1599
+ init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
1600
+ storage_dtype , compute_dtype = torch .float16 , torch .float32
1601
+ inputs_dict = cast_maybe_tensor_dtype (inputs_dict , torch .float32 , compute_dtype )
1602
+ model = self .model_class (** init_dict )
1603
+ model .eval ()
1604
+ additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group" : 1 }
1605
+ model .enable_group_offload (
1606
+ torch_device , offload_type = offload_type , use_stream = True , record_stream = record_stream , ** additional_kwargs
1607
+ )
1608
+ model .enable_layerwise_casting (storage_dtype = storage_dtype , compute_dtype = compute_dtype )
1609
+ _ = model (** inputs_dict )[0 ]
1610
+
1583
1611
def test_auto_model (self , expected_max_diff = 5e-5 ):
1584
1612
if self .forward_requires_fresh_args :
1585
1613
model = self .model_class (** self .init_dict )
0 commit comments