Skip to content

Commit 19eb6d4

Browse files
committed
3d version of navit nested tensor
1 parent bed48b5 commit 19eb6d4

File tree

2 files changed

+333
-1
lines changed

2 files changed

+333
-1
lines changed

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
setup(
77
name = 'vit-pytorch',
88
packages = find_packages(exclude=['examples']),
9-
version = '1.7.7',
9+
version = '1.7.9',
1010
license='MIT',
1111
description = 'Vision Transformer (ViT) - Pytorch',
1212
long_description=long_description,
+332
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,332 @@
1+
from __future__ import annotations
2+
3+
from typing import List
4+
from functools import partial
5+
6+
import torch
7+
import packaging.version as pkg_version
8+
9+
if pkg_version.parse(torch.__version__) < pkg_version.parse('2.4'):
10+
print('nested tensor NaViT was tested on pytorch 2.4')
11+
12+
from torch import nn, Tensor
13+
import torch.nn.functional as F
14+
from torch.nn import Module, ModuleList
15+
from torch.nested import nested_tensor
16+
17+
from einops import rearrange
18+
from einops.layers.torch import Rearrange
19+
20+
# helpers
21+
22+
def exists(val):
23+
return val is not None
24+
25+
def default(val, d):
26+
return val if exists(val) else d
27+
28+
def pair(t):
29+
return t if isinstance(t, tuple) else (t, t)
30+
31+
def divisible_by(numer, denom):
32+
return (numer % denom) == 0
33+
34+
# feedforward
35+
36+
def FeedForward(dim, hidden_dim, dropout = 0.):
37+
return nn.Sequential(
38+
nn.LayerNorm(dim, bias = False),
39+
nn.Linear(dim, hidden_dim),
40+
nn.GELU(),
41+
nn.Dropout(dropout),
42+
nn.Linear(hidden_dim, dim),
43+
nn.Dropout(dropout)
44+
)
45+
46+
class Attention(Module):
47+
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
48+
super().__init__()
49+
self.norm = nn.LayerNorm(dim, bias = False)
50+
51+
dim_inner = heads * dim_head
52+
self.heads = heads
53+
self.dim_head = dim_head
54+
55+
self.to_queries = nn.Linear(dim, dim_inner, bias = False)
56+
self.to_keys = nn.Linear(dim, dim_inner, bias = False)
57+
self.to_values = nn.Linear(dim, dim_inner, bias = False)
58+
59+
# in the paper, they employ qk rmsnorm, a way to stabilize attention
60+
# will use layernorm in place of rmsnorm, which has been shown to work in certain papers. requires l2norm on non-ragged dimension to be supported in nested tensors
61+
62+
self.query_norm = nn.LayerNorm(dim_head, bias = False)
63+
self.key_norm = nn.LayerNorm(dim_head, bias = False)
64+
65+
self.dropout = dropout
66+
67+
self.to_out = nn.Linear(dim_inner, dim, bias = False)
68+
69+
def forward(
70+
self,
71+
x,
72+
context: Tensor | None = None
73+
):
74+
75+
x = self.norm(x)
76+
77+
# for attention pooling, one query pooling to entire sequence
78+
79+
context = default(context, x)
80+
81+
# queries, keys, values
82+
83+
query = self.to_queries(x)
84+
key = self.to_keys(context)
85+
value = self.to_values(context)
86+
87+
# split heads
88+
89+
def split_heads(t):
90+
return t.unflatten(-1, (self.heads, self.dim_head)).transpose(1, 2).contiguous()
91+
92+
query, key, value = map(split_heads, (query, key, value))
93+
94+
# qk norm for attention stability
95+
96+
query = self.query_norm(query)
97+
key = self.key_norm(key)
98+
99+
# attention
100+
101+
out = F.scaled_dot_product_attention(
102+
query, key, value,
103+
dropout_p = self.dropout if self.training else 0.
104+
)
105+
106+
# merge heads
107+
108+
out = out.transpose(1, 2).flatten(-2)
109+
110+
return self.to_out(out)
111+
112+
class Transformer(Module):
113+
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
114+
super().__init__()
115+
self.layers = ModuleList([])
116+
117+
for _ in range(depth):
118+
self.layers.append(ModuleList([
119+
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
120+
FeedForward(dim, mlp_dim, dropout = dropout)
121+
]))
122+
123+
self.norm = nn.LayerNorm(dim, bias = False)
124+
125+
def forward(self, x):
126+
127+
for attn, ff in self.layers:
128+
x = attn(x) + x
129+
x = ff(x) + x
130+
131+
return self.norm(x)
132+
133+
class NaViT(Module):
134+
def __init__(
135+
self,
136+
*,
137+
image_size,
138+
max_frames,
139+
patch_size,
140+
frame_patch_size,
141+
num_classes,
142+
dim,
143+
depth,
144+
heads,
145+
mlp_dim,
146+
channels = 3,
147+
dim_head = 64,
148+
dropout = 0.,
149+
emb_dropout = 0.,
150+
token_dropout_prob: float | None = None
151+
):
152+
super().__init__()
153+
image_height, image_width = pair(image_size)
154+
155+
# what percent of tokens to dropout
156+
# if int or float given, then assume constant dropout prob
157+
# otherwise accept a callback that in turn calculates dropout prob from height and width
158+
159+
self.token_dropout_prob = token_dropout_prob
160+
161+
# calculate patching related stuff
162+
163+
assert divisible_by(image_height, patch_size) and divisible_by(image_width, patch_size), 'Image dimensions must be divisible by the patch size.'
164+
assert divisible_by(max_frames, frame_patch_size)
165+
166+
patch_frame_dim, patch_height_dim, patch_width_dim = (max_frames // frame_patch_size), (image_height // patch_size), (image_width // patch_size)
167+
168+
patch_dim = channels * (patch_size ** 2) * frame_patch_size
169+
170+
self.channels = channels
171+
self.patch_size = patch_size
172+
self.to_patches = Rearrange('c (f pf) (h p1) (w p2) -> f h w (c p1 p2 pf)', p1 = patch_size, p2 = patch_size, pf = frame_patch_size)
173+
174+
self.to_patch_embedding = nn.Sequential(
175+
nn.LayerNorm(patch_dim),
176+
nn.Linear(patch_dim, dim),
177+
nn.LayerNorm(dim),
178+
)
179+
180+
self.pos_embed_frame = nn.Parameter(torch.randn(patch_frame_dim, dim))
181+
self.pos_embed_height = nn.Parameter(torch.randn(patch_height_dim, dim))
182+
self.pos_embed_width = nn.Parameter(torch.randn(patch_width_dim, dim))
183+
184+
self.dropout = nn.Dropout(emb_dropout)
185+
186+
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
187+
188+
# final attention pooling queries
189+
190+
self.attn_pool_queries = nn.Parameter(torch.randn(dim))
191+
self.attn_pool = Attention(dim = dim, dim_head = dim_head, heads = heads)
192+
193+
# output to logits
194+
195+
self.to_latent = nn.Identity()
196+
197+
self.mlp_head = nn.Sequential(
198+
nn.LayerNorm(dim, bias = False),
199+
nn.Linear(dim, num_classes, bias = False)
200+
)
201+
202+
@property
203+
def device(self):
204+
return next(self.parameters()).device
205+
206+
def forward(
207+
self,
208+
volumes: List[Tensor], # different resolution images / CT scans
209+
):
210+
batch, device = len(volumes), self.device
211+
arange = partial(torch.arange, device = device)
212+
213+
assert all([volume.ndim == 4 and volume.shape[0] == self.channels for volume in volumes]), f'all volumes must have {self.channels} channels and number of dimensions of {self.channels} (channels, frame, height, width)'
214+
215+
all_patches = [self.to_patches(volume) for volume in volumes]
216+
217+
# prepare factorized positional embedding height width indices
218+
219+
positions = []
220+
221+
for patches in all_patches:
222+
patch_frame, patch_height, patch_width = patches.shape[:3]
223+
fhw_indices = torch.stack(torch.meshgrid((arange(patch_frame), arange(patch_height), arange(patch_width)), indexing = 'ij'), dim = -1)
224+
fhw_indices = rearrange(fhw_indices, 'f h w c -> (f h w) c')
225+
226+
positions.append(fhw_indices)
227+
228+
# need the sizes to compute token dropout + positional embedding
229+
230+
tokens = [rearrange(patches, 'f h w d -> (f h w) d') for patches in all_patches]
231+
232+
# handle token dropout
233+
234+
seq_lens = torch.tensor([i.shape[0] for i in tokens], device = device)
235+
236+
if self.training and self.token_dropout_prob > 0:
237+
238+
keep_seq_lens = ((1. - self.token_dropout_prob) * seq_lens).int().clamp(min = 1)
239+
240+
kept_tokens = []
241+
kept_positions = []
242+
243+
for one_image_tokens, one_image_positions, seq_len, num_keep in zip(tokens, positions, seq_lens, keep_seq_lens):
244+
keep_indices = torch.randn((seq_len,), device = device).topk(num_keep, dim = -1).indices
245+
246+
one_image_kept_tokens = one_image_tokens[keep_indices]
247+
one_image_kept_positions = one_image_positions[keep_indices]
248+
249+
kept_tokens.append(one_image_kept_tokens)
250+
kept_positions.append(one_image_kept_positions)
251+
252+
tokens, positions, seq_lens = kept_tokens, kept_positions, keep_seq_lens
253+
254+
# add all height and width factorized positions
255+
256+
257+
frame_indices, height_indices, width_indices = torch.cat(positions).unbind(dim = -1)
258+
frame_embed, height_embed, width_embed = self.pos_embed_frame[frame_indices], self.pos_embed_height[height_indices], self.pos_embed_width[width_indices]
259+
260+
pos_embed = frame_embed + height_embed + width_embed
261+
262+
# use nested tensor for transformers and save on padding computation
263+
264+
tokens = torch.cat(tokens)
265+
266+
# linear projection to patch embeddings
267+
268+
tokens = self.to_patch_embedding(tokens)
269+
270+
# absolute positions
271+
272+
tokens = tokens + pos_embed
273+
274+
tokens = nested_tensor(tokens.split(seq_len.tolist()), layout = torch.jagged, device = device)
275+
276+
# embedding dropout
277+
278+
tokens = self.dropout(tokens)
279+
280+
# transformer
281+
282+
tokens = self.transformer(tokens)
283+
284+
# attention pooling
285+
# will use a jagged tensor for queries, as SDPA requires all inputs to be jagged, or not
286+
287+
attn_pool_queries = [rearrange(self.attn_pool_queries, '... -> 1 ...')] * batch
288+
289+
attn_pool_queries = nested_tensor(attn_pool_queries, layout = torch.jagged)
290+
291+
pooled = self.attn_pool(attn_pool_queries, tokens)
292+
293+
# back to unjagged
294+
295+
logits = torch.stack(pooled.unbind())
296+
297+
logits = rearrange(logits, 'b 1 d -> b d')
298+
299+
logits = self.to_latent(logits)
300+
301+
return self.mlp_head(logits)
302+
303+
# quick test
304+
305+
if __name__ == '__main__':
306+
307+
# works for torch 2.4
308+
309+
v = NaViT(
310+
image_size = 256,
311+
max_frames = 8,
312+
patch_size = 32,
313+
frame_patch_size = 2,
314+
num_classes = 1000,
315+
dim = 1024,
316+
depth = 6,
317+
heads = 16,
318+
mlp_dim = 2048,
319+
dropout = 0.,
320+
emb_dropout = 0.,
321+
token_dropout_prob = 0.1
322+
)
323+
324+
# 5 volumetric data (videos or CT scans) of different resolutions - List[Tensor]
325+
326+
volumes = [
327+
torch.randn(3, 2, 256, 256), torch.randn(3, 8, 128, 128),
328+
torch.randn(3, 4, 128, 256), torch.randn(3, 2, 256, 128),
329+
torch.randn(3, 4, 64, 256)
330+
]
331+
332+
assert v(volumes).shape == (5, 1000)

0 commit comments

Comments
 (0)