Skip to content

Adding multiweight support to Quantized InceptionV3 #4850

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 65 additions & 66 deletions torchvision/models/quantization/inception.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,72 +24,6 @@
}


def inception_v3(
pretrained: bool = False,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> "QuantizableInception3":
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moving the builder on the bottom of the page to use proper typing.


r"""Inception v3 model architecture from
`"Rethinking the Inception Architecture for Computer Vision" <http://arxiv.org/abs/1512.00567>`_.

.. note::
**Important**: In contrast to the other models the inception_v3 expects tensors with a size of
N x 3 x 299 x 299, so ensure your images are sized accordingly.

Note that quantize = True returns a quantized model with 8 bit
weights. Quantized models only support inference and run on CPUs.
GPU inference is not yet supported

Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
quantize (bool): If True, return a quantized version of the model
aux_logits (bool): If True, add an auxiliary branch that can improve training.
Default: *True*
transform_input (bool): If True, preprocesses the input according to the method with which it
was trained on ImageNet. Default: *False*
"""
if pretrained:
if "transform_input" not in kwargs:
kwargs["transform_input"] = True
if "aux_logits" in kwargs:
original_aux_logits = kwargs["aux_logits"]
kwargs["aux_logits"] = True
else:
original_aux_logits = False

model = QuantizableInception3(**kwargs)
_replace_relu(model)

if quantize:
# TODO use pretrained as a string to specify the backend
backend = "fbgemm"
quantize_model(model, backend)
else:
assert pretrained in [True, False]

if pretrained:
if quantize:
if not original_aux_logits:
model.aux_logits = False
model.AuxLogits = None
model_url = quant_model_urls["inception_v3_google_" + backend]
else:
model_url = inception_module.model_urls["inception_v3_google"]

state_dict = load_state_dict_from_url(model_url, progress=progress)

model.load_state_dict(state_dict)

if not quantize:
if not original_aux_logits:
model.aux_logits = False
model.AuxLogits = None
return model


class QuantizableBasicConv2d(inception_module.BasicConv2d):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -237,3 +171,68 @@ def fuse_model(self) -> None:
for m in self.modules():
if type(m) is QuantizableBasicConv2d:
m.fuse_model()


def inception_v3(
pretrained: bool = False,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableInception3:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No review, just copy-paste.

r"""Inception v3 model architecture from
`"Rethinking the Inception Architecture for Computer Vision" <http://arxiv.org/abs/1512.00567>`_.

.. note::
**Important**: In contrast to the other models the inception_v3 expects tensors with a size of
N x 3 x 299 x 299, so ensure your images are sized accordingly.

Note that quantize = True returns a quantized model with 8 bit
weights. Quantized models only support inference and run on CPUs.
GPU inference is not yet supported

Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
quantize (bool): If True, return a quantized version of the model
aux_logits (bool): If True, add an auxiliary branch that can improve training.
Default: *True*
transform_input (bool): If True, preprocesses the input according to the method with which it
was trained on ImageNet. Default: *False*
"""
if pretrained:
if "transform_input" not in kwargs:
kwargs["transform_input"] = True
if "aux_logits" in kwargs:
original_aux_logits = kwargs["aux_logits"]
kwargs["aux_logits"] = True
else:
original_aux_logits = False

model = QuantizableInception3(**kwargs)
_replace_relu(model)

if quantize:
# TODO use pretrained as a string to specify the backend
backend = "fbgemm"
quantize_model(model, backend)
else:
assert pretrained in [True, False]

if pretrained:
if quantize:
if not original_aux_logits:
model.aux_logits = False
model.AuxLogits = None
model_url = quant_model_urls["inception_v3_google_" + backend]
else:
model_url = inception_module.model_urls["inception_v3_google"]

state_dict = load_state_dict_from_url(model_url, progress=progress)

model.load_state_dict(state_dict)

if not quantize:
if not original_aux_logits:
model.aux_logits = False
model.AuxLogits = None
return model
10 changes: 5 additions & 5 deletions torchvision/prototype/models/inception.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
from ._meta import _IMAGENET_CATEGORIES


__all__ = ["Inception3", "InceptionOutputs", "_InceptionOutputs", "Inception3Weights", "inception_v3"]
__all__ = ["Inception3", "InceptionOutputs", "_InceptionOutputs", "InceptionV3Weights", "inception_v3"]


class Inception3Weights(Weights):
class InceptionV3Weights(Weights):
ImageNet1K_TFV1 = WeightEntry(
url="https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth",
transforms=partial(ImageNetEval, crop_size=299, resize_size=342),
Expand All @@ -28,11 +28,11 @@ class Inception3Weights(Weights):
)


def inception_v3(weights: Optional[Inception3Weights] = None, progress: bool = True, **kwargs: Any) -> Inception3:
def inception_v3(weights: Optional[InceptionV3Weights] = None, progress: bool = True, **kwargs: Any) -> Inception3:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The naming is inconsistent. The model class is called Inception3, while the mode builder inception_v3. It's unclear how we should name the weights. Following the convention from ResNets this should have been Inception_V3 which is ugly. I propose not to spend more time on this now and tackle it at #4652. I've already added a reference for this issue.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree! Lets tackle it in the cleanup

if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
weights = Inception3Weights.ImageNet1K_TFV1 if kwargs.pop("pretrained") else None
weights = Inception3Weights.verify(weights)
weights = InceptionV3Weights.ImageNet1K_TFV1 if kwargs.pop("pretrained") else None
weights = InceptionV3Weights.verify(weights)

original_aux_logits = kwargs.get("aux_logits", True)
if weights is not None:
Expand Down
1 change: 1 addition & 0 deletions torchvision/prototype/models/quantization/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .googlenet import *
from .inception import *
from .resnet import *
87 changes: 87 additions & 0 deletions torchvision/prototype/models/quantization/inception.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import warnings
from functools import partial
from typing import Any, Optional, Union

from torchvision.transforms.functional import InterpolationMode

from ....models.quantization.inception import (
QuantizableInception3,
_replace_relu,
quantize_model,
)
from ...transforms.presets import ImageNetEval
from .._api import Weights, WeightEntry
from .._meta import _IMAGENET_CATEGORIES
from ..inception import InceptionV3Weights


__all__ = [
"QuantizableInception3",
"QuantizedInceptionV3Weights",
"inception_v3",
]


class QuantizedInceptionV3Weights(Weights):
ImageNet1K_FBGEMM_TFV1 = WeightEntry(
url="https://download.pytorch.org/models/quantized/inception_v3_google_fbgemm-71447a44.pth",
transforms=partial(ImageNetEval, crop_size=299, resize_size=342),
meta={
"size": (299, 299),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"backend": "fbgemm",
"quantization": "ptq",
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models",
"unquantized": InceptionV3Weights.ImageNet1K_TFV1,
"acc@1": 77.176,
"acc@5": 93.354,
},
)


def inception_v3(
weights: Optional[Union[QuantizedInceptionV3Weights, InceptionV3Weights]] = None,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableInception3:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
if kwargs.pop("pretrained"):
weights = (
QuantizedInceptionV3Weights.ImageNet1K_FBGEMM_TFV1 if quantize else InceptionV3Weights.ImageNet1K_TFV1
)
else:
weights = None

if quantize:
weights = QuantizedInceptionV3Weights.verify(weights)
else:
weights = InceptionV3Weights.verify(weights)

original_aux_logits = kwargs.get("aux_logits", False)
if weights is not None:
if "transform_input" not in kwargs:
kwargs["transform_input"] = True
kwargs["aux_logits"] = True
kwargs["num_classes"] = len(weights.meta["categories"])
if "backend" in weights.meta:
kwargs["backend"] = weights.meta["backend"]
backend = kwargs.pop("backend", "fbgemm")

model = QuantizableInception3(**kwargs)
_replace_relu(model)
if quantize:
quantize_model(model, backend)

if weights is not None:
if quantize and not original_aux_logits:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Special handling needed here. See original implementation.

model.aux_logits = False
model.AuxLogits = None
model.load_state_dict(weights.state_dict(progress=progress))
if not quantize and not original_aux_logits:
model.aux_logits = False
model.AuxLogits = None

return model