-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Adding multiweight support to Quantized InceptionV3 #4850
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
bff540c
b2c71c9
0ce4670
aa1adfb
c6245f7
a11b20d
bc51d53
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,72 +24,6 @@ | |
} | ||
|
||
|
||
def inception_v3( | ||
pretrained: bool = False, | ||
progress: bool = True, | ||
quantize: bool = False, | ||
**kwargs: Any, | ||
) -> "QuantizableInception3": | ||
|
||
r"""Inception v3 model architecture from | ||
`"Rethinking the Inception Architecture for Computer Vision" <http://arxiv.org/abs/1512.00567>`_. | ||
|
||
.. note:: | ||
**Important**: In contrast to the other models the inception_v3 expects tensors with a size of | ||
N x 3 x 299 x 299, so ensure your images are sized accordingly. | ||
|
||
Note that quantize = True returns a quantized model with 8 bit | ||
weights. Quantized models only support inference and run on CPUs. | ||
GPU inference is not yet supported | ||
|
||
Args: | ||
pretrained (bool): If True, returns a model pre-trained on ImageNet | ||
progress (bool): If True, displays a progress bar of the download to stderr | ||
quantize (bool): If True, return a quantized version of the model | ||
aux_logits (bool): If True, add an auxiliary branch that can improve training. | ||
Default: *True* | ||
transform_input (bool): If True, preprocesses the input according to the method with which it | ||
was trained on ImageNet. Default: *False* | ||
""" | ||
if pretrained: | ||
if "transform_input" not in kwargs: | ||
kwargs["transform_input"] = True | ||
if "aux_logits" in kwargs: | ||
original_aux_logits = kwargs["aux_logits"] | ||
kwargs["aux_logits"] = True | ||
else: | ||
original_aux_logits = False | ||
|
||
model = QuantizableInception3(**kwargs) | ||
_replace_relu(model) | ||
|
||
if quantize: | ||
# TODO use pretrained as a string to specify the backend | ||
backend = "fbgemm" | ||
quantize_model(model, backend) | ||
else: | ||
assert pretrained in [True, False] | ||
|
||
if pretrained: | ||
if quantize: | ||
if not original_aux_logits: | ||
model.aux_logits = False | ||
model.AuxLogits = None | ||
model_url = quant_model_urls["inception_v3_google_" + backend] | ||
else: | ||
model_url = inception_module.model_urls["inception_v3_google"] | ||
|
||
state_dict = load_state_dict_from_url(model_url, progress=progress) | ||
|
||
model.load_state_dict(state_dict) | ||
|
||
if not quantize: | ||
if not original_aux_logits: | ||
model.aux_logits = False | ||
model.AuxLogits = None | ||
return model | ||
|
||
|
||
class QuantizableBasicConv2d(inception_module.BasicConv2d): | ||
def __init__(self, *args: Any, **kwargs: Any) -> None: | ||
super().__init__(*args, **kwargs) | ||
|
@@ -237,3 +171,68 @@ def fuse_model(self) -> None: | |
for m in self.modules(): | ||
if type(m) is QuantizableBasicConv2d: | ||
m.fuse_model() | ||
|
||
|
||
def inception_v3( | ||
pretrained: bool = False, | ||
progress: bool = True, | ||
quantize: bool = False, | ||
**kwargs: Any, | ||
) -> QuantizableInception3: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No review, just copy-paste. |
||
r"""Inception v3 model architecture from | ||
`"Rethinking the Inception Architecture for Computer Vision" <http://arxiv.org/abs/1512.00567>`_. | ||
|
||
.. note:: | ||
**Important**: In contrast to the other models the inception_v3 expects tensors with a size of | ||
N x 3 x 299 x 299, so ensure your images are sized accordingly. | ||
|
||
Note that quantize = True returns a quantized model with 8 bit | ||
weights. Quantized models only support inference and run on CPUs. | ||
GPU inference is not yet supported | ||
|
||
Args: | ||
pretrained (bool): If True, returns a model pre-trained on ImageNet | ||
progress (bool): If True, displays a progress bar of the download to stderr | ||
quantize (bool): If True, return a quantized version of the model | ||
aux_logits (bool): If True, add an auxiliary branch that can improve training. | ||
Default: *True* | ||
transform_input (bool): If True, preprocesses the input according to the method with which it | ||
was trained on ImageNet. Default: *False* | ||
""" | ||
if pretrained: | ||
if "transform_input" not in kwargs: | ||
kwargs["transform_input"] = True | ||
if "aux_logits" in kwargs: | ||
original_aux_logits = kwargs["aux_logits"] | ||
kwargs["aux_logits"] = True | ||
else: | ||
original_aux_logits = False | ||
|
||
model = QuantizableInception3(**kwargs) | ||
_replace_relu(model) | ||
|
||
if quantize: | ||
# TODO use pretrained as a string to specify the backend | ||
backend = "fbgemm" | ||
quantize_model(model, backend) | ||
else: | ||
assert pretrained in [True, False] | ||
|
||
if pretrained: | ||
if quantize: | ||
if not original_aux_logits: | ||
model.aux_logits = False | ||
model.AuxLogits = None | ||
model_url = quant_model_urls["inception_v3_google_" + backend] | ||
else: | ||
model_url = inception_module.model_urls["inception_v3_google"] | ||
|
||
state_dict = load_state_dict_from_url(model_url, progress=progress) | ||
|
||
model.load_state_dict(state_dict) | ||
|
||
if not quantize: | ||
if not original_aux_logits: | ||
model.aux_logits = False | ||
model.AuxLogits = None | ||
return model |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,10 +10,10 @@ | |
from ._meta import _IMAGENET_CATEGORIES | ||
|
||
|
||
__all__ = ["Inception3", "InceptionOutputs", "_InceptionOutputs", "Inception3Weights", "inception_v3"] | ||
__all__ = ["Inception3", "InceptionOutputs", "_InceptionOutputs", "InceptionV3Weights", "inception_v3"] | ||
|
||
|
||
class Inception3Weights(Weights): | ||
class InceptionV3Weights(Weights): | ||
ImageNet1K_TFV1 = WeightEntry( | ||
url="https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth", | ||
transforms=partial(ImageNetEval, crop_size=299, resize_size=342), | ||
|
@@ -28,11 +28,11 @@ class Inception3Weights(Weights): | |
) | ||
|
||
|
||
def inception_v3(weights: Optional[Inception3Weights] = None, progress: bool = True, **kwargs: Any) -> Inception3: | ||
def inception_v3(weights: Optional[InceptionV3Weights] = None, progress: bool = True, **kwargs: Any) -> Inception3: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The naming is inconsistent. The model class is called There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agree! Lets tackle it in the cleanup |
||
if "pretrained" in kwargs: | ||
warnings.warn("The argument pretrained is deprecated, please use weights instead.") | ||
weights = Inception3Weights.ImageNet1K_TFV1 if kwargs.pop("pretrained") else None | ||
weights = Inception3Weights.verify(weights) | ||
weights = InceptionV3Weights.ImageNet1K_TFV1 if kwargs.pop("pretrained") else None | ||
weights = InceptionV3Weights.verify(weights) | ||
|
||
original_aux_logits = kwargs.get("aux_logits", True) | ||
if weights is not None: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
from .googlenet import * | ||
from .inception import * | ||
from .resnet import * |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import warnings | ||
from functools import partial | ||
from typing import Any, Optional, Union | ||
|
||
from torchvision.transforms.functional import InterpolationMode | ||
|
||
from ....models.quantization.inception import ( | ||
QuantizableInception3, | ||
_replace_relu, | ||
quantize_model, | ||
) | ||
from ...transforms.presets import ImageNetEval | ||
from .._api import Weights, WeightEntry | ||
from .._meta import _IMAGENET_CATEGORIES | ||
from ..inception import InceptionV3Weights | ||
|
||
|
||
__all__ = [ | ||
"QuantizableInception3", | ||
"QuantizedInceptionV3Weights", | ||
"inception_v3", | ||
] | ||
|
||
|
||
class QuantizedInceptionV3Weights(Weights): | ||
ImageNet1K_FBGEMM_TFV1 = WeightEntry( | ||
url="https://download.pytorch.org/models/quantized/inception_v3_google_fbgemm-71447a44.pth", | ||
transforms=partial(ImageNetEval, crop_size=299, resize_size=342), | ||
meta={ | ||
"size": (299, 299), | ||
"categories": _IMAGENET_CATEGORIES, | ||
"interpolation": InterpolationMode.BILINEAR, | ||
"backend": "fbgemm", | ||
"quantization": "ptq", | ||
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models", | ||
"unquantized": InceptionV3Weights.ImageNet1K_TFV1, | ||
"acc@1": 77.176, | ||
"acc@5": 93.354, | ||
}, | ||
) | ||
|
||
|
||
def inception_v3( | ||
weights: Optional[Union[QuantizedInceptionV3Weights, InceptionV3Weights]] = None, | ||
progress: bool = True, | ||
quantize: bool = False, | ||
**kwargs: Any, | ||
) -> QuantizableInception3: | ||
if "pretrained" in kwargs: | ||
warnings.warn("The argument pretrained is deprecated, please use weights instead.") | ||
if kwargs.pop("pretrained"): | ||
weights = ( | ||
QuantizedInceptionV3Weights.ImageNet1K_FBGEMM_TFV1 if quantize else InceptionV3Weights.ImageNet1K_TFV1 | ||
) | ||
else: | ||
weights = None | ||
|
||
if quantize: | ||
weights = QuantizedInceptionV3Weights.verify(weights) | ||
else: | ||
weights = InceptionV3Weights.verify(weights) | ||
|
||
original_aux_logits = kwargs.get("aux_logits", False) | ||
if weights is not None: | ||
if "transform_input" not in kwargs: | ||
kwargs["transform_input"] = True | ||
kwargs["aux_logits"] = True | ||
kwargs["num_classes"] = len(weights.meta["categories"]) | ||
if "backend" in weights.meta: | ||
kwargs["backend"] = weights.meta["backend"] | ||
backend = kwargs.pop("backend", "fbgemm") | ||
|
||
model = QuantizableInception3(**kwargs) | ||
_replace_relu(model) | ||
if quantize: | ||
quantize_model(model, backend) | ||
|
||
if weights is not None: | ||
if quantize and not original_aux_logits: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Special handling needed here. See original implementation. |
||
model.aux_logits = False | ||
model.AuxLogits = None | ||
model.load_state_dict(weights.state_dict(progress=progress)) | ||
if not quantize and not original_aux_logits: | ||
model.aux_logits = False | ||
model.AuxLogits = None | ||
|
||
return model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moving the builder on the bottom of the page to use proper typing.