Skip to content

Commit 455689a

Browse files
Nathan LambertPrathik Rao
Nathan Lambert
authored and
Prathik Rao
committed
Clean up resnet.py file (huggingface#780)
* clean up resnet.py * make style and quality * minor formatting
1 parent afb4294 commit 455689a

File tree

1 file changed

+63
-57
lines changed

1 file changed

+63
-57
lines changed

src/diffusers/models/resnet.py

+63-57
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@ class Upsample2D(nn.Module):
99
"""
1010
An upsampling layer with an optional convolution.
1111
12-
:param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
13-
applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
14-
upsampling occurs in the inner-two dimensions.
12+
Parameters:
13+
channels: channels in the inputs and outputs.
14+
use_conv: a bool determining if a convolution is applied.
15+
dims: determines if the signal is 1D, 2D, or 3D. If 3D, then upsampling occurs in the inner-two dimensions.
1516
"""
1617

1718
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
@@ -61,9 +62,10 @@ class Downsample2D(nn.Module):
6162
"""
6263
A downsampling layer with an optional convolution.
6364
64-
:param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
65-
applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
66-
downsampling occurs in the inner-two dimensions.
65+
Parameters:
66+
channels: channels in the inputs and outputs.
67+
use_conv: a bool determining if a convolution is applied.
68+
dims: determines if the signal is 1D, 2D, or 3D. If 3D, then downsampling occurs in the inner-two dimensions.
6769
"""
6870

6971
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
@@ -115,21 +117,22 @@ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=
115117
def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
116118
"""Fused `upsample_2d()` followed by `Conv2d()`.
117119
118-
Args:
119120
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
120-
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary:
121-
order.
122-
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
123-
C]`.
124-
weight: Weight tensor of the shape `[filterH, filterW, inChannels,
125-
outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
126-
kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
127-
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
128-
factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
121+
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
122+
arbitrary order.
123+
124+
Args:
125+
hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
126+
weight: Weight tensor of the shape `[filterH, filterW, inChannels,
127+
outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
128+
kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
129+
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
130+
factor: Integer upsampling factor (default: 2).
131+
gain: Scaling factor for signal magnitude (default: 1.0).
129132
130133
Returns:
131-
Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as
132-
`x`.
134+
output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
135+
datatype as `hidden_states`.
133136
"""
134137

135138
assert isinstance(factor, int) and factor >= 1
@@ -164,7 +167,6 @@ def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1
164167
output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW,
165168
)
166169
assert output_padding[0] >= 0 and output_padding[1] >= 0
167-
inC = weight.shape[1]
168170
num_groups = hidden_states.shape[1] // inC
169171

170172
# Transpose weights.
@@ -214,20 +216,23 @@ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=
214216

215217
def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
216218
"""Fused `Conv2d()` followed by `downsample_2d()`.
219+
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
220+
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
221+
arbitrary order.
217222
218223
Args:
219-
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
220-
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary:
221-
order.
222-
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. w: Weight tensor of the shape `[filterH,
223-
filterW, inChannels, outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] //
224-
numGroups`. k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
225-
factor`, which corresponds to average pooling. factor: Integer downsampling factor (default: 2). gain:
226-
Scaling factor for signal magnitude (default: 1.0).
224+
hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
225+
weight:
226+
Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
227+
performed by `inChannels = x.shape[0] // numGroups`.
228+
kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
229+
factor`, which corresponds to average pooling.
230+
factor: Integer downsampling factor (default: 2).
231+
gain: Scaling factor for signal magnitude (default: 1.0).
227232
228233
Returns:
229-
Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same
230-
datatype as `x`.
234+
output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and
235+
same datatype as `x`.
231236
"""
232237

233238
assert isinstance(factor, int) and factor >= 1
@@ -251,17 +256,17 @@ def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain
251256
torch.tensor(kernel, device=hidden_states.device),
252257
pad=((pad_value + 1) // 2, pad_value // 2),
253258
)
254-
hidden_states = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
259+
output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
255260
else:
256261
pad_value = kernel.shape[0] - factor
257-
hidden_states = upfirdn2d_native(
262+
output = upfirdn2d_native(
258263
hidden_states,
259264
torch.tensor(kernel, device=hidden_states.device),
260265
down=factor,
261266
pad=((pad_value + 1) // 2, pad_value // 2),
262267
)
263268

264-
return hidden_states
269+
return output
265270

266271
def forward(self, hidden_states):
267272
if self.use_conv:
@@ -393,20 +398,20 @@ def forward(self, hidden_states):
393398

394399
def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
395400
r"""Upsample2D a batch of 2D images with the given filter.
396-
397-
Args:
398401
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
399402
filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
400-
`gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a:
401-
multiple of the upsampling factor.
402-
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
403-
C]`.
404-
k: FIR filter of the shape `[firH, firW]` or `[firN]`
403+
`gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is
404+
a: multiple of the upsampling factor.
405+
406+
Args:
407+
hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
408+
kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
405409
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
406-
factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
410+
factor: Integer upsampling factor (default: 2).
411+
gain: Scaling factor for signal magnitude (default: 1.0).
407412
408413
Returns:
409-
Tensor of the shape `[N, C, H * factor, W * factor]`
414+
output: Tensor of the shape `[N, C, H * factor, W * factor]`
410415
"""
411416
assert isinstance(factor, int) and factor >= 1
412417
if kernel is None:
@@ -419,30 +424,31 @@ def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
419424

420425
kernel = kernel * (gain * (factor**2))
421426
pad_value = kernel.shape[0] - factor
422-
return upfirdn2d_native(
427+
output = upfirdn2d_native(
423428
hidden_states,
424429
kernel.to(device=hidden_states.device),
425430
up=factor,
426431
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
427432
)
433+
return output
428434

429435

430436
def downsample_2d(hidden_states, kernel=None, factor=2, gain=1):
431437
r"""Downsample2D a batch of 2D images with the given filter.
432-
433-
Args:
434438
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
435439
given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
436440
specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
437441
shape is a multiple of the downsampling factor.
438-
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
439-
C]`.
442+
443+
Args:
444+
hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
440445
kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
441446
(separable). The default is `[1] * factor`, which corresponds to average pooling.
442-
factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
447+
factor: Integer downsampling factor (default: 2).
448+
gain: Scaling factor for signal magnitude (default: 1.0).
443449
444450
Returns:
445-
Tensor of the shape `[N, C, H // factor, W // factor]`
451+
output: Tensor of the shape `[N, C, H // factor, W // factor]`
446452
"""
447453

448454
assert isinstance(factor, int) and factor >= 1
@@ -456,34 +462,34 @@ def downsample_2d(hidden_states, kernel=None, factor=2, gain=1):
456462

457463
kernel = kernel * gain
458464
pad_value = kernel.shape[0] - factor
459-
return upfirdn2d_native(
465+
output = upfirdn2d_native(
460466
hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2)
461467
)
468+
return output
462469

463470

464-
def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)):
471+
def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)):
465472
up_x = up_y = up
466473
down_x = down_y = down
467474
pad_x0 = pad_y0 = pad[0]
468475
pad_x1 = pad_y1 = pad[1]
469476

470-
_, channel, in_h, in_w = input.shape
471-
input = input.reshape(-1, in_h, in_w, 1)
472-
# Rename this variable (input); it shadows a builtin.sonarlint(python:S5806)
477+
_, channel, in_h, in_w = tensor.shape
478+
tensor = tensor.reshape(-1, in_h, in_w, 1)
473479

474-
_, in_h, in_w, minor = input.shape
480+
_, in_h, in_w, minor = tensor.shape
475481
kernel_h, kernel_w = kernel.shape
476482

477-
out = input.view(-1, in_h, 1, in_w, 1, minor)
483+
out = tensor.view(-1, in_h, 1, in_w, 1, minor)
478484

479485
# Temporary workaround for mps specific issue: https://github.com/pytorch/pytorch/issues/84535
480-
if input.device.type == "mps":
486+
if tensor.device.type == "mps":
481487
out = out.to("cpu")
482488
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
483489
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
484490

485491
out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
486-
out = out.to(input.device) # Move back to mps if necessary
492+
out = out.to(tensor.device) # Move back to mps if necessary
487493
out = out[
488494
:,
489495
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),

0 commit comments

Comments
 (0)