Skip to content

Commit 3a6caba

Browse files
anijain2305sayakpaulgithub-actions[bot]
authored
[gguf] Refactor __torch_function__ to avoid unnecessary computation (#11551)
* [gguf] Refactor __torch_function__ to avoid unnecessary computation This helps with torch.compile compilation lantency. Avoiding unnecessary computation should also lead to a slightly improved eager latency. * Apply style fixes --------- Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 4267d8f commit 3a6caba

File tree

1 file changed

+15
-12
lines changed

1 file changed

+15
-12
lines changed

src/diffusers/quantizers/gguf/utils.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -408,29 +408,32 @@ def __new__(cls, data, requires_grad=False, quant_type=None):
408408
def as_tensor(self):
409409
return torch.Tensor._make_subclass(torch.Tensor, self, self.requires_grad)
410410

411+
@staticmethod
412+
def _extract_quant_type(args):
413+
# When converting from original format checkpoints we often use splits, cats etc on tensors
414+
# this method ensures that the returned tensor type from those operations remains GGUFParameter
415+
# so that we preserve quant_type information
416+
for arg in args:
417+
if isinstance(arg, list) and isinstance(arg[0], GGUFParameter):
418+
return arg[0].quant_type
419+
if isinstance(arg, GGUFParameter):
420+
return arg.quant_type
421+
return None
422+
411423
@classmethod
412424
def __torch_function__(cls, func, types, args=(), kwargs=None):
413425
if kwargs is None:
414426
kwargs = {}
415427

416428
result = super().__torch_function__(func, types, args, kwargs)
417429

418-
# When converting from original format checkpoints we often use splits, cats etc on tensors
419-
# this method ensures that the returned tensor type from those operations remains GGUFParameter
420-
# so that we preserve quant_type information
421-
quant_type = None
422-
for arg in args:
423-
if isinstance(arg, list) and isinstance(arg[0], GGUFParameter):
424-
quant_type = arg[0].quant_type
425-
break
426-
if isinstance(arg, GGUFParameter):
427-
quant_type = arg.quant_type
428-
break
429430
if isinstance(result, torch.Tensor):
431+
quant_type = cls._extract_quant_type(args)
430432
return cls(result, quant_type=quant_type)
431433
# Handle tuples and lists
432-
elif isinstance(result, (tuple, list)):
434+
elif type(result) in (list, tuple):
433435
# Preserve the original type (tuple or list)
436+
quant_type = cls._extract_quant_type(args)
434437
wrapped = [cls(x, quant_type=quant_type) if isinstance(x, torch.Tensor) else x for x in result]
435438
return type(result)(wrapped)
436439
else:

0 commit comments

Comments
 (0)