Skip to content

Adding new ResNet50 weights #4734

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Oct 25, 2021
29 changes: 25 additions & 4 deletions references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@
from torchvision.transforms.functional import InterpolationMode


try:
from torchvision.prototype import models as PM
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try to import the prototype models but without failing.

except ImportError:
PM = None


def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None):
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
Expand Down Expand Up @@ -142,11 +148,18 @@ def load_data(traindir, valdir, args):
print("Loading dataset_test from {}".format(cache_path))
dataset_test, _ = torch.load(cache_path)
else:
if not args.weights:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which preprocessing we will use depends on whether weights are defined.

preprocessing = presets.ClassificationPresetEval(
crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation
)
else:
fn = PM.__dict__[args.model]
weights = PM._api.get_weight(fn, args.weights)
preprocessing = weights.transforms()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having a definition of the weights means we will be accessing the prototype models. Those have the preprocessing attached to the weights, so we fetch them and construct the preprocessing class.


dataset_test = torchvision.datasets.ImageFolder(
valdir,
presets.ClassificationPresetEval(
crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation
),
preprocessing,
)
if args.cache_dataset:
print("Saving dataset_test to {}".format(cache_path))
Expand Down Expand Up @@ -206,7 +219,12 @@ def main(args):
)

print("Creating model")
model = torchvision.models.__dict__[args.model](pretrained=args.pretrained, num_classes=num_classes)
if not args.weights:
model = torchvision.models.__dict__[args.model](pretrained=args.pretrained, num_classes=num_classes)
else:
if PM is None:
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
model = PM.__dict__[args.model](weights=args.weights, num_classes=num_classes)
model.to(device)

if args.distributed and args.sync_bn:
Expand Down Expand Up @@ -455,6 +473,9 @@ def get_args_parser(add_help=True):
"--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
)

# Prototype models only
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")

return parser


Expand Down
6 changes: 6 additions & 0 deletions test/test_prototype_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@ def get_available_classification_models():
return [k for k, v in models.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]


def test_get_weight():
fn = models.resnet50
weight_name = "ImageNet1K_RefV2"
assert models._api.get_weight(fn, weight_name) == models.ResNet50Weights.ImageNet1K_RefV2


@pytest.mark.parametrize("model_name", get_available_classification_models())
@pytest.mark.parametrize("dev", cpu_and_gpu())
@pytest.mark.skipif(os.getenv("PYTORCH_TEST_WITH_PROTOTYPE", "0") == "0", reason="Prototype code tests are disabled")
Expand Down
38 changes: 37 additions & 1 deletion torchvision/prototype/models/_api.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from collections import OrderedDict
from dataclasses import dataclass, fields
from enum import Enum
from inspect import signature
from typing import Any, Callable, Dict

from ..._internally_replaced_utils import load_state_dict_from_url


__all__ = ["Weights", "WeightEntry"]
__all__ = ["Weights", "WeightEntry", "get_weight"]


@dataclass
Expand Down Expand Up @@ -74,3 +75,38 @@ def __getattr__(self, name):
if f.name == name:
return object.__getattribute__(self.value, name)
return super().__getattr__(name)


def get_weight(fn: Callable, weight_name: str) -> Weights:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now I consider it a private method. We will eventually need to make it public because getting the enum class from a string is useful but it's unclear whether we should do it by passing the model_builder and then weight_name or construct it via the fully qualified name.

Copy link
Member

@NicolasHug NicolasHug Oct 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I only got a chance to looks at it now.

Relying on the model_builder's annotation seems like a pretty involved way of retrieving the weights.

Should we go simple here and just register all the weights in some sort of private _AVAILABLE_WEIGHTS dict? get_weight() would then just be a query into this private dict

(This is my only comment, the rest of the PR looks great!)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NicolasHug Thanks for looking at it. FYI I merged after Prabhat's review so that we pass this to the FBsync but I plan to make changes on follow up PRs.

I agree that this is involved and that's why I haven't exposed it as public. I've added an entry at #4652 to review the mechanism and more specifically sync with you on making it Torchhub friendly. One option as you said is to have a similar registration mechanism as proposed here to keep track of method/weight combos and flag also the "best/latest" weights. I have on purpose omitted all the versioning parts of the original RFC to allow for discussions across Audio and Text to continue and see if we can adopt a common solution. But I think they are currently looking into moving towards a different direction that has no model builders, so we might be able to bring this feature sooner.

"""
Gets the weight enum of a specific model builder method and weight name combination.

Args:
fn (Callable): The builder method used to create the model.
weight_name (str): The name of the weight enum entry of the specific model.

Returns:
Weights: The requested weight enum.
"""
sig = signature(fn)
if "weights" not in sig.parameters:
raise ValueError("The method is missing the 'weights' argument.")

ann = signature(fn).parameters["weights"].annotation
weights_class = None
if isinstance(ann, type) and issubclass(ann, Weights):
weights_class = ann
else:
# handle cases like Union[Optional, T]
# TODO: Replace ann.__args__ with typing.get_args(ann) after python >= 3.8
for t in ann.__args__: # type: ignore[union-attr]
if isinstance(t, type) and issubclass(t, Weights):
weights_class = t
break

if weights_class is None:
raise ValueError(
"The weight class for the specific method couldn't be retrieved. Make sure the typing info is " "correct."
)

return weights_class.from_str(weight_name)
8 changes: 4 additions & 4 deletions torchvision/prototype/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,13 @@ class ResNet50Weights(Weights):
},
)
ImageNet1K_RefV2 = WeightEntry(
url="https://download.pytorch.org/models/resnet50-tmp.pth",
transforms=partial(ImageNetEval, crop_size=224),
url="https://download.pytorch.org/models/resnet50-f46c3f97.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
meta={
**_common_meta,
"recipe": "https://github.com/pytorch/vision/issues/3995",
"acc@1": 80.352,
"acc@5": 95.148,
"acc@1": 80.674,
"acc@5": 95.166,
},
)

Expand Down