From 5eaf0575cf53007faaa58f534cd7f0de61a04a5b Mon Sep 17 00:00:00 2001 From: Abhinav Agarwalla Date: Wed, 31 Jan 2024 20:16:35 -0500 Subject: [PATCH 1/2] Adding histogram quantizer --- .../quantization/utils/quantization_scheme.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/sparseml/modifiers/quantization/utils/quantization_scheme.py b/src/sparseml/modifiers/quantization/utils/quantization_scheme.py index b3ef1807227..6ba0b085835 100644 --- a/src/sparseml/modifiers/quantization/utils/quantization_scheme.py +++ b/src/sparseml/modifiers/quantization/utils/quantization_scheme.py @@ -121,7 +121,7 @@ def get_observer(self) -> "torch.quantization.FakeQuantize": @validator("strategy") def validate_strategy(cls, value): - valid_scopes = ["tensor", "channel"] + valid_scopes = ["tensor", "channel", "histogram"] if value not in valid_scopes: raise ValueError(f"`strategy` must be one of {valid_scopes}, got {value}") return value @@ -307,6 +307,14 @@ def get_observer( qscheme=qscheme, reduce_range=reduce_range, ) + elif strategy == "histrogram": + qscheme = torch.per_channel_symmetric if symmetric else torch.per_channel_affine + observer_cls = torch_quantization.HistogramObserver + observer_kwargs = dict( + dtype=dtype, + qscheme=qscheme, + reduce_range=reduce_range, + ) else: # default to tensor strategy qscheme = torch.per_tensor_symmetric if symmetric else torch.per_tensor_affine observer_cls = torch_quantization.MovingAverageMinMaxObserver From 86e3764da1b8f271fcb1d92f79edb8f0cad4c883 Mon Sep 17 00:00:00 2001 From: Abhinav Agarwalla Date: Thu, 8 Feb 2024 17:34:58 -0500 Subject: [PATCH 2/2] Fixing major typo --- .../modifiers/quantization/utils/quantization_scheme.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sparseml/modifiers/quantization/utils/quantization_scheme.py b/src/sparseml/modifiers/quantization/utils/quantization_scheme.py index 6ba0b085835..4936fa81385 100644 --- a/src/sparseml/modifiers/quantization/utils/quantization_scheme.py +++ b/src/sparseml/modifiers/quantization/utils/quantization_scheme.py @@ -307,8 +307,8 @@ def get_observer( qscheme=qscheme, reduce_range=reduce_range, ) - elif strategy == "histrogram": - qscheme = torch.per_channel_symmetric if symmetric else torch.per_channel_affine + elif strategy == "histogram": + qscheme = torch.per_tensor_symmetric if symmetric else torch.per_tensor_affine observer_cls = torch_quantization.HistogramObserver observer_kwargs = dict( dtype=dtype,