@@ -9,9 +9,10 @@ class Upsample2D(nn.Module):
9
9
"""
10
10
An upsampling layer with an optional convolution.
11
11
12
- :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
13
- applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
14
- upsampling occurs in the inner-two dimensions.
12
+ Parameters:
13
+ channels: channels in the inputs and outputs.
14
+ use_conv: a bool determining if a convolution is applied.
15
+ dims: determines if the signal is 1D, 2D, or 3D. If 3D, then upsampling occurs in the inner-two dimensions.
15
16
"""
16
17
17
18
def __init__ (self , channels , use_conv = False , use_conv_transpose = False , out_channels = None , name = "conv" ):
@@ -61,9 +62,10 @@ class Downsample2D(nn.Module):
61
62
"""
62
63
A downsampling layer with an optional convolution.
63
64
64
- :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
65
- applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
66
- downsampling occurs in the inner-two dimensions.
65
+ Parameters:
66
+ channels: channels in the inputs and outputs.
67
+ use_conv: a bool determining if a convolution is applied.
68
+ dims: determines if the signal is 1D, 2D, or 3D. If 3D, then downsampling occurs in the inner-two dimensions.
67
69
"""
68
70
69
71
def __init__ (self , channels , use_conv = False , out_channels = None , padding = 1 , name = "conv" ):
@@ -115,21 +117,22 @@ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=
115
117
def _upsample_2d (self , hidden_states , weight = None , kernel = None , factor = 2 , gain = 1 ):
116
118
"""Fused `upsample_2d()` followed by `Conv2d()`.
117
119
118
- Args:
119
120
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
120
- efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary:
121
- order.
122
- x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
123
- C]`.
124
- weight: Weight tensor of the shape `[filterH, filterW, inChannels,
125
- outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
126
- kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
127
- (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
128
- factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
121
+ efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
122
+ arbitrary order.
123
+
124
+ Args:
125
+ hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
126
+ weight: Weight tensor of the shape `[filterH, filterW, inChannels,
127
+ outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
128
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
129
+ (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
130
+ factor: Integer upsampling factor (default: 2).
131
+ gain: Scaling factor for signal magnitude (default: 1.0).
129
132
130
133
Returns:
131
- Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as
132
- `x `.
134
+ output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
135
+ datatype as `hidden_states `.
133
136
"""
134
137
135
138
assert isinstance (factor , int ) and factor >= 1
@@ -164,7 +167,6 @@ def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1
164
167
output_shape [1 ] - (hidden_states .shape [3 ] - 1 ) * stride [1 ] - convW ,
165
168
)
166
169
assert output_padding [0 ] >= 0 and output_padding [1 ] >= 0
167
- inC = weight .shape [1 ]
168
170
num_groups = hidden_states .shape [1 ] // inC
169
171
170
172
# Transpose weights.
@@ -214,20 +216,23 @@ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=
214
216
215
217
def _downsample_2d (self , hidden_states , weight = None , kernel = None , factor = 2 , gain = 1 ):
216
218
"""Fused `Conv2d()` followed by `downsample_2d()`.
219
+ Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
220
+ efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
221
+ arbitrary order.
217
222
218
223
Args:
219
- Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
220
- efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary :
221
- order.
222
- x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. w: Weight tensor of the shape `[filterH,
223
- filterW, inChannels, outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] //
224
- numGroups`. k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
225
- factor`, which corresponds to average pooling. factor : Integer downsampling factor (default: 2). gain:
226
- Scaling factor for signal magnitude (default: 1.0).
224
+ hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
225
+ weight :
226
+ Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
227
+ performed by `inChannels = x.shape[0] // numGroups`.
228
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
229
+ factor`, which corresponds to average pooling.
230
+ factor: Integer downsampling factor (default: 2).
231
+ gain: Scaling factor for signal magnitude (default: 1.0).
227
232
228
233
Returns:
229
- Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same
230
- datatype as `x`.
234
+ output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and
235
+ same datatype as `x`.
231
236
"""
232
237
233
238
assert isinstance (factor , int ) and factor >= 1
@@ -251,17 +256,17 @@ def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain
251
256
torch .tensor (kernel , device = hidden_states .device ),
252
257
pad = ((pad_value + 1 ) // 2 , pad_value // 2 ),
253
258
)
254
- hidden_states = F .conv2d (upfirdn_input , weight , stride = stride_value , padding = 0 )
259
+ output = F .conv2d (upfirdn_input , weight , stride = stride_value , padding = 0 )
255
260
else :
256
261
pad_value = kernel .shape [0 ] - factor
257
- hidden_states = upfirdn2d_native (
262
+ output = upfirdn2d_native (
258
263
hidden_states ,
259
264
torch .tensor (kernel , device = hidden_states .device ),
260
265
down = factor ,
261
266
pad = ((pad_value + 1 ) // 2 , pad_value // 2 ),
262
267
)
263
268
264
- return hidden_states
269
+ return output
265
270
266
271
def forward (self , hidden_states ):
267
272
if self .use_conv :
@@ -393,20 +398,20 @@ def forward(self, hidden_states):
393
398
394
399
def upsample_2d (hidden_states , kernel = None , factor = 2 , gain = 1 ):
395
400
r"""Upsample2D a batch of 2D images with the given filter.
396
-
397
- Args:
398
401
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
399
402
filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
400
- `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a:
401
- multiple of the upsampling factor.
402
- x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
403
- C]`.
404
- k: FIR filter of the shape `[firH, firW]` or `[firN]`
403
+ `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is
404
+ a: multiple of the upsampling factor.
405
+
406
+ Args:
407
+ hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
408
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
405
409
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
406
- factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
410
+ factor: Integer upsampling factor (default: 2).
411
+ gain: Scaling factor for signal magnitude (default: 1.0).
407
412
408
413
Returns:
409
- Tensor of the shape `[N, C, H * factor, W * factor]`
414
+ output: Tensor of the shape `[N, C, H * factor, W * factor]`
410
415
"""
411
416
assert isinstance (factor , int ) and factor >= 1
412
417
if kernel is None :
@@ -419,30 +424,31 @@ def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
419
424
420
425
kernel = kernel * (gain * (factor ** 2 ))
421
426
pad_value = kernel .shape [0 ] - factor
422
- return upfirdn2d_native (
427
+ output = upfirdn2d_native (
423
428
hidden_states ,
424
429
kernel .to (device = hidden_states .device ),
425
430
up = factor ,
426
431
pad = ((pad_value + 1 ) // 2 + factor - 1 , pad_value // 2 ),
427
432
)
433
+ return output
428
434
429
435
430
436
def downsample_2d (hidden_states , kernel = None , factor = 2 , gain = 1 ):
431
437
r"""Downsample2D a batch of 2D images with the given filter.
432
-
433
- Args:
434
438
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
435
439
given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
436
440
specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
437
441
shape is a multiple of the downsampling factor.
438
- x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
439
- C]`.
442
+
443
+ Args:
444
+ hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
440
445
kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
441
446
(separable). The default is `[1] * factor`, which corresponds to average pooling.
442
- factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
447
+ factor: Integer downsampling factor (default: 2).
448
+ gain: Scaling factor for signal magnitude (default: 1.0).
443
449
444
450
Returns:
445
- Tensor of the shape `[N, C, H // factor, W // factor]`
451
+ output: Tensor of the shape `[N, C, H // factor, W // factor]`
446
452
"""
447
453
448
454
assert isinstance (factor , int ) and factor >= 1
@@ -456,34 +462,34 @@ def downsample_2d(hidden_states, kernel=None, factor=2, gain=1):
456
462
457
463
kernel = kernel * gain
458
464
pad_value = kernel .shape [0 ] - factor
459
- return upfirdn2d_native (
465
+ output = upfirdn2d_native (
460
466
hidden_states , kernel .to (device = hidden_states .device ), down = factor , pad = ((pad_value + 1 ) // 2 , pad_value // 2 )
461
467
)
468
+ return output
462
469
463
470
464
- def upfirdn2d_native (input , kernel , up = 1 , down = 1 , pad = (0 , 0 )):
471
+ def upfirdn2d_native (tensor , kernel , up = 1 , down = 1 , pad = (0 , 0 )):
465
472
up_x = up_y = up
466
473
down_x = down_y = down
467
474
pad_x0 = pad_y0 = pad [0 ]
468
475
pad_x1 = pad_y1 = pad [1 ]
469
476
470
- _ , channel , in_h , in_w = input .shape
471
- input = input .reshape (- 1 , in_h , in_w , 1 )
472
- # Rename this variable (input); it shadows a builtin.sonarlint(python:S5806)
477
+ _ , channel , in_h , in_w = tensor .shape
478
+ tensor = tensor .reshape (- 1 , in_h , in_w , 1 )
473
479
474
- _ , in_h , in_w , minor = input .shape
480
+ _ , in_h , in_w , minor = tensor .shape
475
481
kernel_h , kernel_w = kernel .shape
476
482
477
- out = input .view (- 1 , in_h , 1 , in_w , 1 , minor )
483
+ out = tensor .view (- 1 , in_h , 1 , in_w , 1 , minor )
478
484
479
485
# Temporary workaround for mps specific issue: https://github.com/pytorch/pytorch/issues/84535
480
- if input .device .type == "mps" :
486
+ if tensor .device .type == "mps" :
481
487
out = out .to ("cpu" )
482
488
out = F .pad (out , [0 , 0 , 0 , up_x - 1 , 0 , 0 , 0 , up_y - 1 ])
483
489
out = out .view (- 1 , in_h * up_y , in_w * up_x , minor )
484
490
485
491
out = F .pad (out , [0 , 0 , max (pad_x0 , 0 ), max (pad_x1 , 0 ), max (pad_y0 , 0 ), max (pad_y1 , 0 )])
486
- out = out .to (input .device ) # Move back to mps if necessary
492
+ out = out .to (tensor .device ) # Move back to mps if necessary
487
493
out = out [
488
494
:,
489
495
max (- pad_y0 , 0 ) : out .shape [1 ] - max (- pad_y1 , 0 ),
0 commit comments