Skip to content

Commit 477ddb6

Browse files
authored
Add option to move param to device before quantization (#699)
* add device argument to quantize_() * fix test * add test * remove print * fix * remove timing check
1 parent b523f9f commit 477ddb6

File tree

2 files changed

+41
-2
lines changed

2 files changed

+41
-2
lines changed

test/quantization/test_quant_api.py

+24
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
from torchao.utils import unwrap_tensor_subclass
5555
import copy
5656
import tempfile
57+
import gc
5758
from torch.testing._internal.common_utils import TestCase
5859

5960

@@ -680,6 +681,29 @@ def test_quantized_tensor_subclass_save_load_map_location(self):
680681
res = m_copy(*example_inputs)
681682
self.assertEqual(res, ref)
682683

684+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
685+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
686+
def test_quantized_model_streaming(self):
687+
def reset_memory():
688+
gc.collect()
689+
torch.cuda.empty_cache()
690+
torch.cuda.reset_peak_memory_stats()
691+
692+
reset_memory()
693+
m = ToyLinearModel()
694+
quantize_(m.to(device="cuda"), int8_weight_only())
695+
memory_baseline = torch.cuda.max_memory_allocated()
696+
697+
del m
698+
reset_memory()
699+
m = ToyLinearModel()
700+
quantize_(m, int8_weight_only(), device="cuda")
701+
memory_streaming = torch.cuda.max_memory_allocated()
702+
703+
for param in m.parameters():
704+
assert param.is_cuda
705+
self.assertLess(memory_streaming, memory_baseline)
706+
683707

684708
if __name__ == "__main__":
685709
unittest.main()

torchao/quantization/quant_api.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ def _replace_with_custom_fn_if_matches_filter(
161161
replacement_fn,
162162
filter_fn,
163163
cur_fqn="",
164+
device=None,
164165
) -> None:
165166
"""
166167
Recursively replaces each child module in `model` with the result of `replacement_fn(child)`
@@ -171,20 +172,25 @@ def _replace_with_custom_fn_if_matches_filter(
171172
replacement_fn (Callable[[torch.nn.Module], torch.nn.Module]): The function to replace matching modules.
172173
filter_fn (Callable[[torch.nn.Module], bool]): The filter function to determine which modules to replace.
173174
cur_fqn (str, optional): The current fully qualified name of the module being processed. Defaults to "".
175+
device (device, optional): Device to move the model to before applying `filter_fn`. Defaults to None.
174176
175177
Returns:
176178
None
177179
"""
178180
if filter_fn(model, cur_fqn[:-1]):
181+
if device is not None:
182+
model.to(device=device) # move to device before quantization
179183
model = replacement_fn(model)
180184
return model
181185
else:
182186
for name, child in model.named_children():
183187
new_child = _replace_with_custom_fn_if_matches_filter(
184-
child, replacement_fn, filter_fn, f"{cur_fqn}{name}."
188+
child, replacement_fn, filter_fn, f"{cur_fqn}{name}.", device
185189
)
186190
if new_child is not child:
187191
setattr(model, name, new_child)
192+
if device is not None:
193+
model.to(device=device) # move parent module to device
188194
return model
189195

190196

@@ -269,7 +275,13 @@ def insert_subclass(lin):
269275

270276
return insert_subclass
271277

272-
def quantize_(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.nn.Module], torch.nn.Module], filter_fn: Optional[Callable[[torch.nn.Module, str], bool]]=None, set_inductor_config: bool=True):
278+
def quantize_(
279+
model: torch.nn.Module,
280+
apply_tensor_subclass: Callable[[torch.nn.Module], torch.nn.Module],
281+
filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None,
282+
set_inductor_config: bool = True,
283+
device: Optional[torch.types.Device] = None,
284+
):
273285
"""Convert the weight of linear modules in the model with `apply_tensor_subclass`, model is modified inplace
274286
275287
Args:
@@ -278,6 +290,8 @@ def quantize_(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.nn.
278290
filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to run `apply_tensor_subclass` on
279291
the weight of the module
280292
set_inductor_config (bool, optional): Whether to automatically use recommended inductor config settings (defaults to True)
293+
device (device, optional): Device to move module to before applying `filter_fn`. This can be set to `"cuda"` to speed up quantization. The final model will be on the specified `device`.
294+
Defaults to None (do not change device).
281295
282296
Example::
283297
@@ -329,6 +343,7 @@ def filter_fn(module: nn.Module, fqn: str) -> bool:
329343
model,
330344
apply_tensor_subclass,
331345
_is_linear if filter_fn is None else filter_fn,
346+
device=device,
332347
)
333348

334349
def _int8_asymm_per_token_quant(x: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)