Skip to content

Commit 5e808f4

Browse files
committed
3d version of navit nested tensor
1 parent bed48b5 commit 5e808f4

File tree

2 files changed

+349
-1
lines changed

2 files changed

+349
-1
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
setup(
77
name = 'vit-pytorch',
88
packages = find_packages(exclude=['examples']),
9-
version = '1.7.7',
9+
version = '1.7.10',
1010
license='MIT',
1111
description = 'Vision Transformer (ViT) - Pytorch',
1212
long_description=long_description,
Lines changed: 348 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,348 @@
1+
from __future__ import annotations
2+
3+
from typing import List
4+
from functools import partial
5+
6+
import torch
7+
import packaging.version as pkg_version
8+
9+
if pkg_version.parse(torch.__version__) < pkg_version.parse('2.4'):
10+
print('nested tensor NaViT was tested on pytorch 2.4')
11+
12+
from torch import nn, Tensor
13+
import torch.nn.functional as F
14+
from torch.nn import Module, ModuleList
15+
from torch.nested import nested_tensor
16+
17+
from einops import rearrange
18+
from einops.layers.torch import Rearrange
19+
20+
# helpers
21+
22+
def exists(val):
23+
return val is not None
24+
25+
def default(val, d):
26+
return val if exists(val) else d
27+
28+
def pair(t):
29+
return t if isinstance(t, tuple) else (t, t)
30+
31+
def divisible_by(numer, denom):
32+
return (numer % denom) == 0
33+
34+
# feedforward
35+
36+
def FeedForward(dim, hidden_dim, dropout = 0.):
37+
return nn.Sequential(
38+
nn.LayerNorm(dim, bias = False),
39+
nn.Linear(dim, hidden_dim),
40+
nn.GELU(),
41+
nn.Dropout(dropout),
42+
nn.Linear(hidden_dim, dim),
43+
nn.Dropout(dropout)
44+
)
45+
46+
class Attention(Module):
47+
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
48+
super().__init__()
49+
self.norm = nn.LayerNorm(dim, bias = False)
50+
51+
dim_inner = heads * dim_head
52+
self.heads = heads
53+
self.dim_head = dim_head
54+
55+
self.to_queries = nn.Linear(dim, dim_inner, bias = False)
56+
self.to_keys = nn.Linear(dim, dim_inner, bias = False)
57+
self.to_values = nn.Linear(dim, dim_inner, bias = False)
58+
59+
# in the paper, they employ qk rmsnorm, a way to stabilize attention
60+
# will use layernorm in place of rmsnorm, which has been shown to work in certain papers. requires l2norm on non-ragged dimension to be supported in nested tensors
61+
62+
self.query_norm = nn.LayerNorm(dim_head, bias = False)
63+
self.key_norm = nn.LayerNorm(dim_head, bias = False)
64+
65+
self.dropout = dropout
66+
67+
self.to_out = nn.Linear(dim_inner, dim, bias = False)
68+
69+
def forward(
70+
self,
71+
x,
72+
context: Tensor | None = None
73+
):
74+
75+
x = self.norm(x)
76+
77+
# for attention pooling, one query pooling to entire sequence
78+
79+
context = default(context, x)
80+
81+
# queries, keys, values
82+
83+
query = self.to_queries(x)
84+
key = self.to_keys(context)
85+
value = self.to_values(context)
86+
87+
# split heads
88+
89+
def split_heads(t):
90+
return t.unflatten(-1, (self.heads, self.dim_head)).transpose(1, 2).contiguous()
91+
92+
# queries, keys, values
93+
94+
query = self.to_queries(x)
95+
key = self.to_keys(context)
96+
value = self.to_values(context)
97+
98+
# split heads
99+
100+
def split_heads(t):
101+
return t.unflatten(-1, (self.heads, self.dim_head))
102+
103+
def transpose_head_seq(t):
104+
return t.transpose(1, 2)
105+
106+
query, key, value = map(split_heads, (query, key, value))
107+
108+
# qk norm for attention stability
109+
110+
query = self.query_norm(query)
111+
key = self.key_norm(key)
112+
113+
query, key, value = map(transpose_head_seq, (query, key, value))
114+
115+
# attention
116+
117+
out = F.scaled_dot_product_attention(
118+
query, key, value,
119+
dropout_p = self.dropout if self.training else 0.
120+
)
121+
122+
# merge heads
123+
124+
out = out.transpose(1, 2).flatten(-2)
125+
126+
return self.to_out(out)
127+
128+
class Transformer(Module):
129+
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
130+
super().__init__()
131+
self.layers = ModuleList([])
132+
133+
for _ in range(depth):
134+
self.layers.append(ModuleList([
135+
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
136+
FeedForward(dim, mlp_dim, dropout = dropout)
137+
]))
138+
139+
self.norm = nn.LayerNorm(dim, bias = False)
140+
141+
def forward(self, x):
142+
143+
for attn, ff in self.layers:
144+
x = attn(x) + x
145+
x = ff(x) + x
146+
147+
return self.norm(x)
148+
149+
class NaViT(Module):
150+
def __init__(
151+
self,
152+
*,
153+
image_size,
154+
max_frames,
155+
patch_size,
156+
frame_patch_size,
157+
num_classes,
158+
dim,
159+
depth,
160+
heads,
161+
mlp_dim,
162+
channels = 3,
163+
dim_head = 64,
164+
dropout = 0.,
165+
emb_dropout = 0.,
166+
token_dropout_prob: float | None = None
167+
):
168+
super().__init__()
169+
image_height, image_width = pair(image_size)
170+
171+
# what percent of tokens to dropout
172+
# if int or float given, then assume constant dropout prob
173+
# otherwise accept a callback that in turn calculates dropout prob from height and width
174+
175+
self.token_dropout_prob = token_dropout_prob
176+
177+
# calculate patching related stuff
178+
179+
assert divisible_by(image_height, patch_size) and divisible_by(image_width, patch_size), 'Image dimensions must be divisible by the patch size.'
180+
assert divisible_by(max_frames, frame_patch_size)
181+
182+
patch_frame_dim, patch_height_dim, patch_width_dim = (max_frames // frame_patch_size), (image_height // patch_size), (image_width // patch_size)
183+
184+
patch_dim = channels * (patch_size ** 2) * frame_patch_size
185+
186+
self.channels = channels
187+
self.patch_size = patch_size
188+
self.to_patches = Rearrange('c (f pf) (h p1) (w p2) -> f h w (c p1 p2 pf)', p1 = patch_size, p2 = patch_size, pf = frame_patch_size)
189+
190+
self.to_patch_embedding = nn.Sequential(
191+
nn.LayerNorm(patch_dim),
192+
nn.Linear(patch_dim, dim),
193+
nn.LayerNorm(dim),
194+
)
195+
196+
self.pos_embed_frame = nn.Parameter(torch.randn(patch_frame_dim, dim))
197+
self.pos_embed_height = nn.Parameter(torch.randn(patch_height_dim, dim))
198+
self.pos_embed_width = nn.Parameter(torch.randn(patch_width_dim, dim))
199+
200+
self.dropout = nn.Dropout(emb_dropout)
201+
202+
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
203+
204+
# final attention pooling queries
205+
206+
self.attn_pool_queries = nn.Parameter(torch.randn(dim))
207+
self.attn_pool = Attention(dim = dim, dim_head = dim_head, heads = heads)
208+
209+
# output to logits
210+
211+
self.to_latent = nn.Identity()
212+
213+
self.mlp_head = nn.Sequential(
214+
nn.LayerNorm(dim, bias = False),
215+
nn.Linear(dim, num_classes, bias = False)
216+
)
217+
218+
@property
219+
def device(self):
220+
return next(self.parameters()).device
221+
222+
def forward(
223+
self,
224+
volumes: List[Tensor], # different resolution images / CT scans
225+
):
226+
batch, device = len(volumes), self.device
227+
arange = partial(torch.arange, device = device)
228+
229+
assert all([volume.ndim == 4 and volume.shape[0] == self.channels for volume in volumes]), f'all volumes must have {self.channels} channels and number of dimensions of {self.channels} (channels, frame, height, width)'
230+
231+
all_patches = [self.to_patches(volume) for volume in volumes]
232+
233+
# prepare factorized positional embedding height width indices
234+
235+
positions = []
236+
237+
for patches in all_patches:
238+
patch_frame, patch_height, patch_width = patches.shape[:3]
239+
fhw_indices = torch.stack(torch.meshgrid((arange(patch_frame), arange(patch_height), arange(patch_width)), indexing = 'ij'), dim = -1)
240+
fhw_indices = rearrange(fhw_indices, 'f h w c -> (f h w) c')
241+
242+
positions.append(fhw_indices)
243+
244+
# need the sizes to compute token dropout + positional embedding
245+
246+
tokens = [rearrange(patches, 'f h w d -> (f h w) d') for patches in all_patches]
247+
248+
# handle token dropout
249+
250+
seq_lens = torch.tensor([i.shape[0] for i in tokens], device = device)
251+
252+
if self.training and self.token_dropout_prob > 0:
253+
254+
keep_seq_lens = ((1. - self.token_dropout_prob) * seq_lens).int().clamp(min = 1)
255+
256+
kept_tokens = []
257+
kept_positions = []
258+
259+
for one_image_tokens, one_image_positions, seq_len, num_keep in zip(tokens, positions, seq_lens, keep_seq_lens):
260+
keep_indices = torch.randn((seq_len,), device = device).topk(num_keep, dim = -1).indices
261+
262+
one_image_kept_tokens = one_image_tokens[keep_indices]
263+
one_image_kept_positions = one_image_positions[keep_indices]
264+
265+
kept_tokens.append(one_image_kept_tokens)
266+
kept_positions.append(one_image_kept_positions)
267+
268+
tokens, positions, seq_lens = kept_tokens, kept_positions, keep_seq_lens
269+
270+
# add all height and width factorized positions
271+
272+
273+
frame_indices, height_indices, width_indices = torch.cat(positions).unbind(dim = -1)
274+
frame_embed, height_embed, width_embed = self.pos_embed_frame[frame_indices], self.pos_embed_height[height_indices], self.pos_embed_width[width_indices]
275+
276+
pos_embed = frame_embed + height_embed + width_embed
277+
278+
# use nested tensor for transformers and save on padding computation
279+
280+
tokens = torch.cat(tokens)
281+
282+
# linear projection to patch embeddings
283+
284+
tokens = self.to_patch_embedding(tokens)
285+
286+
# absolute positions
287+
288+
tokens = tokens + pos_embed
289+
290+
tokens = nested_tensor(tokens.split(seq_lens.tolist()), layout = torch.jagged, device = device)
291+
292+
# embedding dropout
293+
294+
tokens = self.dropout(tokens)
295+
296+
# transformer
297+
298+
tokens = self.transformer(tokens)
299+
300+
# attention pooling
301+
# will use a jagged tensor for queries, as SDPA requires all inputs to be jagged, or not
302+
303+
attn_pool_queries = [rearrange(self.attn_pool_queries, '... -> 1 ...')] * batch
304+
305+
attn_pool_queries = nested_tensor(attn_pool_queries, layout = torch.jagged)
306+
307+
pooled = self.attn_pool(attn_pool_queries, tokens)
308+
309+
# back to unjagged
310+
311+
logits = torch.stack(pooled.unbind())
312+
313+
logits = rearrange(logits, 'b 1 d -> b d')
314+
315+
logits = self.to_latent(logits)
316+
317+
return self.mlp_head(logits)
318+
319+
# quick test
320+
321+
if __name__ == '__main__':
322+
323+
# works for torch 2.4
324+
325+
v = NaViT(
326+
image_size = 256,
327+
max_frames = 8,
328+
patch_size = 32,
329+
frame_patch_size = 2,
330+
num_classes = 1000,
331+
dim = 1024,
332+
depth = 6,
333+
heads = 16,
334+
mlp_dim = 2048,
335+
dropout = 0.,
336+
emb_dropout = 0.,
337+
token_dropout_prob = 0.1
338+
)
339+
340+
# 5 volumetric data (videos or CT scans) of different resolutions - List[Tensor]
341+
342+
volumes = [
343+
torch.randn(3, 2, 256, 256), torch.randn(3, 8, 128, 128),
344+
torch.randn(3, 4, 128, 256), torch.randn(3, 2, 256, 128),
345+
torch.randn(3, 4, 64, 256)
346+
]
347+
348+
assert v(volumes).shape == (5, 1000)

0 commit comments

Comments
 (0)