Skip to content

Commit 9e78871

Browse files
authored
Expose get_weight to Torch Hub (#6026)
* Prefixing `_get_enum_from_fn` with underscore * Exposing `get_weight` to Torch Hub
1 parent 8e5844f commit 9e78871

File tree

3 files changed

+5
-4
lines changed

3 files changed

+5
-4
lines changed

hubconf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Optional list of dependencies required by the package
22
dependencies = ["torch"]
33

4+
from torchvision.models import get_weight
45
from torchvision.models.alexnet import alexnet
56
from torchvision.models.convnext import convnext_tiny, convnext_small, convnext_base, convnext_large
67
from torchvision.models.densenet import densenet121, densenet169, densenet201, densenet161

torchvision/models/_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def get_weight(name: str) -> WeightsEnum:
107107
return weights_enum.from_str(value_name)
108108

109109

110-
def get_enum_from_fn(fn: Callable) -> WeightsEnum:
110+
def _get_enum_from_fn(fn: Callable) -> WeightsEnum:
111111
"""
112112
Internal method that gets the weight enum of a specific model builder method.
113113
Might be removed after the handle_legacy_interface is removed.

torchvision/models/detection/backbone_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from torchvision.ops.feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool
77

88
from .. import mobilenet, resnet
9-
from .._api import WeightsEnum, get_enum_from_fn
9+
from .._api import WeightsEnum, _get_enum_from_fn
1010
from .._utils import IntermediateLayerGetter, handle_legacy_interface
1111

1212

@@ -62,7 +62,7 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]:
6262
@handle_legacy_interface(
6363
weights=(
6464
"pretrained",
65-
lambda kwargs: get_enum_from_fn(resnet.__dict__[kwargs["backbone_name"]]).from_str("IMAGENET1K_V1"),
65+
lambda kwargs: _get_enum_from_fn(resnet.__dict__[kwargs["backbone_name"]]).from_str("IMAGENET1K_V1"),
6666
),
6767
)
6868
def resnet_fpn_backbone(
@@ -177,7 +177,7 @@ def _validate_trainable_layers(
177177
@handle_legacy_interface(
178178
weights=(
179179
"pretrained",
180-
lambda kwargs: get_enum_from_fn(mobilenet.__dict__[kwargs["backbone_name"]]).from_str("IMAGENET1K_V1"),
180+
lambda kwargs: _get_enum_from_fn(mobilenet.__dict__[kwargs["backbone_name"]]).from_str("IMAGENET1K_V1"),
181181
),
182182
)
183183
def mobilenet_backbone(

0 commit comments

Comments
 (0)