Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 38 additions & 49 deletions src/diffusers/models/autoencoders/autoencoder_rae.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,17 +226,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:


class RAEDecoder(nn.Module):
"""
Decoder implementation ported from RAE-main to keep checkpoint compatibility.

Key attributes (must match checkpoint keys):
- decoder_embed
- decoder_pos_embed
- decoder_layers
- decoder_norm
- decoder_pred
- trainable_cls_token
"""
"""Lightweight RAE decoder."""

def __init__(
self,
Expand Down Expand Up @@ -291,7 +281,7 @@ def __init__(
def _initialize_weights(self, num_patches: int):
# Skip initialization when parameters are on meta device (e.g. during
# accelerate.init_empty_weights() used by low_cpu_mem_usage loading).
# The weights will be loaded from the checkpoint afterwards.
# The weights are initialized.
if self.decoder_pos_embed.device.type == "meta":
return

Expand Down Expand Up @@ -487,6 +477,31 @@ def __init__(
f"Unknown encoder_type='{encoder_type}'. Available: {sorted(_ENCODER_FORWARD_FNS.keys())}"
)

if encoder_input_size % encoder_patch_size != 0:
raise ValueError(
f"encoder_input_size={encoder_input_size} must be divisible by encoder_patch_size={encoder_patch_size}."
)

decoder_patch_size = patch_size
if decoder_patch_size <= 0:
raise ValueError("patch_size must be a positive integer (this is decoder_patch_size).")

num_patches = (encoder_input_size // encoder_patch_size) ** 2
grid = int(sqrt(num_patches))
if grid * grid != num_patches:
raise ValueError(f"Computed num_patches={num_patches} must be a perfect square.")

derived_image_size = decoder_patch_size * grid
if image_size is None:
image_size = derived_image_size
else:
image_size = int(image_size)
if image_size != derived_image_size:
raise ValueError(
f"image_size={image_size} must equal decoder_patch_size*sqrt(num_patches)={derived_image_size} "
f"for patch_size={decoder_patch_size} and computed num_patches={num_patches}."
)

def _to_config_compatible(value: Any) -> Any:
if isinstance(value, torch.Tensor):
return value.detach().cpu().tolist()
Expand All @@ -511,21 +526,6 @@ def _as_optional_tensor(value: torch.Tensor | list | tuple | None) -> torch.Tens
latents_std=_to_config_compatible(latents_std),
)

self.encoder_input_size = encoder_input_size
self.noise_tau = float(noise_tau)
self.reshape_to_2d = bool(reshape_to_2d)
self.use_encoder_loss = bool(use_encoder_loss)

# Validate early, before building the (potentially large) encoder/decoder.
encoder_patch_size = int(encoder_patch_size)
if self.encoder_input_size % encoder_patch_size != 0:
raise ValueError(
f"encoder_input_size={self.encoder_input_size} must be divisible by encoder_patch_size={encoder_patch_size}."
)
decoder_patch_size = int(patch_size)
if decoder_patch_size <= 0:
raise ValueError("patch_size must be a positive integer (this is decoder_patch_size).")

# Frozen representation encoder (built from config, no downloads)
self.encoder: nn.Module = _build_encoder(
encoder_type=encoder_type,
Expand All @@ -534,22 +534,7 @@ def _as_optional_tensor(value: torch.Tensor | list | tuple | None) -> torch.Tens
num_hidden_layers=encoder_num_hidden_layers,
)
self._encoder_forward_fn = _ENCODER_FORWARD_FNS[encoder_type]
num_patches = (self.encoder_input_size // encoder_patch_size) ** 2

grid = int(sqrt(num_patches))
if grid * grid != num_patches:
raise ValueError(f"Computed num_patches={num_patches} must be a perfect square.")

derived_image_size = decoder_patch_size * grid
if image_size is None:
image_size = derived_image_size
else:
image_size = int(image_size)
if image_size != derived_image_size:
raise ValueError(
f"image_size={image_size} must equal decoder_patch_size*sqrt(num_patches)={derived_image_size} "
f"for patch_size={decoder_patch_size} and computed num_patches={num_patches}."
)
num_patches = (encoder_input_size // encoder_patch_size) ** 2

# Encoder input normalization stats (ImageNet defaults)
if encoder_norm_mean is None:
Expand Down Expand Up @@ -584,6 +569,7 @@ def _as_optional_tensor(value: torch.Tensor | list | tuple | None) -> torch.Tens
num_channels=int(num_channels),
image_size=int(image_size),
)

self.num_patches = int(num_patches)
self.decoder_patch_size = int(decoder_patch_size)
self.decoder_image_size = int(image_size)
Expand All @@ -593,16 +579,19 @@ def _as_optional_tensor(value: torch.Tensor | list | tuple | None) -> torch.Tens

def _noising(self, x: torch.Tensor, generator: torch.Generator | None = None) -> torch.Tensor:
# Per-sample random sigma in [0, noise_tau]
noise_sigma = self.noise_tau * torch.rand(
noise_sigma = self.config.noise_tau * torch.rand(
(x.size(0),) + (1,) * (x.ndim - 1), device=x.device, dtype=x.dtype, generator=generator
)
return x + noise_sigma * randn_tensor(x.shape, generator=generator, device=x.device, dtype=x.dtype)

def _resize_and_normalize(self, x: torch.Tensor) -> torch.Tensor:
_, _, h, w = x.shape
if h != self.encoder_input_size or w != self.encoder_input_size:
if h != self.config.encoder_input_size or w != self.config.encoder_input_size:
x = F.interpolate(
x, size=(self.encoder_input_size, self.encoder_input_size), mode="bicubic", align_corners=False
x,
size=(self.config.encoder_input_size, self.config.encoder_input_size),
mode="bicubic",
align_corners=False,
)
mean = self.encoder_mean.to(device=x.device, dtype=x.dtype)
std = self.encoder_std.to(device=x.device, dtype=x.dtype)
Expand Down Expand Up @@ -631,10 +620,10 @@ def _encode(self, x: torch.Tensor, generator: torch.Generator | None = None) ->
else:
tokens = self._encoder_forward_fn(self.encoder, x) # (B, N, C)

if self.training and self.noise_tau > 0:
if self.training and self.config.noise_tau > 0:
tokens = self._noising(tokens, generator=generator)

if self.reshape_to_2d:
if self.config.reshape_to_2d:
b, n, c = tokens.shape
side = int(sqrt(n))
if side * side != n:
Expand Down Expand Up @@ -671,7 +660,7 @@ def _decode(self, z: torch.Tensor) -> torch.Tensor:

z = self._denormalize_latents(z)

if self.reshape_to_2d:
if self.config.reshape_to_2d:
b, c, h, w = z.shape
tokens = z.view(b, c, h * w).transpose(1, 2).contiguous() # (B, N, C)
else:
Expand Down