@@ -161,6 +161,7 @@ def _replace_with_custom_fn_if_matches_filter(
161
161
replacement_fn ,
162
162
filter_fn ,
163
163
cur_fqn = "" ,
164
+ device = None ,
164
165
) -> None :
165
166
"""
166
167
Recursively replaces each child module in `model` with the result of `replacement_fn(child)`
@@ -171,20 +172,25 @@ def _replace_with_custom_fn_if_matches_filter(
171
172
replacement_fn (Callable[[torch.nn.Module], torch.nn.Module]): The function to replace matching modules.
172
173
filter_fn (Callable[[torch.nn.Module], bool]): The filter function to determine which modules to replace.
173
174
cur_fqn (str, optional): The current fully qualified name of the module being processed. Defaults to "".
175
+ device (device, optional): Device to move the model to before applying `filter_fn`. Defaults to None.
174
176
175
177
Returns:
176
178
None
177
179
"""
178
180
if filter_fn (model , cur_fqn [:- 1 ]):
181
+ if device is not None :
182
+ model .to (device = device ) # move to device before quantization
179
183
model = replacement_fn (model )
180
184
return model
181
185
else :
182
186
for name , child in model .named_children ():
183
187
new_child = _replace_with_custom_fn_if_matches_filter (
184
- child , replacement_fn , filter_fn , f"{ cur_fqn } { name } ."
188
+ child , replacement_fn , filter_fn , f"{ cur_fqn } { name } ." , device
185
189
)
186
190
if new_child is not child :
187
191
setattr (model , name , new_child )
192
+ if device is not None :
193
+ model .to (device = device ) # move parent module to device
188
194
return model
189
195
190
196
@@ -269,7 +275,13 @@ def insert_subclass(lin):
269
275
270
276
return insert_subclass
271
277
272
- def quantize_ (model : torch .nn .Module , apply_tensor_subclass : Callable [[torch .nn .Module ], torch .nn .Module ], filter_fn : Optional [Callable [[torch .nn .Module , str ], bool ]]= None , set_inductor_config : bool = True ):
278
+ def quantize_ (
279
+ model : torch .nn .Module ,
280
+ apply_tensor_subclass : Callable [[torch .nn .Module ], torch .nn .Module ],
281
+ filter_fn : Optional [Callable [[torch .nn .Module , str ], bool ]] = None ,
282
+ set_inductor_config : bool = True ,
283
+ device : Optional [torch .types .Device ] = None ,
284
+ ):
273
285
"""Convert the weight of linear modules in the model with `apply_tensor_subclass`, model is modified inplace
274
286
275
287
Args:
@@ -278,6 +290,8 @@ def quantize_(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.nn.
278
290
filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to run `apply_tensor_subclass` on
279
291
the weight of the module
280
292
set_inductor_config (bool, optional): Whether to automatically use recommended inductor config settings (defaults to True)
293
+ device (device, optional): Device to move module to before applying `filter_fn`. This can be set to `"cuda"` to speed up quantization. The final model will be on the specified `device`.
294
+ Defaults to None (do not change device).
281
295
282
296
Example::
283
297
@@ -329,6 +343,7 @@ def filter_fn(module: nn.Module, fqn: str) -> bool:
329
343
model ,
330
344
apply_tensor_subclass ,
331
345
_is_linear if filter_fn is None else filter_fn ,
346
+ device = device ,
332
347
)
333
348
334
349
def _int8_asymm_per_token_quant (x : torch .Tensor ) -> torch .Tensor :
0 commit comments