Skip to content

Commit ab63fc9

Browse files
authored
remove duplicated qkv computation in na_vit_nested_tensor_3d.py (#341)
1 parent c3018d1 commit ab63fc9

File tree

1 file changed

+0
-11
lines changed

1 file changed

+0
-11
lines changed

vit_pytorch/na_vit_nested_tensor_3d.py

-11
Original file line numberDiff line numberDiff line change
@@ -83,17 +83,6 @@ def forward(
8383

8484
# split heads
8585

86-
def split_heads(t):
87-
return t.unflatten(-1, (self.heads, self.dim_head)).transpose(1, 2).contiguous()
88-
89-
# queries, keys, values
90-
91-
query = self.to_queries(x)
92-
key = self.to_keys(context)
93-
value = self.to_values(context)
94-
95-
# split heads
96-
9786
def split_heads(t):
9887
return t.unflatten(-1, (self.heads, self.dim_head))
9988

0 commit comments

Comments
 (0)