Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 62 additions & 24 deletions src/diffusers/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,17 @@


class Upsample1D(nn.Module):
"""
An upsampling layer with an optional convolution.
"""A 1D upsampling layer with an optional convolution.

Parameters:
channels: channels in the inputs and outputs.
use_conv: a bool determining if a convolution is applied.
use_conv_transpose:
out_channels:
channels (`int`):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
use_conv_transpose (`bool`, default `False`):
option to use a convolution transpose.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
"""

def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
Expand Down Expand Up @@ -62,14 +65,17 @@ def forward(self, x):


class Downsample1D(nn.Module):
"""
A downsampling layer with an optional convolution.
"""A 1D downsampling layer with an optional convolution.

Parameters:
channels: channels in the inputs and outputs.
use_conv: a bool determining if a convolution is applied.
out_channels:
padding:
channels (`int`):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
padding (`int`, default `1`):
padding for the convolution.
"""

def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
Expand All @@ -93,14 +99,17 @@ def forward(self, x):


class Upsample2D(nn.Module):
"""
An upsampling layer with an optional convolution.
"""A 2D upsampling layer with an optional convolution.

Parameters:
channels: channels in the inputs and outputs.
use_conv: a bool determining if a convolution is applied.
use_conv_transpose:
out_channels:
channels (`int`):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
use_conv_transpose (`bool`, default `False`):
option to use a convolution transpose.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
"""

def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
Expand Down Expand Up @@ -162,14 +171,17 @@ def forward(self, hidden_states, output_size=None):


class Downsample2D(nn.Module):
"""
A downsampling layer with an optional convolution.
"""A 2D downsampling layer with an optional convolution.

Parameters:
channels: channels in the inputs and outputs.
use_conv: a bool determining if a convolution is applied.
out_channels:
padding:
channels (`int`):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
padding (`int`, default `1`):
padding for the convolution.
"""

def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
Expand Down Expand Up @@ -209,6 +221,19 @@ def forward(self, hidden_states):


class FirUpsample2D(nn.Module):
"""A 2D FIR upsampling layer with an optional convolution.

Parameters:
channels (`int`):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
kernel for the FIR filter.
"""

def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
super().__init__()
out_channels = out_channels if out_channels else channels
Expand Down Expand Up @@ -309,6 +334,19 @@ def forward(self, hidden_states):


class FirDownsample2D(nn.Module):
"""A 2D FIR downsampling layer with an optional convolution.

Parameters:
channels (`int`):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
kernel for the FIR filter.
"""

def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
super().__init__()
out_channels = out_channels if out_channels else channels
Expand Down