Skip to content

Commit 20379d9

Browse files
authored
[tests] add tests for combining layerwise upcasting and groupoffloading. (#11558)
* add tests for combining layerwise upcasting and groupoffloading. * feedback
1 parent 3a6caba commit 20379d9

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

tests/models/test_modeling_common.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1580,6 +1580,34 @@ def run_forward(model):
15801580
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading3, atol=1e-5))
15811581
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading4, atol=1e-5))
15821582

1583+
@parameterized.expand([(False, "block_level"), (True, "leaf_level")])
1584+
@require_torch_accelerator
1585+
@torch.no_grad()
1586+
def test_group_offloading_with_layerwise_casting(self, record_stream, offload_type):
1587+
torch.manual_seed(0)
1588+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1589+
model = self.model_class(**init_dict)
1590+
1591+
if not getattr(model, "_supports_group_offloading", True):
1592+
return
1593+
1594+
model.to(torch_device)
1595+
model.eval()
1596+
_ = model(**inputs_dict)[0]
1597+
1598+
torch.manual_seed(0)
1599+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1600+
storage_dtype, compute_dtype = torch.float16, torch.float32
1601+
inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype)
1602+
model = self.model_class(**init_dict)
1603+
model.eval()
1604+
additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": 1}
1605+
model.enable_group_offload(
1606+
torch_device, offload_type=offload_type, use_stream=True, record_stream=record_stream, **additional_kwargs
1607+
)
1608+
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
1609+
_ = model(**inputs_dict)[0]
1610+
15831611
def test_auto_model(self, expected_max_diff=5e-5):
15841612
if self.forward_requires_fresh_args:
15851613
model = self.model_class(**self.init_dict)

0 commit comments

Comments
 (0)