diff --git a/test/test_extended_models.py b/test/test_extended_models.py index 55259bb150d..56710cac2c1 100644 --- a/test/test_extended_models.py +++ b/test/test_extended_models.py @@ -29,6 +29,21 @@ def test_get_model(name, model_class): assert isinstance(models.get_model(name), model_class) +@pytest.mark.parametrize( + "name, model_fn", + [ + ("resnet50", models.resnet50), + ("retinanet_resnet50_fpn_v2", models.detection.retinanet_resnet50_fpn_v2), + ("raft_large", models.optical_flow.raft_large), + ("quantized_resnet50", models.quantization.resnet50), + ("lraspp_mobilenet_v3_large", models.segmentation.lraspp_mobilenet_v3_large), + ("mvit_v1_b", models.video.mvit_v1_b), + ], +) +def test_get_model_builder(name, model_fn): + assert models.get_model_builder(name) == model_fn + + @pytest.mark.parametrize( "name, weight", [ diff --git a/test/test_models.py b/test/test_models.py index 7645dc419ff..2629c3c4dc1 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -17,7 +17,7 @@ from _utils_internal import get_relative_path from common_utils import cpu_and_gpu, freeze_rng_state, map_nested_tensor_object, needs_cuda, set_rng_seed from torchvision import models -from torchvision.models._api import find_model, list_models +from torchvision.models import get_model_builder, list_models ACCEPT = os.getenv("EXPECTTEST_ACCEPT", "0") == "1" @@ -25,7 +25,7 @@ def list_model_fns(module): - return [find_model(name) for name in list_models(module)] + return [get_model_builder(name) for name in list_models(module)] @pytest.fixture diff --git a/torchvision/models/__init__.py b/torchvision/models/__init__.py index eb949fb3d5c..f8baa05c1f6 100644 --- a/torchvision/models/__init__.py +++ b/torchvision/models/__init__.py @@ -14,4 +14,4 @@ from .vision_transformer import * from .swin_transformer import * from . import detection, optical_flow, quantization, segmentation, video -from ._api import get_model, get_model_weights, get_weight, list_models +from ._api import get_model, get_model_builder, get_model_weights, get_weight, list_models diff --git a/torchvision/models/_api.py b/torchvision/models/_api.py index c2886d2ed99..22073d9c28d 100644 --- a/torchvision/models/_api.py +++ b/torchvision/models/_api.py @@ -13,7 +13,7 @@ from .._internally_replaced_utils import load_state_dict_from_url -__all__ = ["WeightsEnum", "Weights", "get_model", "get_model_weights", "get_weight", "list_models"] +__all__ = ["WeightsEnum", "Weights", "get_model", "get_model_builder", "get_model_weights", "get_weight", "list_models"] @dataclass @@ -127,7 +127,7 @@ def get_model_weights(name: Union[Callable, str]) -> W: Returns: weights_enum (W): The weights enum class associated with the model. """ - model = find_model(name) if isinstance(name, str) else name + model = get_model_builder(name) if isinstance(name, str) else name return cast(W, _get_enum_from_fn(model)) @@ -199,7 +199,18 @@ def list_models(module: Optional[ModuleType] = None) -> List[str]: return sorted(models) -def find_model(name: str) -> Callable[..., M]: +def get_model_builder(name: str) -> Callable[..., M]: + """ + Gets the model name and returns the model builder method. + + .. betastatus:: function + + Args: + name (str): The name under which the model is registered. + + Returns: + fn (Callable): The model builder method. + """ name = name.lower() try: fn = BUILTIN_MODELS[name] @@ -221,5 +232,5 @@ def get_model(name: str, **config: Any) -> M: Returns: model (nn.Module): The initialized model. """ - fn = find_model(name) + fn = get_model_builder(name) return fn(**config)