Skip to content

Commit fcb9501

Browse files
committed
add register tokens to the nested tensor 3d na vit example for researcher
1 parent c4651a3 commit fcb9501

File tree

2 files changed

+23
-7
lines changed

2 files changed

+23
-7
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
setup(
77
name = 'vit-pytorch',
88
packages = find_packages(exclude=['examples']),
9-
version = '1.7.11',
9+
version = '1.7.12',
1010
license='MIT',
1111
description = 'Vision Transformer (ViT) - Pytorch',
1212
long_description=long_description,

vit_pytorch/na_vit_nested_tensor_3d.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ def __init__(
163163
dim_head = 64,
164164
dropout = 0.,
165165
emb_dropout = 0.,
166+
num_registers = 4,
166167
token_dropout_prob: float | None = None
167168
):
168169
super().__init__()
@@ -193,9 +194,18 @@ def __init__(
193194
nn.LayerNorm(dim),
194195
)
195196

196-
self.pos_embed_frame = nn.Parameter(torch.randn(patch_frame_dim, dim))
197-
self.pos_embed_height = nn.Parameter(torch.randn(patch_height_dim, dim))
198-
self.pos_embed_width = nn.Parameter(torch.randn(patch_width_dim, dim))
197+
self.pos_embed_frame = nn.Parameter(torch.zeros(patch_frame_dim, dim))
198+
self.pos_embed_height = nn.Parameter(torch.zeros(patch_height_dim, dim))
199+
self.pos_embed_width = nn.Parameter(torch.zeros(patch_width_dim, dim))
200+
201+
# register tokens
202+
203+
self.register_tokens = nn.Parameter(torch.zeros(num_registers, dim))
204+
205+
nn.init.normal_(self.pos_embed_frame, std = 0.02)
206+
nn.init.normal_(self.pos_embed_height, std = 0.02)
207+
nn.init.normal_(self.pos_embed_width, std = 0.02)
208+
nn.init.normal_(self.register_tokens, std = 0.02)
199209

200210
self.dropout = nn.Dropout(emb_dropout)
201211

@@ -275,8 +285,6 @@ def forward(
275285

276286
pos_embed = frame_embed + height_embed + width_embed
277287

278-
# use nested tensor for transformers and save on padding computation
279-
280288
tokens = torch.cat(tokens)
281289

282290
# linear projection to patch embeddings
@@ -287,7 +295,15 @@ def forward(
287295

288296
tokens = tokens + pos_embed
289297

290-
tokens = nested_tensor(tokens.split(seq_lens.tolist()), layout = torch.jagged, device = device)
298+
# add register tokens
299+
300+
tokens = tokens.split(seq_lens.tolist())
301+
302+
tokens = [torch.cat((self.register_tokens, one_tokens)) for one_tokens in tokens]
303+
304+
# use nested tensor for transformers and save on padding computation
305+
306+
tokens = nested_tensor(tokens, layout = torch.jagged, device = device)
291307

292308
# embedding dropout
293309

0 commit comments

Comments
 (0)