Skip to content

Commit 3fa2414

Browse files
authored
Add SWAG model weight that only the linear head is finetuned to ImageNet1K (#5793)
* Add SWAG model that only the linear classifier head is finetuned with frozen trunk weight * Add accuracy from experiments * Change name from SWAG_LC to SWAG_LINEAR * Add comment on SWAG_LINEAR weight * Remove the comment docs (moved to PR description), and add the PR url as recipe. Also change name of previous swag model to SWAG_E2E_V1
1 parent c399c3f commit 3fa2414

File tree

2 files changed

+101
-8
lines changed

2 files changed

+101
-8
lines changed

torchvision/models/regnet.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -575,7 +575,7 @@ class RegNet_Y_16GF_Weights(WeightsEnum):
575575
"acc@5": 96.328,
576576
},
577577
)
578-
IMAGENET1K_SWAG_V1 = Weights(
578+
IMAGENET1K_SWAG_E2E_V1 = Weights(
579579
url="https://download.pytorch.org/models/regnet_y_16gf_swag-43afe44d.pth",
580580
transforms=partial(
581581
ImageClassification, crop_size=384, resize_size=384, interpolation=InterpolationMode.BICUBIC
@@ -587,6 +587,19 @@ class RegNet_Y_16GF_Weights(WeightsEnum):
587587
"acc@5": 98.054,
588588
},
589589
)
590+
IMAGENET1K_SWAG_LINEAR_V1 = Weights(
591+
url="https://download.pytorch.org/models/regnet_y_16gf_lc_swag-f3ec0043.pth",
592+
transforms=partial(
593+
ImageClassification, crop_size=224, resize_size=224, interpolation=InterpolationMode.BICUBIC
594+
),
595+
meta={
596+
**_COMMON_SWAG_META,
597+
"recipe": "https://github.com/pytorch/vision/pull/5793",
598+
"num_params": 83590140,
599+
"acc@1": 83.976,
600+
"acc@5": 97.244,
601+
},
602+
)
590603
DEFAULT = IMAGENET1K_V2
591604

592605

@@ -613,7 +626,7 @@ class RegNet_Y_32GF_Weights(WeightsEnum):
613626
"acc@5": 96.498,
614627
},
615628
)
616-
IMAGENET1K_SWAG_V1 = Weights(
629+
IMAGENET1K_SWAG_E2E_V1 = Weights(
617630
url="https://download.pytorch.org/models/regnet_y_32gf_swag-04fdfa75.pth",
618631
transforms=partial(
619632
ImageClassification, crop_size=384, resize_size=384, interpolation=InterpolationMode.BICUBIC
@@ -625,11 +638,24 @@ class RegNet_Y_32GF_Weights(WeightsEnum):
625638
"acc@5": 98.362,
626639
},
627640
)
641+
IMAGENET1K_SWAG_LINEAR_V1 = Weights(
642+
url="https://download.pytorch.org/models/regnet_y_32gf_lc_swag-e1583746.pth",
643+
transforms=partial(
644+
ImageClassification, crop_size=224, resize_size=224, interpolation=InterpolationMode.BICUBIC
645+
),
646+
meta={
647+
**_COMMON_SWAG_META,
648+
"recipe": "https://github.com/pytorch/vision/pull/5793",
649+
"num_params": 145046770,
650+
"acc@1": 84.622,
651+
"acc@5": 97.480,
652+
},
653+
)
628654
DEFAULT = IMAGENET1K_V2
629655

630656

631657
class RegNet_Y_128GF_Weights(WeightsEnum):
632-
IMAGENET1K_SWAG_V1 = Weights(
658+
IMAGENET1K_SWAG_E2E_V1 = Weights(
633659
url="https://download.pytorch.org/models/regnet_y_128gf_swag-c8ce3e52.pth",
634660
transforms=partial(
635661
ImageClassification, crop_size=384, resize_size=384, interpolation=InterpolationMode.BICUBIC
@@ -641,7 +667,20 @@ class RegNet_Y_128GF_Weights(WeightsEnum):
641667
"acc@5": 98.682,
642668
},
643669
)
644-
DEFAULT = IMAGENET1K_SWAG_V1
670+
IMAGENET1K_SWAG_LINEAR_V1 = Weights(
671+
url="https://download.pytorch.org/models/regnet_y_128gf_lc_swag-cbe8ce12.pth",
672+
transforms=partial(
673+
ImageClassification, crop_size=224, resize_size=224, interpolation=InterpolationMode.BICUBIC
674+
),
675+
meta={
676+
**_COMMON_SWAG_META,
677+
"recipe": "https://github.com/pytorch/vision/pull/5793",
678+
"num_params": 644812894,
679+
"acc@1": 86.068,
680+
"acc@5": 97.844,
681+
},
682+
)
683+
DEFAULT = IMAGENET1K_SWAG_E2E_V1
645684

646685

647686
class RegNet_X_400MF_Weights(WeightsEnum):

torchvision/models/vision_transformer.py

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ class ViT_B_16_Weights(WeightsEnum):
349349
"acc@5": 95.318,
350350
},
351351
)
352-
IMAGENET1K_SWAG_V1 = Weights(
352+
IMAGENET1K_SWAG_E2E_V1 = Weights(
353353
url="https://download.pytorch.org/models/vit_b_16_swag-9ac1b537.pth",
354354
transforms=partial(
355355
ImageClassification,
@@ -366,6 +366,24 @@ class ViT_B_16_Weights(WeightsEnum):
366366
"acc@5": 97.650,
367367
},
368368
)
369+
IMAGENET1K_SWAG_LINEAR_V1 = Weights(
370+
url="https://download.pytorch.org/models/vit_b_16_lc_swag-4e70ced5.pth",
371+
transforms=partial(
372+
ImageClassification,
373+
crop_size=224,
374+
resize_size=224,
375+
interpolation=InterpolationMode.BICUBIC,
376+
),
377+
meta={
378+
**_COMMON_SWAG_META,
379+
"recipe": "https://github.com/pytorch/vision/pull/5793",
380+
"num_params": 86567656,
381+
"size": (224, 224),
382+
"min_size": (224, 224),
383+
"acc@1": 81.886,
384+
"acc@5": 96.180,
385+
},
386+
)
369387
DEFAULT = IMAGENET1K_V1
370388

371389

@@ -400,7 +418,7 @@ class ViT_L_16_Weights(WeightsEnum):
400418
"acc@5": 94.638,
401419
},
402420
)
403-
IMAGENET1K_SWAG_V1 = Weights(
421+
IMAGENET1K_SWAG_E2E_V1 = Weights(
404422
url="https://download.pytorch.org/models/vit_l_16_swag-4f3808c9.pth",
405423
transforms=partial(
406424
ImageClassification,
@@ -417,6 +435,24 @@ class ViT_L_16_Weights(WeightsEnum):
417435
"acc@5": 98.512,
418436
},
419437
)
438+
IMAGENET1K_SWAG_LINEAR_V1 = Weights(
439+
url="https://download.pytorch.org/models/vit_l_16_lc_swag-4d563306.pth",
440+
transforms=partial(
441+
ImageClassification,
442+
crop_size=224,
443+
resize_size=224,
444+
interpolation=InterpolationMode.BICUBIC,
445+
),
446+
meta={
447+
**_COMMON_SWAG_META,
448+
"recipe": "https://github.com/pytorch/vision/pull/5793",
449+
"num_params": 304326632,
450+
"size": (224, 224),
451+
"min_size": (224, 224),
452+
"acc@1": 85.146,
453+
"acc@5": 97.422,
454+
},
455+
)
420456
DEFAULT = IMAGENET1K_V1
421457

422458

@@ -438,7 +474,7 @@ class ViT_L_32_Weights(WeightsEnum):
438474

439475

440476
class ViT_H_14_Weights(WeightsEnum):
441-
IMAGENET1K_SWAG_V1 = Weights(
477+
IMAGENET1K_SWAG_E2E_V1 = Weights(
442478
url="https://download.pytorch.org/models/vit_h_14_swag-80465313.pth",
443479
transforms=partial(
444480
ImageClassification,
@@ -455,7 +491,25 @@ class ViT_H_14_Weights(WeightsEnum):
455491
"acc@5": 98.694,
456492
},
457493
)
458-
DEFAULT = IMAGENET1K_SWAG_V1
494+
IMAGENET1K_SWAG_LINEAR_V1 = Weights(
495+
url="https://download.pytorch.org/models/vit_h_14_lc_swag-c1eb923e.pth",
496+
transforms=partial(
497+
ImageClassification,
498+
crop_size=224,
499+
resize_size=224,
500+
interpolation=InterpolationMode.BICUBIC,
501+
),
502+
meta={
503+
**_COMMON_SWAG_META,
504+
"recipe": "https://github.com/pytorch/vision/pull/5793",
505+
"num_params": 632045800,
506+
"size": (224, 224),
507+
"min_size": (224, 224),
508+
"acc@1": 85.708,
509+
"acc@5": 97.730,
510+
},
511+
)
512+
DEFAULT = IMAGENET1K_SWAG_E2E_V1
459513

460514

461515
@handle_legacy_interface(weights=("pretrained", ViT_B_16_Weights.IMAGENET1K_V1))

0 commit comments

Comments
 (0)