Skip to content

Commit 4ef72fc

Browse files
committed
add EsViT, by popular request, an alternative to Dino that is compatible with efficient ViTs with accounting for regional self-supervised loss
1 parent c2aab05 commit 4ef72fc

File tree

5 files changed

+458
-4
lines changed

5 files changed

+458
-4
lines changed

README.md

+85
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
- [Parallel ViT](#parallel-vit)
3333
- [Learnable Memory ViT](#learnable-memory-vit)
3434
- [Dino](#dino)
35+
- [EsViT](#esvit)
3536
- [Accessing Attention](#accessing-attention)
3637
- [Research Ideas](#research-ideas)
3738
* [Efficient Attention](#efficient-attention)
@@ -1076,6 +1077,80 @@ for _ in range(100):
10761077
torch.save(model.state_dict(), './pretrained-net.pt')
10771078
```
10781079

1080+
## EsViT
1081+
1082+
<img src="./images/esvit.png" width="350px"></img>
1083+
1084+
<a href="https://arxiv.org/abs/2106.09785">`EsViT`</a> is a variant of Dino (from above) re-engineered to support efficient `ViT`s with patch merging / downsampling by taking into an account an extra regional loss between the augmented views. To quote the abstract, it `outperforms its supervised counterpart on 17 out of 18 datasets` at 3 times higher throughput.
1085+
1086+
Even though it is named as though it were a new `ViT` variant, it actually is just a strategy for training any multistage `ViT` (in the paper, they focused on Swin). The example below will show how to use it with `CvT`. You'll need to set the `hidden_layer` to the name of the layer within your efficient ViT that outputs the non-average pooled visual representations, just before the global pooling and projection to logits.
1087+
1088+
```python
1089+
import torch
1090+
from vit_pytorch.cvt import CvT
1091+
from vit_pytorch.es_vit import EsViTTrainer
1092+
1093+
cvt = CvT(
1094+
num_classes = 1000,
1095+
s1_emb_dim = 64,
1096+
s1_emb_kernel = 7,
1097+
s1_emb_stride = 4,
1098+
s1_proj_kernel = 3,
1099+
s1_kv_proj_stride = 2,
1100+
s1_heads = 1,
1101+
s1_depth = 1,
1102+
s1_mlp_mult = 4,
1103+
s2_emb_dim = 192,
1104+
s2_emb_kernel = 3,
1105+
s2_emb_stride = 2,
1106+
s2_proj_kernel = 3,
1107+
s2_kv_proj_stride = 2,
1108+
s2_heads = 3,
1109+
s2_depth = 2,
1110+
s2_mlp_mult = 4,
1111+
s3_emb_dim = 384,
1112+
s3_emb_kernel = 3,
1113+
s3_emb_stride = 2,
1114+
s3_proj_kernel = 3,
1115+
s3_kv_proj_stride = 2,
1116+
s3_heads = 4,
1117+
s3_depth = 10,
1118+
s3_mlp_mult = 4,
1119+
dropout = 0.
1120+
)
1121+
1122+
learner = EsViTTrainer(
1123+
cvt,
1124+
image_size = 256,
1125+
hidden_layer = 'layers', # hidden layer name or index, from which to extract the embedding
1126+
projection_hidden_size = 256, # projector network hidden dimension
1127+
projection_layers = 4, # number of layers in projection network
1128+
num_classes_K = 65336, # output logits dimensions (referenced as K in paper)
1129+
student_temp = 0.9, # student temperature
1130+
teacher_temp = 0.04, # teacher temperature, needs to be annealed from 0.04 to 0.07 over 30 epochs
1131+
local_upper_crop_scale = 0.4, # upper bound for local crop - 0.4 was recommended in the paper
1132+
global_lower_crop_scale = 0.5, # lower bound for global crop - 0.5 was recommended in the paper
1133+
moving_average_decay = 0.9, # moving average of encoder - paper showed anywhere from 0.9 to 0.999 was ok
1134+
center_moving_average_decay = 0.9, # moving average of teacher centers - paper showed anywhere from 0.9 to 0.999 was ok
1135+
)
1136+
1137+
opt = torch.optim.AdamW(learner.parameters(), lr = 3e-4)
1138+
1139+
def sample_unlabelled_images():
1140+
return torch.randn(8, 3, 256, 256)
1141+
1142+
for _ in range(1000):
1143+
images = sample_unlabelled_images()
1144+
loss = learner(images)
1145+
opt.zero_grad()
1146+
loss.backward()
1147+
opt.step()
1148+
learner.update_moving_average() # update moving average of teacher encoder and teacher centers
1149+
1150+
# save your improved network
1151+
torch.save(cvt.state_dict(), './pretrained-net.pt')
1152+
```
1153+
10791154
## Accessing Attention
10801155

10811156
If you would like to visualize the attention weights (post-softmax) for your research, just follow the procedure below
@@ -1584,6 +1659,16 @@ Coming from computer vision and new to transformers? Here are some resources tha
15841659
}
15851660
```
15861661

1662+
```bibtex
1663+
@article{Li2021EfficientSV,
1664+
title = {Efficient Self-supervised Vision Transformers for Representation Learning},
1665+
author = {Chunyuan Li and Jianwei Yang and Pengchuan Zhang and Mei Gao and Bin Xiao and Xiyang Dai and Lu Yuan and Jianfeng Gao},
1666+
journal = {ArXiv},
1667+
year = {2021},
1668+
volume = {abs/2106.09785}
1669+
}
1670+
```
1671+
15871672
```bibtex
15881673
@misc{vaswani2017attention,
15891674
title = {Attention Is All You Need},

images/esvit.png

191 KB
Loading

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vit-pytorch',
55
packages = find_packages(exclude=['examples']),
6-
version = '0.33.2',
6+
version = '0.34.1',
77
license='MIT',
88
description = 'Vision Transformer (ViT) - Pytorch',
99
author = 'Phil Wang',

vit_pytorch/cvt.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -164,12 +164,14 @@ def __init__(
164164

165165
dim = config['emb_dim']
166166

167-
self.layers = nn.Sequential(
168-
*layers,
167+
self.layers = nn.Sequential(*layers)
168+
169+
self.to_logits = nn.Sequential(
169170
nn.AdaptiveAvgPool2d(1),
170171
Rearrange('... () () -> ...'),
171172
nn.Linear(dim, num_classes)
172173
)
173174

174175
def forward(self, x):
175-
return self.layers(x)
176+
latents = self.layers(x)
177+
return self.to_logits(latents)

0 commit comments

Comments
 (0)