@@ -163,6 +163,7 @@ def __init__(
163
163
dim_head = 64 ,
164
164
dropout = 0. ,
165
165
emb_dropout = 0. ,
166
+ num_registers = 4 ,
166
167
token_dropout_prob : float | None = None
167
168
):
168
169
super ().__init__ ()
@@ -193,9 +194,18 @@ def __init__(
193
194
nn .LayerNorm (dim ),
194
195
)
195
196
196
- self .pos_embed_frame = nn .Parameter (torch .randn (patch_frame_dim , dim ))
197
- self .pos_embed_height = nn .Parameter (torch .randn (patch_height_dim , dim ))
198
- self .pos_embed_width = nn .Parameter (torch .randn (patch_width_dim , dim ))
197
+ self .pos_embed_frame = nn .Parameter (torch .zeros (patch_frame_dim , dim ))
198
+ self .pos_embed_height = nn .Parameter (torch .zeros (patch_height_dim , dim ))
199
+ self .pos_embed_width = nn .Parameter (torch .zeros (patch_width_dim , dim ))
200
+
201
+ # register tokens
202
+
203
+ self .register_tokens = nn .Parameter (torch .zeros (num_registers , dim ))
204
+
205
+ nn .init .normal_ (self .pos_embed_frame , std = 0.02 )
206
+ nn .init .normal_ (self .pos_embed_height , std = 0.02 )
207
+ nn .init .normal_ (self .pos_embed_width , std = 0.02 )
208
+ nn .init .normal_ (self .register_tokens , std = 0.02 )
199
209
200
210
self .dropout = nn .Dropout (emb_dropout )
201
211
@@ -275,8 +285,6 @@ def forward(
275
285
276
286
pos_embed = frame_embed + height_embed + width_embed
277
287
278
- # use nested tensor for transformers and save on padding computation
279
-
280
288
tokens = torch .cat (tokens )
281
289
282
290
# linear projection to patch embeddings
@@ -287,7 +295,15 @@ def forward(
287
295
288
296
tokens = tokens + pos_embed
289
297
290
- tokens = nested_tensor (tokens .split (seq_lens .tolist ()), layout = torch .jagged , device = device )
298
+ # add register tokens
299
+
300
+ tokens = tokens .split (seq_lens .tolist ())
301
+
302
+ tokens = [torch .cat ((self .register_tokens , one_tokens )) for one_tokens in tokens ]
303
+
304
+ # use nested tensor for transformers and save on padding computation
305
+
306
+ tokens = nested_tensor (tokens , layout = torch .jagged , device = device )
291
307
292
308
# embedding dropout
293
309
0 commit comments