From 3aa1aa7d98b4cd764b3cd80bf3efd4c9cf6d741f Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Sat, 8 Oct 2022 08:24:26 -0700 Subject: [PATCH 1/3] clean up resnet.py --- src/diffusers/models/resnet.py | 109 ++++++++++++++++++--------------- 1 file changed, 58 insertions(+), 51 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index b9718e67f279..86fa177d817d 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -9,9 +9,10 @@ class Upsample2D(nn.Module): """ An upsampling layer with an optional convolution. - :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is - applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then - upsampling occurs in the inner-two dimensions. + Parameters: + channels: channels in the inputs and outputs. + use_conv: a bool determining if a convolution is applied. + dims: determines if the signal is 1D, 2D, or 3D. If 3D, then upsampling occurs in the inner-two dimensions. """ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): @@ -61,9 +62,10 @@ class Downsample2D(nn.Module): """ A downsampling layer with an optional convolution. - :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is - applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then - downsampling occurs in the inner-two dimensions. + Parameters: + channels: channels in the inputs and outputs. + use_conv: a bool determining if a convolution is applied. + dims: determines if the signal is 1D, 2D, or 3D. If 3D, then downsampling occurs in the inner-two dimensions. """ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): @@ -115,21 +117,23 @@ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel= def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1): """Fused `upsample_2d()` followed by `Conv2d()`. - Args: Padding is performed only once at the beginning, not between the operations. The fused op is considerably more - efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary: + efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary order. - x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, - C]`. - weight: Weight tensor of the shape `[filterH, filterW, inChannels, - outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`. - kernel: FIR filter of the shape `[firH, firW]` or `[firN]` - (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling. - factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0). + + Args: + hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, + C]`. + weight: Weight tensor of the shape `[filterH, filterW, inChannels, + outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`. + kernel: FIR filter of the shape `[firH, firW]` or `[firN]` + (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling. + factor: Integer upsampling factor (default: 2). + gain: Scaling factor for signal magnitude (default: 1.0). Returns: - Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as - `x`. + output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as + `hidden_states`. """ assert isinstance(factor, int) and factor >= 1 @@ -164,7 +168,6 @@ def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1 output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW, ) assert output_padding[0] >= 0 and output_padding[1] >= 0 - inC = weight.shape[1] num_groups = hidden_states.shape[1] // inC # Transpose weights. @@ -214,19 +217,20 @@ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel= def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1): """Fused `Conv2d()` followed by `downsample_2d()`. - - Args: Padding is performed only once at the beginning, not between the operations. The fused op is considerably more - efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary: + efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary order. - x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. w: Weight tensor of the shape `[filterH, - filterW, inChannels, outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // - numGroups`. k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * - factor`, which corresponds to average pooling. factor: Integer downsampling factor (default: 2). gain: - Scaling factor for signal magnitude (default: 1.0). + + Args: + hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. + weight: Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`. + kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * + factor`, which corresponds to average pooling. + factor: Integer downsampling factor (default: 2). + gain: Scaling factor for signal magnitude (default: 1.0). Returns: - Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same + output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same datatype as `x`. """ @@ -251,17 +255,17 @@ def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain torch.tensor(kernel, device=hidden_states.device), pad=((pad_value + 1) // 2, pad_value // 2), ) - hidden_states = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0) + output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0) else: pad_value = kernel.shape[0] - factor - hidden_states = upfirdn2d_native( + output = upfirdn2d_native( hidden_states, torch.tensor(kernel, device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2), ) - return hidden_states + return output def forward(self, hidden_states): if self.use_conv: @@ -393,20 +397,21 @@ def forward(self, hidden_states): def upsample_2d(hidden_states, kernel=None, factor=2, gain=1): r"""Upsample2D a batch of 2D images with the given filter. - - Args: Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a: multiple of the upsampling factor. - x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, + + Args: + hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. - k: FIR filter of the shape `[firH, firW]` or `[firN]` + kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling. - factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0). + factor: Integer upsampling factor (default: 2). + gain: Scaling factor for signal magnitude (default: 1.0). Returns: - Tensor of the shape `[N, C, H * factor, W * factor]` + output: Tensor of the shape `[N, C, H * factor, W * factor]` """ assert isinstance(factor, int) and factor >= 1 if kernel is None: @@ -419,30 +424,32 @@ def upsample_2d(hidden_states, kernel=None, factor=2, gain=1): kernel = kernel * (gain * (factor**2)) pad_value = kernel.shape[0] - factor - return upfirdn2d_native( + output = upfirdn2d_native( hidden_states, kernel.to(device=hidden_states.device), up=factor, pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2), ) + return output def downsample_2d(hidden_states, kernel=None, factor=2, gain=1): r"""Downsample2D a batch of 2D images with the given filter. - - Args: Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a multiple of the downsampling factor. - x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, + + Args: + hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which corresponds to average pooling. - factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0). + factor: Integer downsampling factor (default: 2). + gain: Scaling factor for signal magnitude (default: 1.0). Returns: - Tensor of the shape `[N, C, H // factor, W // factor]` + output: Tensor of the shape `[N, C, H // factor, W // factor]` """ assert isinstance(factor, int) and factor >= 1 @@ -456,34 +463,34 @@ def downsample_2d(hidden_states, kernel=None, factor=2, gain=1): kernel = kernel * gain pad_value = kernel.shape[0] - factor - return upfirdn2d_native( + output = upfirdn2d_native( hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2) ) + return output -def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)): +def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)): up_x = up_y = up down_x = down_y = down pad_x0 = pad_y0 = pad[0] pad_x1 = pad_y1 = pad[1] - _, channel, in_h, in_w = input.shape - input = input.reshape(-1, in_h, in_w, 1) - # Rename this variable (input); it shadows a builtin.sonarlint(python:S5806) + _, channel, in_h, in_w = tensor.shape + tensor = tensor.reshape(-1, in_h, in_w, 1) - _, in_h, in_w, minor = input.shape + _, in_h, in_w, minor = tensor.shape kernel_h, kernel_w = kernel.shape - out = input.view(-1, in_h, 1, in_w, 1, minor) + out = tensor.view(-1, in_h, 1, in_w, 1, minor) # Temporary workaround for mps specific issue: https://github.com/pytorch/pytorch/issues/84535 - if input.device.type == "mps": + if tensor.device.type == "mps": out = out.to("cpu") out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) out = out.view(-1, in_h * up_y, in_w * up_x, minor) out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) - out = out.to(input.device) # Move back to mps if necessary + out = out.to(tensor.device) # Move back to mps if necessary out = out[ :, max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), From 4dbdc691bf6ecb30d9ed063ef5dd19bee73e4614 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Sat, 8 Oct 2022 08:24:45 -0700 Subject: [PATCH 2/3] make style and quality --- src/diffusers/models/resnet.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 86fa177d817d..f2d9ffa7c45d 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -118,8 +118,8 @@ def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1 """Fused `upsample_2d()` followed by `Conv2d()`. Padding is performed only once at the beginning, not between the operations. The fused op is considerably more - efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary - order. + efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of + arbitrary order. Args: hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, @@ -132,7 +132,8 @@ def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1 gain: Scaling factor for signal magnitude (default: 1.0). Returns: - output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as + output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same + datatype as `hidden_states`. """ @@ -218,20 +219,22 @@ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel= def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1): """Fused `Conv2d()` followed by `downsample_2d()`. Padding is performed only once at the beginning, not between the operations. The fused op is considerably more - efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary - order. + efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of + arbitrary order. Args: hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. - weight: Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`. + weight: + Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be + performed by `inChannels = x.shape[0] // numGroups`. kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which corresponds to average pooling. factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0). Returns: - output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same - datatype as `x`. + output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and + same datatype as `x`. """ assert isinstance(factor, int) and factor >= 1 @@ -399,8 +402,8 @@ def upsample_2d(hidden_states, kernel=None, factor=2, gain=1): r"""Upsample2D a batch of 2D images with the given filter. Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified - `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a: - multiple of the upsampling factor. + `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is + a: multiple of the upsampling factor. Args: hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, From cdde0366c7ff95bb3879c827fb52b2300c5066d8 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Sat, 8 Oct 2022 08:29:42 -0700 Subject: [PATCH 3/3] minor formatting --- src/diffusers/models/resnet.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index f2d9ffa7c45d..dc8a91164977 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -122,8 +122,7 @@ def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1 arbitrary order. Args: - hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, - C]`. + hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. weight: Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`. kernel: FIR filter of the shape `[firH, firW]` or `[firN]` @@ -133,8 +132,7 @@ def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1 Returns: output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same - datatype as - `hidden_states`. + datatype as `hidden_states`. """ assert isinstance(factor, int) and factor >= 1 @@ -406,8 +404,7 @@ def upsample_2d(hidden_states, kernel=None, factor=2, gain=1): a: multiple of the upsampling factor. Args: - hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, - C]`. + hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling. factor: Integer upsampling factor (default: 2). @@ -444,8 +441,7 @@ def downsample_2d(hidden_states, kernel=None, factor=2, gain=1): shape is a multiple of the downsampling factor. Args: - hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, - C]`. + hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which corresponds to average pooling. factor: Integer downsampling factor (default: 2).