Skip to content

Commit 0b66abe

Browse files
prabhat00155datumbox
authored andcommitted
[fbsync] Adding ViT to torchvision/models (#4594)
Summary: * [vit] Adding ViT to torchvision/models * adding pre-logits layer + resolving comments * Fix the model attribute bug * Change version to arch * fix failing unittests * remove useless prints * reduce input size to fix unittests * Increase windows-cpu executor to 2xlarge * Use `batch_first=True` and remove classifier * Change resource_class back to xlarge * Remove vit_h_14 * Remove vit_h_14 from __all__ * Move vision_transformer.py into prototype * Fix formatting issue * remove arch in builder * Fix type err in model builder * address comments and trigger unittests * remove the prototype import in torchvision.models * Adding vit back to models to trigger CircleCI test * fix test_jit_forward_backward * Move all to prototype. * Adopt new helper methods and fix prototype tests. * Remove unused import. Reviewed By: NicolasHug Differential Revision: D32694316 fbshipit-source-id: fa2867555fb7ae65f8dab537517386f6694585a2 Co-authored-by: Vasilis Vryniotis <[email protected]> Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent 9e0f868 commit 0b66abe

9 files changed

+438
-6
lines changed
939 Bytes
Binary file not shown.
939 Bytes
Binary file not shown.
939 Bytes
Binary file not shown.
939 Bytes
Binary file not shown.

test/test_backbone_utils.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import random
22
from itertools import chain
3+
from typing import Mapping, Sequence
34

45
import pytest
56
import torch
@@ -89,7 +90,16 @@ def _create_feature_extractor(self, *args, **kwargs):
8990

9091
def _get_return_nodes(self, model):
9192
set_rng_seed(0)
92-
exclude_nodes_filter = ["getitem", "floordiv", "size", "chunk"]
93+
exclude_nodes_filter = [
94+
"getitem",
95+
"floordiv",
96+
"size",
97+
"chunk",
98+
"_assert",
99+
"eq",
100+
"dim",
101+
"getattr",
102+
]
93103
train_nodes, eval_nodes = get_graph_node_names(
94104
model, tracer_kwargs={"leaf_modules": self.leaf_modules}, suppress_diff_warning=True
95105
)
@@ -144,7 +154,16 @@ def test_forward_backward(self, model_name):
144154
model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
145155
)
146156
out = model(self.inp)
147-
sum(o.mean() for o in out.values()).backward()
157+
out_agg = 0
158+
for node_out in out.values():
159+
if isinstance(node_out, Sequence):
160+
out_agg += sum(o.mean() for o in node_out if o is not None)
161+
elif isinstance(node_out, Mapping):
162+
out_agg += sum(o.mean() for o in node_out.values() if o is not None)
163+
else:
164+
# Assume that the only other alternative at this point is a Tensor
165+
out_agg += node_out.mean()
166+
out_agg.backward()
148167

149168
def test_feature_extraction_methods_equivalence(self):
150169
model = models.resnet18(**self.model_defaults).eval()
@@ -176,7 +195,16 @@ def test_jit_forward_backward(self, model_name):
176195
)
177196
model = torch.jit.script(model)
178197
fgn_out = model(self.inp)
179-
sum(o.mean() for o in fgn_out.values()).backward()
198+
out_agg = 0
199+
for node_out in fgn_out.values():
200+
if isinstance(node_out, Sequence):
201+
out_agg += sum(o.mean() for o in node_out if o is not None)
202+
elif isinstance(node_out, Mapping):
203+
out_agg += sum(o.mean() for o in node_out.values() if o is not None)
204+
else:
205+
# Assume that the only other alternative at this point is a Tensor
206+
out_agg += node_out.mean()
207+
out_agg.backward()
180208

181209
def test_train_eval(self):
182210
class TestModel(torch.nn.Module):

test/test_models.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,7 @@ def test_classification_model(model_fn, dev):
507507
}
508508
model_name = model_fn.__name__
509509
kwargs = {**defaults, **_model_params.get(model_name, {})}
510+
num_classes = kwargs.get("num_classes")
510511
input_shape = kwargs.pop("input_shape")
511512

512513
model = model_fn(**kwargs)
@@ -515,7 +516,7 @@ def test_classification_model(model_fn, dev):
515516
x = torch.rand(input_shape).to(device=dev)
516517
out = model(x)
517518
_assert_expected(out.cpu(), model_name, prec=0.1)
518-
assert out.shape[-1] == 50
519+
assert out.shape[-1] == num_classes
519520
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None))
520521
_check_fx_compatible(model, x)
521522

test/test_prototype_models.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,11 @@ def test_old_vs_new_factory(model_fn, module_name, dev):
122122
x = [x]
123123

124124
# compare with new model builder parameterized in the old fashion way
125-
model_old = _build_model(_get_original_model(model_fn), **kwargs).to(device=dev)
126-
model_new = _build_model(model_fn, **kwargs).to(device=dev)
125+
try:
126+
model_old = _build_model(_get_original_model(model_fn), **kwargs).to(device=dev)
127+
model_new = _build_model(model_fn, **kwargs).to(device=dev)
128+
except ModuleNotFoundError:
129+
pytest.skip(f"Model '{model_name}' not available in both modules.")
127130
torch.testing.assert_close(model_new(x), model_old(x), rtol=0.0, atol=0.0, check_dtype=False)
128131

129132

torchvision/prototype/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .shufflenetv2 import *
1111
from .squeezenet import *
1212
from .vgg import *
13+
from .vision_transformer import *
1314
from . import detection
1415
from . import quantization
1516
from . import segmentation

0 commit comments

Comments
 (0)