Skip to content

Commit 8df96a9

Browse files
authored
Merge branch 'main' into model-contrib-guidelines
2 parents e06c180 + 22f8dc4 commit 8df96a9

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+754
-431
lines changed

.circleci/unittest/linux/scripts/run_test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@ conda activate ./env
77

88
export PYTORCH_TEST_WITH_SLOW='1'
99
python -m torch.utils.collect_env
10-
pytest --cov=torchvision --junitxml=test-results/junit.xml -v --durations 20
10+
pytest --junitxml=test-results/junit.xml -v --durations 20

.circleci/unittest/windows/scripts/run_test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,4 @@ source "$this_dir/set_cuda_envs.sh"
1010

1111
export PYTORCH_TEST_WITH_SLOW='1'
1212
python -m torch.utils.collect_env
13-
pytest --cov=torchvision --junitxml=test-results/junit.xml -v --durations 20
13+
pytest --junitxml=test-results/junit.xml -v --durations 20

.coveragerc

Lines changed: 0 additions & 7 deletions
This file was deleted.

docs/source/datasets.rst

Lines changed: 68 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ Torchvision provides many built-in datasets in the ``torchvision.datasets``
55
module, as well as utility classes for building your own datasets.
66

77
Built-in datasets
8-
~~~~~~~~~~~~~~~~~
8+
-----------------
99

1010
All datasets are subclasses of :class:`torch.utils.data.Dataset`
1111
i.e, they have ``__getitem__`` and ``__len__`` methods implemented.
@@ -25,6 +25,8 @@ All the datasets have almost similar API. They all have two common arguments:
2525
``transform`` and ``target_transform`` to transform the input and target respectively.
2626
You can also create your own datasets using the provided :ref:`base classes <base_classes_datasets>`.
2727

28+
Image classification
29+
~~~~~~~~~~~~~~~~~~~~
2830

2931
.. autosummary::
3032
:toctree: generated/
@@ -35,61 +37,105 @@ You can also create your own datasets using the provided :ref:`base classes <bas
3537
CelebA
3638
CIFAR10
3739
CIFAR100
38-
Cityscapes
39-
CocoCaptions
40-
CocoDetection
4140
Country211
4241
DTD
4342
EMNIST
4443
EuroSAT
4544
FakeData
4645
FashionMNIST
4746
FER2013
47+
FGVCAircraft
4848
Flickr8k
4949
Flickr30k
5050
Flowers102
51-
FlyingChairs
52-
FlyingThings3D
5351
Food101
54-
FGVCAircraft
5552
GTSRB
56-
HD1K
57-
HMDB51
58-
ImageNet
5953
INaturalist
60-
Kinetics400
61-
Kitti
62-
KittiFlow
54+
ImageNet
6355
KMNIST
6456
LFWPeople
65-
LFWPairs
6657
LSUN
6758
MNIST
6859
Omniglot
6960
OxfordIIITPet
70-
PCAM
71-
PhotoTour
7261
Places365
73-
RenderedSST2
62+
PCAM
7463
QMNIST
75-
SBDataset
76-
SBU
64+
RenderedSST2
7765
SEMEION
78-
Sintel
66+
SBU
7967
StanfordCars
8068
STL10
8169
SUN397
8270
SVHN
83-
UCF101
8471
USPS
72+
73+
Image detection or segmentation
74+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
75+
76+
.. autosummary::
77+
:toctree: generated/
78+
:template: class_dataset.rst
79+
80+
CocoDetection
81+
CelebA
82+
Cityscapes
83+
GTSRB
84+
Kitti
85+
OxfordIIITPet
86+
SBDataset
8587
VOCSegmentation
8688
VOCDetection
8789
WIDERFace
8890

91+
Optical Flow
92+
~~~~~~~~~~~~
93+
94+
.. autosummary::
95+
:toctree: generated/
96+
:template: class_dataset.rst
97+
98+
FlyingChairs
99+
FlyingThings3D
100+
HD1K
101+
KittiFlow
102+
Sintel
103+
104+
Image pairs
105+
~~~~~~~~~~~
106+
107+
.. autosummary::
108+
:toctree: generated/
109+
:template: class_dataset.rst
110+
111+
LFWPairs
112+
PhotoTour
113+
114+
Image captioning
115+
~~~~~~~~~~~~~~~~
116+
117+
.. autosummary::
118+
:toctree: generated/
119+
:template: class_dataset.rst
120+
121+
CocoCaptions
122+
123+
Video classification
124+
~~~~~~~~~~~~~~~~~~~~
125+
126+
.. autosummary::
127+
:toctree: generated/
128+
:template: class_dataset.rst
129+
130+
HMDB51
131+
Kinetics400
132+
UCF101
133+
134+
89135
.. _base_classes_datasets:
90136

91137
Base classes for custom datasets
92-
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
138+
--------------------------------
93139

94140
.. autosummary::
95141
:toctree: generated/

docs/source/models.rst

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,10 @@ You can construct a model with random weights by calling its constructor:
8989
vit_b_32 = models.vit_b_32()
9090
vit_l_16 = models.vit_l_16()
9191
vit_l_32 = models.vit_l_32()
92+
convnext_tiny = models.convnext_tiny()
93+
convnext_small = models.convnext_small()
94+
convnext_base = models.convnext_base()
95+
convnext_large = models.convnext_large()
9296
9397
We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`.
9498
These can be constructed by passing ``pretrained=True``:
@@ -136,6 +140,10 @@ These can be constructed by passing ``pretrained=True``:
136140
vit_b_32 = models.vit_b_32(pretrained=True)
137141
vit_l_16 = models.vit_l_16(pretrained=True)
138142
vit_l_32 = models.vit_l_32(pretrained=True)
143+
convnext_tiny = models.convnext_tiny(pretrained=True)
144+
convnext_small = models.convnext_small(pretrained=True)
145+
convnext_base = models.convnext_base(pretrained=True)
146+
convnext_large = models.convnext_large(pretrained=True)
139147
140148
Instancing a pre-trained model will download its weights to a cache directory.
141149
This directory can be set using the `TORCH_HOME` environment variable. See
@@ -248,7 +256,10 @@ vit_b_16 81.072 95.318
248256
vit_b_32 75.912 92.466
249257
vit_l_16 79.662 94.638
250258
vit_l_32 76.972 93.070
251-
convnext_tiny (prototype) 82.520 96.146
259+
convnext_tiny 82.520 96.146
260+
convnext_small 83.616 96.650
261+
convnext_base 84.062 96.870
262+
convnext_large 84.414 96.976
252263
================================ ============= =============
253264

254265

@@ -464,6 +475,18 @@ VisionTransformer
464475
vit_l_16
465476
vit_l_32
466477

478+
ConvNeXt
479+
--------
480+
481+
.. autosummary::
482+
:toctree: generated/
483+
:template: function.rst
484+
485+
convnext_tiny
486+
convnext_small
487+
convnext_base
488+
convnext_large
489+
467490
Quantized Models
468491
----------------
469492

docs/source/ops.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ Operators
2121
clip_boxes_to_image
2222
deform_conv2d
2323
generalized_box_iou
24+
generalized_box_iou_loss
2425
masks_to_boxes
2526
nms
2627
ps_roi_align

hubconf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
dependencies = ["torch"]
33

44
from torchvision.models.alexnet import alexnet
5+
from torchvision.models.convnext import convnext_tiny, convnext_small, convnext_base, convnext_large
56
from torchvision.models.densenet import densenet121, densenet169, densenet201, densenet161
67
from torchvision.models.efficientnet import (
78
efficientnet_b0,

references/classification/README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,11 +201,12 @@ and `--batch_size 64`.
201201
### ConvNeXt
202202
```
203203
torchrun --nproc_per_node=8 train.py\
204-
--model convnext_tiny --batch-size 128 --opt adamw --lr 1e-3 --lr-scheduler cosineannealinglr \
204+
--model $MODEL --batch-size 128 --opt adamw --lr 1e-3 --lr-scheduler cosineannealinglr \
205205
--lr-warmup-epochs 5 --lr-warmup-method linear --auto-augment ta_wide --epochs 600 --random-erase 0.1 \
206206
--label-smoothing 0.1 --mixup-alpha 0.2 --cutmix-alpha 1.0 --weight-decay 0.05 --norm-weight-decay 0.0 \
207-
--train-crop-size 176 --model-ema --val-resize-size 236 --ra-sampler --ra-reps 4
207+
--train-crop-size 176 --model-ema --val-resize-size 232 --ra-sampler --ra-reps 4
208208
```
209+
Here `$MODEL` is one of `convnext_tiny`, `convnext_small`, `convnext_base` and `convnext_large`. Note that each variant had its `--val-resize-size` optimized in a post-training step, see their `Weights` entry for their exact value.
209210

210211
Note that the above command corresponds to training on a single node with 8 GPUs.
211212
For generatring the pre-trained weights, we trained with 2 nodes, each with 8 GPUs (for a total of 16 GPUs),

references/classification/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def load_data(traindir, valdir, args):
178178

179179
print("Creating data loaders")
180180
if args.distributed:
181-
if args.ra_sampler:
181+
if hasattr(args, "ra_sampler") and args.ra_sampler:
182182
train_sampler = RASampler(dataset, shuffle=True, repetitions=args.ra_reps)
183183
else:
184184
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)

references/classification/train_quantization.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,16 @@
1313

1414

1515
try:
16-
from torchvision.prototype import models as PM
16+
from torchvision import prototype
1717
except ImportError:
18-
PM = None
18+
prototype = None
1919

2020

2121
def main(args):
22-
if args.weights and PM is None:
22+
if args.prototype and prototype is None:
2323
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
24+
if not args.prototype and args.weights:
25+
raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.")
2426
if args.output_dir:
2527
utils.mkdir(args.output_dir)
2628

@@ -54,14 +56,14 @@ def main(args):
5456

5557
print("Creating model", args.model)
5658
# when training quantized models, we always start from a pre-trained fp32 reference model
57-
if not args.weights:
59+
if not args.prototype:
5860
model = torchvision.models.quantization.__dict__[args.model](pretrained=True, quantize=args.test_only)
5961
else:
60-
model = PM.quantization.__dict__[args.model](weights=args.weights, quantize=args.test_only)
62+
model = prototype.models.quantization.__dict__[args.model](weights=args.weights, quantize=args.test_only)
6163
model.to(device)
6264

6365
if not (args.test_only or args.post_training_quantize):
64-
model.fuse_model()
66+
model.fuse_model(is_qat=True)
6567
model.qconfig = torch.ao.quantization.get_default_qat_qconfig(args.backend)
6668
torch.ao.quantization.prepare_qat(model, inplace=True)
6769

@@ -95,7 +97,7 @@ def main(args):
9597
ds, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True
9698
)
9799
model.eval()
98-
model.fuse_model()
100+
model.fuse_model(is_qat=False)
99101
model.qconfig = torch.ao.quantization.get_default_qconfig(args.backend)
100102
torch.ao.quantization.prepare(model, inplace=True)
101103
# Calibrate first
@@ -264,6 +266,12 @@ def get_args_parser(add_help=True):
264266
parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
265267

266268
# Prototype models only
269+
parser.add_argument(
270+
"--prototype",
271+
dest="prototype",
272+
help="Use prototype model builders instead those from main area",
273+
action="store_true",
274+
)
267275
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
268276

269277
return parser

references/classification/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=T
344344
345345
# Quantized Classification
346346
model = M.quantization.mobilenet_v3_large(pretrained=False, quantize=False)
347-
model.fuse_model()
347+
model.fuse_model(is_qat=True)
348348
model.qconfig = torch.ao.quantization.get_default_qat_qconfig('qnnpack')
349349
_ = torch.ao.quantization.prepare_qat(model, inplace=True)
350350
print(store_model_weights(model, './qat.pth'))
Binary file not shown.
Binary file not shown.
Binary file not shown.

test/test_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -833,7 +833,7 @@ def test_quantized_classification_model(model_fn):
833833
model.train()
834834
model.qconfig = torch.ao.quantization.default_qat_qconfig
835835

836-
model.fuse_model()
836+
model.fuse_model(is_qat=not eval_mode)
837837
if eval_mode:
838838
torch.ao.quantization.prepare(model, inplace=True)
839839
else:

test/test_utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def test_draw_boxes_vanilla():
124124
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
125125
img_cp = img.clone()
126126
boxes_cp = boxes.clone()
127-
result = utils.draw_bounding_boxes(img, boxes, fill=False, width=7)
127+
result = utils.draw_bounding_boxes(img, boxes, fill=False, width=7, colors="white")
128128

129129
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_vanilla.png")
130130
if not os.path.exists(path):
@@ -149,7 +149,11 @@ def test_draw_invalid_boxes():
149149
img_tp = ((1, 1, 1), (1, 2, 3))
150150
img_wrong1 = torch.full((3, 5, 5), 255, dtype=torch.float)
151151
img_wrong2 = torch.full((1, 3, 5, 5), 255, dtype=torch.uint8)
152+
img_correct = torch.zeros((3, 10, 10), dtype=torch.uint8)
152153
boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
154+
labels_wrong = ["one", "two"]
155+
colors_wrong = ["pink", "blue"]
156+
153157
with pytest.raises(TypeError, match="Tensor expected"):
154158
utils.draw_bounding_boxes(img_tp, boxes)
155159
with pytest.raises(ValueError, match="Tensor uint8 expected"):
@@ -158,6 +162,10 @@ def test_draw_invalid_boxes():
158162
utils.draw_bounding_boxes(img_wrong2, boxes)
159163
with pytest.raises(ValueError, match="Only grayscale and RGB images are supported"):
160164
utils.draw_bounding_boxes(img_wrong2[0][:2], boxes)
165+
with pytest.raises(ValueError, match="Number of boxes"):
166+
utils.draw_bounding_boxes(img_correct, boxes, labels_wrong)
167+
with pytest.raises(ValueError, match="Number of colors"):
168+
utils.draw_bounding_boxes(img_correct, boxes, colors=colors_wrong)
161169

162170

163171
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)