@@ -167,8 +167,10 @@ def __init__(
167
167
stride ,
168
168
padding ,
169
169
frame_stride = 1 ,
170
+ frame_padding = None ,
170
171
frame_pooling_stride = 1 ,
171
172
frame_pooling_kernel_size = 1 ,
173
+ frame_pooling_padding = None ,
172
174
pooling_kernel_size = 3 ,
173
175
pooling_stride = 2 ,
174
176
pooling_padding = 1 ,
@@ -188,16 +190,22 @@ def __init__(
188
190
189
191
n_filter_list_pairs = zip (n_filter_list [:- 1 ], n_filter_list [1 :])
190
192
193
+ if frame_padding is None :
194
+ frame_padding = frame_kernel_size // 2
195
+
196
+ if frame_pooling_padding is None :
197
+ frame_pooling_padding = frame_pooling_kernel_size // 2
198
+
191
199
self .conv_layers = nn .Sequential (
192
200
* [nn .Sequential (
193
201
nn .Conv3d (chan_in , chan_out ,
194
202
kernel_size = (frame_kernel_size , kernel_size , kernel_size ),
195
203
stride = (frame_stride , stride , stride ),
196
- padding = (frame_kernel_size // 2 , padding , padding ), bias = conv_bias ),
204
+ padding = (frame_padding , padding , padding ), bias = conv_bias ),
197
205
nn .Identity () if not exists (activation ) else activation (),
198
206
nn .MaxPool3d (kernel_size = (frame_pooling_kernel_size , pooling_kernel_size , pooling_kernel_size ),
199
207
stride = (frame_pooling_stride , pooling_stride , pooling_stride ),
200
- padding = (frame_pooling_kernel_size // 2 , pooling_padding , pooling_padding )) if max_pool else nn .Identity ()
208
+ padding = (frame_pooling_padding , pooling_padding , pooling_padding )) if max_pool else nn .Identity ()
201
209
)
202
210
for chan_in , chan_out in n_filter_list_pairs
203
211
])
@@ -324,8 +332,10 @@ def __init__(
324
332
n_conv_layers = 1 ,
325
333
frame_stride = 1 ,
326
334
frame_kernel_size = 3 ,
335
+ frame_padding = None ,
327
336
frame_pooling_kernel_size = 1 ,
328
337
frame_pooling_stride = 1 ,
338
+ frame_pooling_padding = None ,
329
339
kernel_size = 7 ,
330
340
stride = 2 ,
331
341
padding = 3 ,
@@ -342,8 +352,10 @@ def __init__(
342
352
n_output_channels = embedding_dim ,
343
353
frame_stride = frame_stride ,
344
354
frame_kernel_size = frame_kernel_size ,
355
+ frame_padding = frame_padding ,
345
356
frame_pooling_stride = frame_pooling_stride ,
346
357
frame_pooling_kernel_size = frame_pooling_kernel_size ,
358
+ frame_pooling_padding = frame_pooling_padding ,
347
359
kernel_size = kernel_size ,
348
360
stride = stride ,
349
361
padding = padding ,
0 commit comments