-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Adding new ResNet50 weights #4734
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
be58403
c1402ec
b6e55b6
9e76d10
d2ebcdc
766789e
a2f94d5
21ec231
4702aee
1a72a2f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,12 @@ | |
from torchvision.transforms.functional import InterpolationMode | ||
|
||
|
||
try: | ||
from torchvision.prototype import models as PM | ||
except ImportError: | ||
PM = None | ||
|
||
|
||
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None): | ||
model.train() | ||
metric_logger = utils.MetricLogger(delimiter=" ") | ||
|
@@ -142,11 +148,18 @@ def load_data(traindir, valdir, args): | |
print("Loading dataset_test from {}".format(cache_path)) | ||
dataset_test, _ = torch.load(cache_path) | ||
else: | ||
if not args.weights: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Which preprocessing we will use depends on whether weights are defined. |
||
preprocessing = presets.ClassificationPresetEval( | ||
crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation | ||
) | ||
else: | ||
fn = PM.__dict__[args.model] | ||
weights = PM._api.get_weight(fn, args.weights) | ||
preprocessing = weights.transforms() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Having a definition of the weights means we will be accessing the prototype models. Those have the preprocessing attached to the weights, so we fetch them and construct the preprocessing class. |
||
|
||
dataset_test = torchvision.datasets.ImageFolder( | ||
valdir, | ||
presets.ClassificationPresetEval( | ||
crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation | ||
), | ||
preprocessing, | ||
) | ||
if args.cache_dataset: | ||
print("Saving dataset_test to {}".format(cache_path)) | ||
|
@@ -206,7 +219,12 @@ def main(args): | |
) | ||
|
||
print("Creating model") | ||
model = torchvision.models.__dict__[args.model](pretrained=args.pretrained, num_classes=num_classes) | ||
if not args.weights: | ||
model = torchvision.models.__dict__[args.model](pretrained=args.pretrained, num_classes=num_classes) | ||
else: | ||
if PM is None: | ||
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") | ||
model = PM.__dict__[args.model](weights=args.weights, num_classes=num_classes) | ||
model.to(device) | ||
|
||
if args.distributed and args.sync_bn: | ||
|
@@ -455,6 +473,9 @@ def get_args_parser(add_help=True): | |
"--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)" | ||
) | ||
|
||
# Prototype models only | ||
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") | ||
|
||
return parser | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,13 @@ | ||
from collections import OrderedDict | ||
from dataclasses import dataclass, fields | ||
from enum import Enum | ||
from inspect import signature | ||
from typing import Any, Callable, Dict | ||
|
||
from ..._internally_replaced_utils import load_state_dict_from_url | ||
|
||
|
||
__all__ = ["Weights", "WeightEntry"] | ||
__all__ = ["Weights", "WeightEntry", "get_weight"] | ||
|
||
|
||
@dataclass | ||
|
@@ -74,3 +75,38 @@ def __getattr__(self, name): | |
if f.name == name: | ||
return object.__getattribute__(self.value, name) | ||
return super().__getattr__(name) | ||
|
||
|
||
def get_weight(fn: Callable, weight_name: str) -> Weights: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For now I consider it a private method. We will eventually need to make it public because getting the enum class from a string is useful but it's unclear whether we should do it by passing the model_builder and then weight_name or construct it via the fully qualified name. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry I only got a chance to looks at it now. Relying on the model_builder's annotation seems like a pretty involved way of retrieving the weights. Should we go simple here and just register all the weights in some sort of private (This is my only comment, the rest of the PR looks great!) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @NicolasHug Thanks for looking at it. FYI I merged after Prabhat's review so that we pass this to the FBsync but I plan to make changes on follow up PRs. I agree that this is involved and that's why I haven't exposed it as public. I've added an entry at #4652 to review the mechanism and more specifically sync with you on making it Torchhub friendly. One option as you said is to have a similar registration mechanism as proposed here to keep track of method/weight combos and flag also the "best/latest" weights. I have on purpose omitted all the versioning parts of the original RFC to allow for discussions across Audio and Text to continue and see if we can adopt a common solution. But I think they are currently looking into moving towards a different direction that has no model builders, so we might be able to bring this feature sooner. |
||
""" | ||
Gets the weight enum of a specific model builder method and weight name combination. | ||
|
||
Args: | ||
fn (Callable): The builder method used to create the model. | ||
weight_name (str): The name of the weight enum entry of the specific model. | ||
|
||
Returns: | ||
Weights: The requested weight enum. | ||
""" | ||
sig = signature(fn) | ||
if "weights" not in sig.parameters: | ||
raise ValueError("The method is missing the 'weights' argument.") | ||
|
||
ann = signature(fn).parameters["weights"].annotation | ||
weights_class = None | ||
if isinstance(ann, type) and issubclass(ann, Weights): | ||
weights_class = ann | ||
else: | ||
# handle cases like Union[Optional, T] | ||
# TODO: Replace ann.__args__ with typing.get_args(ann) after python >= 3.8 | ||
for t in ann.__args__: # type: ignore[union-attr] | ||
if isinstance(t, type) and issubclass(t, Weights): | ||
weights_class = t | ||
break | ||
|
||
if weights_class is None: | ||
raise ValueError( | ||
"The weight class for the specific method couldn't be retrieved. Make sure the typing info is " "correct." | ||
) | ||
|
||
return weights_class.from_str(weight_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Try to import the prototype models but without failing.