Skip to content

Commit cac4e22

Browse files
authored
Make get_model_builder public (#6560)
1 parent a67cc87 commit cac4e22

File tree

4 files changed

+33
-7
lines changed

4 files changed

+33
-7
lines changed

test/test_extended_models.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,21 @@ def test_get_model(name, model_class):
2929
assert isinstance(models.get_model(name), model_class)
3030

3131

32+
@pytest.mark.parametrize(
33+
"name, model_fn",
34+
[
35+
("resnet50", models.resnet50),
36+
("retinanet_resnet50_fpn_v2", models.detection.retinanet_resnet50_fpn_v2),
37+
("raft_large", models.optical_flow.raft_large),
38+
("quantized_resnet50", models.quantization.resnet50),
39+
("lraspp_mobilenet_v3_large", models.segmentation.lraspp_mobilenet_v3_large),
40+
("mvit_v1_b", models.video.mvit_v1_b),
41+
],
42+
)
43+
def test_get_model_builder(name, model_fn):
44+
assert models.get_model_builder(name) == model_fn
45+
46+
3247
@pytest.mark.parametrize(
3348
"name, weight",
3449
[

test/test_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,15 @@
1717
from _utils_internal import get_relative_path
1818
from common_utils import cpu_and_gpu, freeze_rng_state, map_nested_tensor_object, needs_cuda, set_rng_seed
1919
from torchvision import models
20-
from torchvision.models._api import find_model, list_models
20+
from torchvision.models import get_model_builder, list_models
2121

2222

2323
ACCEPT = os.getenv("EXPECTTEST_ACCEPT", "0") == "1"
2424
SKIP_BIG_MODEL = os.getenv("SKIP_BIG_MODEL", "1") == "1"
2525

2626

2727
def list_model_fns(module):
28-
return [find_model(name) for name in list_models(module)]
28+
return [get_model_builder(name) for name in list_models(module)]
2929

3030

3131
@pytest.fixture

torchvision/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,4 @@
1414
from .vision_transformer import *
1515
from .swin_transformer import *
1616
from . import detection, optical_flow, quantization, segmentation, video
17-
from ._api import get_model, get_model_weights, get_weight, list_models
17+
from ._api import get_model, get_model_builder, get_model_weights, get_weight, list_models

torchvision/models/_api.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from .._internally_replaced_utils import load_state_dict_from_url
1414

1515

16-
__all__ = ["WeightsEnum", "Weights", "get_model", "get_model_weights", "get_weight", "list_models"]
16+
__all__ = ["WeightsEnum", "Weights", "get_model", "get_model_builder", "get_model_weights", "get_weight", "list_models"]
1717

1818

1919
@dataclass
@@ -127,7 +127,7 @@ def get_model_weights(name: Union[Callable, str]) -> W:
127127
Returns:
128128
weights_enum (W): The weights enum class associated with the model.
129129
"""
130-
model = find_model(name) if isinstance(name, str) else name
130+
model = get_model_builder(name) if isinstance(name, str) else name
131131
return cast(W, _get_enum_from_fn(model))
132132

133133

@@ -199,7 +199,18 @@ def list_models(module: Optional[ModuleType] = None) -> List[str]:
199199
return sorted(models)
200200

201201

202-
def find_model(name: str) -> Callable[..., M]:
202+
def get_model_builder(name: str) -> Callable[..., M]:
203+
"""
204+
Gets the model name and returns the model builder method.
205+
206+
.. betastatus:: function
207+
208+
Args:
209+
name (str): The name under which the model is registered.
210+
211+
Returns:
212+
fn (Callable): The model builder method.
213+
"""
203214
name = name.lower()
204215
try:
205216
fn = BUILTIN_MODELS[name]
@@ -221,5 +232,5 @@ def get_model(name: str, **config: Any) -> M:
221232
Returns:
222233
model (nn.Module): The initialized model.
223234
"""
224-
fn = find_model(name)
235+
fn = get_model_builder(name)
225236
return fn(**config)

0 commit comments

Comments
 (0)