Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/en/tutorials/basic_training.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ Now you can wrap all these components together in a training loop with 🤗 Acce

... # Sample a random timestep for each image
... timesteps = torch.randint(
... 0, noise_scheduler.num_train_timesteps, (bs,), device=clean_images.device
... 0, noise_scheduler.config.num_train_timesteps, (bs,), device=clean_images.device
... ).long()

... # Add noise to the clean images according to the noise magnitude at each timestep
Expand Down
4 changes: 2 additions & 2 deletions docs/source/en/using-diffusers/contribute_pipeline.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class UnetSchedulerOneForwardPipeline(DiffusionPipeline):

def __call__(self):
image = torch.randn(
(1, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
(1, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size),
)
timestep = 1

Expand Down Expand Up @@ -108,7 +108,7 @@ class UnetSchedulerOneForwardPipeline(DiffusionPipeline):

def __call__(self):
image = torch.randn(
(1, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
(1, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size),
)
timestep = 1

Expand Down
4 changes: 3 additions & 1 deletion docs/source/en/using-diffusers/custom_pipeline_overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ class MyPipeline(DiffusionPipeline):
@torch.no_grad()
def __call__(self, batch_size: int = 1, num_inference_steps: int = 50):
# Sample gaussian noise to begin loop
image = torch.randn((batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size))
image = torch.randn(
(batch_size, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size)
)

image = image.to(self.device)

Expand Down
2 changes: 1 addition & 1 deletion examples/community/bit_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def __call__(
**kwargs,
) -> Union[Tuple, ImagePipelineOutput]:
latents = torch.randn(
(batch_size, self.unet.in_channels, height, width),
(batch_size, self.unet.config.in_channels, height, width),
generator=generator,
)
latents = decimal_to_bits(latents) * self.bit_scale
Expand Down
2 changes: 1 addition & 1 deletion examples/community/clip_guided_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def __call__(
# Unlike in other pipelines, latents need to be generated in the target device
# for 1-to-1 results reproducibility with the CompVis implementation.
# However this currently doesn't work in `mps`.
latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8)
latents_dtype = text_embeddings.dtype
if latents is None:
if self.device.type == "mps":
Expand Down
2 changes: 1 addition & 1 deletion examples/community/clip_guided_stable_diffusion_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ def __call__(
# Unlike in other pipelines, latents need to be generated in the target device
# for 1-to-1 results reproducibility with the CompVis implementation.
# However this currently doesn't work in `mps`.
latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8)
latents_dtype = text_embeddings.dtype
if latents is None:
if self.device.type == "mps":
Expand Down
2 changes: 1 addition & 1 deletion examples/community/composable_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ def __call__(
timesteps = self.scheduler.timesteps

# 5. Prepare latent variables
num_channels_latents = self.unet.in_channels
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
Expand Down
2 changes: 1 addition & 1 deletion examples/community/imagic_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ def __call__(
# Unlike in other pipelines, latents need to be generated in the target device
# for 1-to-1 results reproducibility with the CompVis implementation.
# However this currently doesn't work in `mps`.
latents_shape = (1, self.unet.in_channels, height // 8, width // 8)
latents_shape = (1, self.unet.config.in_channels, height // 8, width // 8)
latents_dtype = text_embeddings.dtype
if self.device.type == "mps":
# randn does not exist on mps
Expand Down
4 changes: 2 additions & 2 deletions examples/community/interpolate_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def __call__(
# Unlike in other pipelines, latents need to be generated in the target device
# for 1-to-1 results reproducibility with the CompVis implementation.
# However this currently doesn't work in `mps`.
latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8)
latents_dtype = text_embeddings.dtype
if latents is None:
if self.device.type == "mps":
Expand Down Expand Up @@ -416,7 +416,7 @@ def embed_text(self, text):
def get_noise(self, seed, dtype=torch.float32, height=512, width=512):
"""Takes in random seed and returns corresponding noise vector"""
return torch.randn(
(1, self.unet.in_channels, height // 8, width // 8),
(1, self.unet.config.in_channels, height // 8, width // 8),
generator=torch.Generator(device=self.device).manual_seed(seed),
device=self.device,
dtype=dtype,
Expand Down
2 changes: 1 addition & 1 deletion examples/community/lpw_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,7 @@ def prepare_latents(self, image, timestep, batch_size, height, width, dtype, dev
if image is None:
shape = (
batch_size,
self.unet.in_channels,
self.unet.config.in_channels,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
Expand Down
4 changes: 2 additions & 2 deletions examples/community/lpw_stable_diffusion_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ def __init__(
self.__init__additional__()

def __init__additional__(self):
self.unet_in_channels = 4
self.unet.config.in_channels = 4
self.vae_scale_factor = 8

def _encode_prompt(
Expand Down Expand Up @@ -621,7 +621,7 @@ def prepare_latents(self, image, timestep, batch_size, height, width, dtype, gen
if image is None:
shape = (
batch_size,
self.unet_in_channels,
self.unet.config.in_channels,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
Expand Down
2 changes: 1 addition & 1 deletion examples/community/magic_mix.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __call__(

torch.manual_seed(seed)
noise = torch.randn(
(1, self.unet.in_channels, height // 8, width // 8),
(1, self.unet.config.in_channels, height // 8, width // 8),
).to(self.device)

latents = self.scheduler.add_noise(
Expand Down
2 changes: 1 addition & 1 deletion examples/community/multilingual_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def __call__(
# Unlike in other pipelines, latents need to be generated in the target device
# for 1-to-1 results reproducibility with the CompVis implementation.
# However this currently doesn't work in `mps`.
latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8)
latents_dtype = text_embeddings.dtype
if latents is None:
if self.device.type == "mps":
Expand Down
2 changes: 1 addition & 1 deletion examples/community/sd_text2img_k_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ def __call__(
sigmas = sigmas.to(text_embeddings.dtype)

# 5. Prepare latent variables
num_channels_latents = self.unet.in_channels
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
Expand Down
4 changes: 2 additions & 2 deletions examples/community/seed_resize_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,8 @@ def __call__(
# Unlike in other pipelines, latents need to be generated in the target device
# for 1-to-1 results reproducibility with the CompVis implementation.
# However this currently doesn't work in `mps`.
latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
latents_shape_reference = (batch_size * num_images_per_prompt, self.unet.in_channels, 64, 64)
latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8)
latents_shape_reference = (batch_size * num_images_per_prompt, self.unet.config.in_channels, 64, 64)
latents_dtype = text_embeddings.dtype
if latents is None:
if self.device.type == "mps":
Expand Down
2 changes: 1 addition & 1 deletion examples/community/speech_to_image_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def __call__(
# Unlike in other pipelines, latents need to be generated in the target device
# for 1-to-1 results reproducibility with the CompVis implementation.
# However this currently doesn't work in `mps`.
latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8)
latents_dtype = text_embeddings.dtype
if latents is None:
if self.device.type == "mps":
Expand Down
2 changes: 1 addition & 1 deletion examples/community/wildcard_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def __call__(
# Unlike in other pipelines, latents need to be generated in the target device
# for 1-to-1 results reproducibility with the CompVis implementation.
# However this currently doesn't work in `mps`.
latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8)
latents_dtype = text_embeddings.dtype
if latents is None:
if self.device.type == "mps":
Expand Down
2 changes: 1 addition & 1 deletion examples/instruct_pix2pix/train_instruct_pix2pix.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,7 @@ def collate_fn(examples):
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()

# Add noise to the latents according to the noise magnitude at each timestep
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,7 @@ def collate_fn(examples):
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()

# Add noise to the latents according to the noise magnitude at each timestep
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,7 @@ def collate_fn(examples):
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()

# Add noise to the latents according to the noise magnitude at each timestep
Expand Down
2 changes: 1 addition & 1 deletion examples/text_to_image/train_text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,7 +804,7 @@ def collate_fn(examples):

bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()

# Add noise to the latents according to the noise magnitude at each timestep
Expand Down
2 changes: 1 addition & 1 deletion examples/text_to_image/train_text_to_image_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,7 +707,7 @@ def collate_fn(examples):

bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()

# Add noise to the latents according to the noise magnitude at each timestep
Expand Down
7 changes: 0 additions & 7 deletions src/diffusers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,6 @@ def register_to_config(self, **kwargs):
# TODO: remove this when we remove the deprecation warning, and the `kwargs` argument,
# or solve in a more general way.
kwargs.pop("kwargs", None)
for key, value in kwargs.items():
try:
setattr(self, key, value)
Copy link
Contributor Author

@patrickvonplaten patrickvonplaten Mar 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should delete this, I don't know why I added it back then, but it's not a good design looking at it now.

In short currently, we have the following logic:

All diffusers models and pipelines shall inherit from ConfigMixin which gives some nice config utils:

  • the .config property`
  • automatic saving of arguments pass to the class contstructor
  • load / save config
  • from_config

The API is designed to allow the following

class MyModel(ConfigMixin):

    @register_to_config
    def __init__(self, attr1, attr2):
        super().__init__()

This currently does two things:

a) it automatically saves attr1 and attr2 as self.config.attr1 and self.config.attr2
b) However, it also saves attr1 and attr2 as self.attr1 and self.attr2

I'm not sure why I thought b) is a good idea back then, but I think it's a pretty bad design choice now as it's not intuitive and also prone to errors (see raise error statement here). Also it doesn't really help as every attribute can be accessed with self.config . Also it badly entangles object members such as pipe.unet with configs which are not models.

=> Long story short, I think we should remove option b). This will clearly be a breaking change, but I don't think anybody really used b) much before as it's quite weird to assume that config attributes are direct members of the class instance.

@pcuenca @sayakpaul @patil-suraj @williamberman @yiyixuxu ok for you if I remove this with a big 🚨 headear?

As you can see many tests are failing, but they are all easy to fix.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+100 to your reasoning. I am okay with this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This way, we also encourage our users to access configuration variables from .config which reduces the cognitive burdens and helps establish a clean and uniform design philosophy.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't remember why this was added in the first place, but I'm all for simplifying the API. Thanks for revisiting!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This changes broke this pipeline

I think it could break many more, instead of not setting attributes anymore we should set getters that tell the user these properties are deprecated, should i open a PR?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's fair. I think it should be nice. Cc @patrickvonplaten

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Too many breaking changes because of this PR, will try to fix and make a patch release

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR should work for the patch release: #3129

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@remorses sorry about the breaking change, we just made a patch release for it

except AttributeError as err:
logger.error(f"Can't set {key} with value {value} for {self}")
raise err

if not hasattr(self, "_internal_dict"):
internal_dict = kwargs
else:
Expand Down
20 changes: 12 additions & 8 deletions src/diffusers/image_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ def resize(self, images: PIL.Image.Image) -> PIL.Image.Image:
Resize a PIL image. Both height and width will be downscaled to the next integer multiple of `vae_scale_factor`
"""
w, h = images.size
w, h = (x - x % self.vae_scale_factor for x in (w, h)) # resize to integer multiple of vae_scale_factor
images = images.resize((w, h), resample=PIL_INTERPOLATION[self.resample])
w, h = (x - x % self.config.vae_scale_factor for x in (w, h)) # resize to integer multiple of vae_scale_factor
images = images.resize((w, h), resample=PIL_INTERPOLATION[self.config.resample])
return images

def preprocess(
Expand All @@ -119,7 +119,7 @@ def preprocess(
)

if isinstance(image[0], PIL.Image.Image):
if self.do_resize:
if self.config.do_resize:
image = [self.resize(i) for i in image]
image = [np.array(i).astype(np.float32) / 255.0 for i in image]
image = np.stack(image, axis=0) # to np
Expand All @@ -129,23 +129,27 @@ def preprocess(
image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
image = self.numpy_to_pt(image)
_, _, height, width = image.shape
if self.do_resize and (height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0):
if self.config.do_resize and (
height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0
):
raise ValueError(
f"Currently we only support resizing for PIL image - please resize your numpy array to be divisible by {self.vae_scale_factor}"
f"Currently we only support resizing for PIL image - please resize your numpy array to be divisible by {self.config.vae_scale_factor}"
f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor"
)

elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
_, _, height, width = image.shape
if self.do_resize and (height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0):
if self.config.do_resize and (
height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0
):
raise ValueError(
f"Currently we only support resizing for PIL image - please resize your pytorch tensor to be divisible by {self.vae_scale_factor}"
f"Currently we only support resizing for PIL image - please resize your pytorch tensor to be divisible by {self.config.vae_scale_factor}"
f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor"
)

# expected range [0,1], normalize to [-1,1]
do_normalize = self.do_normalize
do_normalize = self.config.do_normalize
if image.min() < 0:
warnings.warn(
"Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
Expand Down
14 changes: 12 additions & 2 deletions src/diffusers/models/autoencoder_kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch.nn as nn

from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, apply_forward_hook
from ..utils import BaseOutput, apply_forward_hook, deprecate
from .modeling_utils import ModelMixin
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder

Expand Down Expand Up @@ -120,9 +120,19 @@ def __init__(
if isinstance(self.config.sample_size, (list, tuple))
else self.config.sample_size
)
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.block_out_channels) - 1)))
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_overlap_factor = 0.25

@property
def block_out_channels(self):
deprecate(
"block_out_channels",
"1.0.0",
"Accessing `block_out_channels` directly via vae.block_out_channels is deprecated. Please use `vae.config.block_out_channels instead`",
standard_warn=False,
)
return self.config.block_out_channels

def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (Encoder, Decoder)):
module.gradient_checkpointing = value
Expand Down
12 changes: 11 additions & 1 deletion src/diffusers/models/unet_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch.nn as nn

from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from ..utils import BaseOutput, deprecate
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block
Expand Down Expand Up @@ -190,6 +190,16 @@ def __init__(
fc_dim=block_out_channels[-1] // 4,
)

@property
def in_channels(self):
deprecate(
"in_channels",
"1.0.0",
"Accessing `in_channels` directly via unet.in_channels is deprecated. Please use `unet.config.in_channels` instead",
standard_warn=False,
)
return self.config.in_channels

Comment on lines +193 to +202
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doing this via __getattr__ we could deprecate all properties, not just in_channels.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yeah good idea

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok played around with it a bit and don't feel super comfortable overwriting __getattr__ actually. Many Python methods such as hasattr(...) rely on __getattr__ and I don't manage to escape infinite recursions and have nice standard error messages here.

Added deprecation properties for all the important direct accesses, think that should be good enough.

def forward(
self,
sample: torch.FloatTensor,
Expand Down
12 changes: 11 additions & 1 deletion src/diffusers/models/unet_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch.nn as nn

from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from ..utils import BaseOutput, deprecate
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
Expand Down Expand Up @@ -215,6 +215,16 @@ def __init__(
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)

@property
def in_channels(self):
deprecate(
"in_channels",
"1.0.0",
"Accessing `in_channels` directly via unet.in_channels is deprecated. Please use `unet.config.in_channels` instead",
standard_warn=False,
)
return self.config.in_channels

def forward(
self,
sample: torch.FloatTensor,
Expand Down
Loading