diff --git a/src/sparseml/modifiers/quantization/utils/quantization_scheme.py b/src/sparseml/modifiers/quantization/utils/quantization_scheme.py index b3ef1807227..4936fa81385 100644 --- a/src/sparseml/modifiers/quantization/utils/quantization_scheme.py +++ b/src/sparseml/modifiers/quantization/utils/quantization_scheme.py @@ -121,7 +121,7 @@ def get_observer(self) -> "torch.quantization.FakeQuantize": @validator("strategy") def validate_strategy(cls, value): - valid_scopes = ["tensor", "channel"] + valid_scopes = ["tensor", "channel", "histogram"] if value not in valid_scopes: raise ValueError(f"`strategy` must be one of {valid_scopes}, got {value}") return value @@ -307,6 +307,14 @@ def get_observer( qscheme=qscheme, reduce_range=reduce_range, ) + elif strategy == "histogram": + qscheme = torch.per_tensor_symmetric if symmetric else torch.per_tensor_affine + observer_cls = torch_quantization.HistogramObserver + observer_kwargs = dict( + dtype=dtype, + qscheme=qscheme, + reduce_range=reduce_range, + ) else: # default to tensor strategy qscheme = torch.per_tensor_symmetric if symmetric else torch.per_tensor_affine observer_cls = torch_quantization.MovingAverageMinMaxObserver