Skip to content

Commit b7ed6ba

Browse files
authored
add option to set frame padding for 3D CCT (#339)
1 parent e7cba9b commit b7ed6ba

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

vit_pytorch/cct_3d.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,10 @@ def __init__(
167167
stride,
168168
padding,
169169
frame_stride=1,
170+
frame_padding=None,
170171
frame_pooling_stride=1,
171172
frame_pooling_kernel_size=1,
173+
frame_pooling_padding=None,
172174
pooling_kernel_size=3,
173175
pooling_stride=2,
174176
pooling_padding=1,
@@ -188,16 +190,22 @@ def __init__(
188190

189191
n_filter_list_pairs = zip(n_filter_list[:-1], n_filter_list[1:])
190192

193+
if frame_padding is None:
194+
frame_padding = frame_kernel_size // 2
195+
196+
if frame_pooling_padding is None:
197+
frame_pooling_padding = frame_pooling_kernel_size // 2
198+
191199
self.conv_layers = nn.Sequential(
192200
*[nn.Sequential(
193201
nn.Conv3d(chan_in, chan_out,
194202
kernel_size=(frame_kernel_size, kernel_size, kernel_size),
195203
stride=(frame_stride, stride, stride),
196-
padding=(frame_kernel_size // 2, padding, padding), bias=conv_bias),
204+
padding=(frame_padding, padding, padding), bias=conv_bias),
197205
nn.Identity() if not exists(activation) else activation(),
198206
nn.MaxPool3d(kernel_size=(frame_pooling_kernel_size, pooling_kernel_size, pooling_kernel_size),
199207
stride=(frame_pooling_stride, pooling_stride, pooling_stride),
200-
padding=(frame_pooling_kernel_size // 2, pooling_padding, pooling_padding)) if max_pool else nn.Identity()
208+
padding=(frame_pooling_padding, pooling_padding, pooling_padding)) if max_pool else nn.Identity()
201209
)
202210
for chan_in, chan_out in n_filter_list_pairs
203211
])
@@ -324,8 +332,10 @@ def __init__(
324332
n_conv_layers=1,
325333
frame_stride=1,
326334
frame_kernel_size=3,
335+
frame_padding=None,
327336
frame_pooling_kernel_size=1,
328337
frame_pooling_stride=1,
338+
frame_pooling_padding=None,
329339
kernel_size=7,
330340
stride=2,
331341
padding=3,
@@ -342,8 +352,10 @@ def __init__(
342352
n_output_channels=embedding_dim,
343353
frame_stride=frame_stride,
344354
frame_kernel_size=frame_kernel_size,
355+
frame_padding=frame_padding,
345356
frame_pooling_stride=frame_pooling_stride,
346357
frame_pooling_kernel_size=frame_pooling_kernel_size,
358+
frame_pooling_padding=frame_pooling_padding,
347359
kernel_size=kernel_size,
348360
stride=stride,
349361
padding=padding,

0 commit comments

Comments
 (0)