Skip to content

Refactor preset transforms #5562

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def load_data(traindir, valdir, args):
weights = prototype.models.get_weight(args.weights)
preprocessing = weights.transforms()
else:
preprocessing = prototype.transforms.ImageNetEval(
preprocessing = prototype.transforms.ImageClassificationEval(
crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation
)

Expand Down
2 changes: 1 addition & 1 deletion references/detection/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def get_transform(train, args):
weights = prototype.models.get_weight(args.weights)
return weights.transforms()
else:
return prototype.transforms.CocoEval()
return prototype.transforms.ObjectDetectionEval()


def get_args_parser(add_help=True):
Expand Down
2 changes: 1 addition & 1 deletion references/optical_flow/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def validate(model, args):
weights = prototype.models.get_weight(args.weights)
preprocessing = weights.transforms()
else:
preprocessing = prototype.transforms.RaftEval()
preprocessing = prototype.transforms.OpticalFlowEval()
else:
preprocessing = OpticalFlowPresetEval()

Expand Down
2 changes: 1 addition & 1 deletion references/segmentation/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_transform(train, args):
weights = prototype.models.get_weight(args.weights)
return weights.transforms()
else:
return prototype.transforms.VocEval(resize_size=520)
return prototype.transforms.SemanticSegmentationEval(resize_size=520)


def criterion(inputs, target):
Expand Down
2 changes: 1 addition & 1 deletion references/video_classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def main(args):
weights = prototype.models.get_weight(args.weights)
transform_test = weights.transforms()
else:
transform_test = prototype.transforms.Kinect400Eval(crop_size=(112, 112), resize_size=(128, 171))
transform_test = prototype.transforms.VideoClassificationEval(crop_size=(112, 112), resize_size=(128, 171))

if args.cache_dataset and os.path.exists(cache_path):
print(f"Loading dataset_test from {cache_path}")
Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/models/alexnet.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from functools import partial
from typing import Any, Optional

from torchvision.prototype.transforms import ImageNetEval
from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode

from ...models.alexnet import AlexNet
Expand All @@ -16,7 +16,7 @@
class AlexNet_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/alexnet-owt-7be5be79.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
"task": "image_classification",
"architecture": "AlexNet",
Expand Down
10 changes: 5 additions & 5 deletions torchvision/prototype/models/convnext.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from functools import partial
from typing import Any, List, Optional

from torchvision.prototype.transforms import ImageNetEval
from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode

from ...models.convnext import ConvNeXt, CNBlockConfig
Expand Down Expand Up @@ -56,7 +56,7 @@ def _convnext(
class ConvNeXt_Tiny_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/convnext_tiny-983f1562.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=236),
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=236),
meta={
**_COMMON_META,
"num_params": 28589128,
Expand All @@ -70,7 +70,7 @@ class ConvNeXt_Tiny_Weights(WeightsEnum):
class ConvNeXt_Small_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/convnext_small-0c510722.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=230),
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=230),
meta={
**_COMMON_META,
"num_params": 50223688,
Expand All @@ -84,7 +84,7 @@ class ConvNeXt_Small_Weights(WeightsEnum):
class ConvNeXt_Base_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/convnext_base-6075fbad.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 88591464,
Expand All @@ -98,7 +98,7 @@ class ConvNeXt_Base_Weights(WeightsEnum):
class ConvNeXt_Large_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/convnext_large-ea097f82.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 197767336,
Expand Down
10 changes: 5 additions & 5 deletions torchvision/prototype/models/densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any, Optional, Tuple

import torch.nn as nn
from torchvision.prototype.transforms import ImageNetEval
from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode

from ...models.densenet import DenseNet
Expand Down Expand Up @@ -78,7 +78,7 @@ def _densenet(
class DenseNet121_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/densenet121-a639ec97.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 7978856,
Expand All @@ -92,7 +92,7 @@ class DenseNet121_Weights(WeightsEnum):
class DenseNet161_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/densenet161-8d451a50.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 28681000,
Expand All @@ -106,7 +106,7 @@ class DenseNet161_Weights(WeightsEnum):
class DenseNet169_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/densenet169-b2777c0a.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 14149480,
Expand All @@ -120,7 +120,7 @@ class DenseNet169_Weights(WeightsEnum):
class DenseNet201_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/densenet201-c1103571.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 20013928,
Expand Down
8 changes: 4 additions & 4 deletions torchvision/prototype/models/detection/faster_rcnn.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Optional, Union

from torch import nn
from torchvision.prototype.transforms import CocoEval
from torchvision.prototype.transforms import ObjectDetectionEval
from torchvision.transforms.functional import InterpolationMode

from ....models.detection.faster_rcnn import (
Expand Down Expand Up @@ -43,7 +43,7 @@
class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum):
COCO_V1 = Weights(
url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth",
transforms=CocoEval,
transforms=ObjectDetectionEval,
meta={
**_COMMON_META,
"num_params": 41755286,
Expand All @@ -57,7 +57,7 @@ class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum):
class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
COCO_V1 = Weights(
url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth",
transforms=CocoEval,
transforms=ObjectDetectionEval,
meta={
**_COMMON_META,
"num_params": 19386354,
Expand All @@ -71,7 +71,7 @@ class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
COCO_V1 = Weights(
url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth",
transforms=CocoEval,
transforms=ObjectDetectionEval,
meta={
**_COMMON_META,
"num_params": 19386354,
Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/models/detection/fcos.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Optional

from torch import nn
from torchvision.prototype.transforms import CocoEval
from torchvision.prototype.transforms import ObjectDetectionEval
from torchvision.transforms.functional import InterpolationMode

from ....models.detection.fcos import (
Expand All @@ -27,7 +27,7 @@
class FCOS_ResNet50_FPN_Weights(WeightsEnum):
COCO_V1 = Weights(
url="https://download.pytorch.org/models/fcos_resnet50_fpn_coco-99b0c9b7.pth",
transforms=CocoEval,
transforms=ObjectDetectionEval,
meta={
"task": "image_object_detection",
"architecture": "FCOS",
Expand Down
6 changes: 3 additions & 3 deletions torchvision/prototype/models/detection/keypoint_rcnn.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Optional

from torch import nn
from torchvision.prototype.transforms import CocoEval
from torchvision.prototype.transforms import ObjectDetectionEval
from torchvision.transforms.functional import InterpolationMode

from ....models.detection.keypoint_rcnn import (
Expand Down Expand Up @@ -37,7 +37,7 @@
class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
COCO_LEGACY = Weights(
url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth",
transforms=CocoEval,
transforms=ObjectDetectionEval,
meta={
**_COMMON_META,
"num_params": 59137258,
Expand All @@ -48,7 +48,7 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
)
COCO_V1 = Weights(
url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth",
transforms=CocoEval,
transforms=ObjectDetectionEval,
meta={
**_COMMON_META,
"num_params": 59137258,
Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/models/detection/mask_rcnn.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Optional

from torch import nn
from torchvision.prototype.transforms import CocoEval
from torchvision.prototype.transforms import ObjectDetectionEval
from torchvision.transforms.functional import InterpolationMode

from ....models.detection.mask_rcnn import (
Expand All @@ -27,7 +27,7 @@
class MaskRCNN_ResNet50_FPN_Weights(WeightsEnum):
COCO_V1 = Weights(
url="https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth",
transforms=CocoEval,
transforms=ObjectDetectionEval,
meta={
"task": "image_object_detection",
"architecture": "MaskRCNN",
Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/models/detection/retinanet.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Optional

from torch import nn
from torchvision.prototype.transforms import CocoEval
from torchvision.prototype.transforms import ObjectDetectionEval
from torchvision.transforms.functional import InterpolationMode

from ....models.detection.retinanet import (
Expand All @@ -28,7 +28,7 @@
class RetinaNet_ResNet50_FPN_Weights(WeightsEnum):
COCO_V1 = Weights(
url="https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth",
transforms=CocoEval,
transforms=ObjectDetectionEval,
meta={
"task": "image_object_detection",
"architecture": "RetinaNet",
Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/models/detection/ssd.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import warnings
from typing import Any, Optional

from torchvision.prototype.transforms import CocoEval
from torchvision.prototype.transforms import ObjectDetectionEval
from torchvision.transforms.functional import InterpolationMode

from ....models.detection.ssd import (
Expand All @@ -25,7 +25,7 @@
class SSD300_VGG16_Weights(WeightsEnum):
COCO_V1 = Weights(
url="https://download.pytorch.org/models/ssd300_vgg16_coco-b556d3b4.pth",
transforms=CocoEval,
transforms=ObjectDetectionEval,
meta={
"task": "image_object_detection",
"architecture": "SSD",
Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/models/detection/ssdlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any, Callable, Optional

from torch import nn
from torchvision.prototype.transforms import CocoEval
from torchvision.prototype.transforms import ObjectDetectionEval
from torchvision.transforms.functional import InterpolationMode

from ....models.detection.ssdlite import (
Expand All @@ -30,7 +30,7 @@
class SSDLite320_MobileNet_V3_Large_Weights(WeightsEnum):
COCO_V1 = Weights(
url="https://download.pytorch.org/models/ssdlite320_mobilenet_v3_large_coco-a79551df.pth",
transforms=CocoEval,
transforms=ObjectDetectionEval,
meta={
"task": "image_object_detection",
"architecture": "SSDLite",
Expand Down
Loading