|
32 | 32 | - [Parallel ViT](#parallel-vit)
|
33 | 33 | - [Learnable Memory ViT](#learnable-memory-vit)
|
34 | 34 | - [Dino](#dino)
|
| 35 | +- [EsViT](#esvit) |
35 | 36 | - [Accessing Attention](#accessing-attention)
|
36 | 37 | - [Research Ideas](#research-ideas)
|
37 | 38 | * [Efficient Attention](#efficient-attention)
|
@@ -1076,6 +1077,80 @@ for _ in range(100):
|
1076 | 1077 | torch.save(model.state_dict(), './pretrained-net.pt')
|
1077 | 1078 | ```
|
1078 | 1079 |
|
| 1080 | +## EsViT |
| 1081 | + |
| 1082 | +<img src="./images/esvit.png" width="350px"></img> |
| 1083 | + |
| 1084 | +<a href="https://arxiv.org/abs/2106.09785">`EsViT`</a> is a variant of Dino (from above) re-engineered to support efficient `ViT`s with patch merging / downsampling by taking into an account an extra regional loss between the augmented views. To quote the abstract, it `outperforms its supervised counterpart on 17 out of 18 datasets` at 3 times higher throughput. |
| 1085 | + |
| 1086 | +Even though it is named as though it were a new `ViT` variant, it actually is just a strategy for training any multistage `ViT` (in the paper, they focused on Swin). The example below will show how to use it with `CvT`. You'll need to set the `hidden_layer` to the name of the layer within your efficient ViT that outputs the non-average pooled visual representations, just before the global pooling and projection to logits. |
| 1087 | + |
| 1088 | +```python |
| 1089 | +import torch |
| 1090 | +from vit_pytorch.cvt import CvT |
| 1091 | +from vit_pytorch.es_vit import EsViTTrainer |
| 1092 | + |
| 1093 | +cvt = CvT( |
| 1094 | + num_classes = 1000, |
| 1095 | + s1_emb_dim = 64, |
| 1096 | + s1_emb_kernel = 7, |
| 1097 | + s1_emb_stride = 4, |
| 1098 | + s1_proj_kernel = 3, |
| 1099 | + s1_kv_proj_stride = 2, |
| 1100 | + s1_heads = 1, |
| 1101 | + s1_depth = 1, |
| 1102 | + s1_mlp_mult = 4, |
| 1103 | + s2_emb_dim = 192, |
| 1104 | + s2_emb_kernel = 3, |
| 1105 | + s2_emb_stride = 2, |
| 1106 | + s2_proj_kernel = 3, |
| 1107 | + s2_kv_proj_stride = 2, |
| 1108 | + s2_heads = 3, |
| 1109 | + s2_depth = 2, |
| 1110 | + s2_mlp_mult = 4, |
| 1111 | + s3_emb_dim = 384, |
| 1112 | + s3_emb_kernel = 3, |
| 1113 | + s3_emb_stride = 2, |
| 1114 | + s3_proj_kernel = 3, |
| 1115 | + s3_kv_proj_stride = 2, |
| 1116 | + s3_heads = 4, |
| 1117 | + s3_depth = 10, |
| 1118 | + s3_mlp_mult = 4, |
| 1119 | + dropout = 0. |
| 1120 | +) |
| 1121 | + |
| 1122 | +learner = EsViTTrainer( |
| 1123 | + cvt, |
| 1124 | + image_size = 256, |
| 1125 | + hidden_layer = 'layers', # hidden layer name or index, from which to extract the embedding |
| 1126 | + projection_hidden_size = 256, # projector network hidden dimension |
| 1127 | + projection_layers = 4, # number of layers in projection network |
| 1128 | + num_classes_K = 65336, # output logits dimensions (referenced as K in paper) |
| 1129 | + student_temp = 0.9, # student temperature |
| 1130 | + teacher_temp = 0.04, # teacher temperature, needs to be annealed from 0.04 to 0.07 over 30 epochs |
| 1131 | + local_upper_crop_scale = 0.4, # upper bound for local crop - 0.4 was recommended in the paper |
| 1132 | + global_lower_crop_scale = 0.5, # lower bound for global crop - 0.5 was recommended in the paper |
| 1133 | + moving_average_decay = 0.9, # moving average of encoder - paper showed anywhere from 0.9 to 0.999 was ok |
| 1134 | + center_moving_average_decay = 0.9, # moving average of teacher centers - paper showed anywhere from 0.9 to 0.999 was ok |
| 1135 | +) |
| 1136 | + |
| 1137 | +opt = torch.optim.AdamW(learner.parameters(), lr = 3e-4) |
| 1138 | + |
| 1139 | +def sample_unlabelled_images(): |
| 1140 | + return torch.randn(8, 3, 256, 256) |
| 1141 | + |
| 1142 | +for _ in range(1000): |
| 1143 | + images = sample_unlabelled_images() |
| 1144 | + loss = learner(images) |
| 1145 | + opt.zero_grad() |
| 1146 | + loss.backward() |
| 1147 | + opt.step() |
| 1148 | + learner.update_moving_average() # update moving average of teacher encoder and teacher centers |
| 1149 | + |
| 1150 | +# save your improved network |
| 1151 | +torch.save(cvt.state_dict(), './pretrained-net.pt') |
| 1152 | +``` |
| 1153 | + |
1079 | 1154 | ## Accessing Attention
|
1080 | 1155 |
|
1081 | 1156 | If you would like to visualize the attention weights (post-softmax) for your research, just follow the procedure below
|
@@ -1584,6 +1659,16 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
1584 | 1659 | }
|
1585 | 1660 | ```
|
1586 | 1661 |
|
| 1662 | +```bibtex |
| 1663 | +@article{Li2021EfficientSV, |
| 1664 | + title = {Efficient Self-supervised Vision Transformers for Representation Learning}, |
| 1665 | + author = {Chunyuan Li and Jianwei Yang and Pengchuan Zhang and Mei Gao and Bin Xiao and Xiyang Dai and Lu Yuan and Jianfeng Gao}, |
| 1666 | + journal = {ArXiv}, |
| 1667 | + year = {2021}, |
| 1668 | + volume = {abs/2106.09785} |
| 1669 | +} |
| 1670 | +``` |
| 1671 | + |
1587 | 1672 | ```bibtex
|
1588 | 1673 | @misc{vaswani2017attention,
|
1589 | 1674 | title = {Attention Is All You Need},
|
|
0 commit comments