diff --git a/src/diffusers/models/autoencoders/autoencoder_rae.py b/src/diffusers/models/autoencoders/autoencoder_rae.py index 3c6e821b7770..13eb4f745f70 100644 --- a/src/diffusers/models/autoencoders/autoencoder_rae.py +++ b/src/diffusers/models/autoencoders/autoencoder_rae.py @@ -226,17 +226,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class RAEDecoder(nn.Module): - """ - Decoder implementation ported from RAE-main to keep checkpoint compatibility. - - Key attributes (must match checkpoint keys): - - decoder_embed - - decoder_pos_embed - - decoder_layers - - decoder_norm - - decoder_pred - - trainable_cls_token - """ + """Lightweight RAE decoder.""" def __init__( self, @@ -291,7 +281,7 @@ def __init__( def _initialize_weights(self, num_patches: int): # Skip initialization when parameters are on meta device (e.g. during # accelerate.init_empty_weights() used by low_cpu_mem_usage loading). - # The weights will be loaded from the checkpoint afterwards. + # The weights are initialized. if self.decoder_pos_embed.device.type == "meta": return @@ -487,6 +477,31 @@ def __init__( f"Unknown encoder_type='{encoder_type}'. Available: {sorted(_ENCODER_FORWARD_FNS.keys())}" ) + if encoder_input_size % encoder_patch_size != 0: + raise ValueError( + f"encoder_input_size={encoder_input_size} must be divisible by encoder_patch_size={encoder_patch_size}." + ) + + decoder_patch_size = patch_size + if decoder_patch_size <= 0: + raise ValueError("patch_size must be a positive integer (this is decoder_patch_size).") + + num_patches = (encoder_input_size // encoder_patch_size) ** 2 + grid = int(sqrt(num_patches)) + if grid * grid != num_patches: + raise ValueError(f"Computed num_patches={num_patches} must be a perfect square.") + + derived_image_size = decoder_patch_size * grid + if image_size is None: + image_size = derived_image_size + else: + image_size = int(image_size) + if image_size != derived_image_size: + raise ValueError( + f"image_size={image_size} must equal decoder_patch_size*sqrt(num_patches)={derived_image_size} " + f"for patch_size={decoder_patch_size} and computed num_patches={num_patches}." + ) + def _to_config_compatible(value: Any) -> Any: if isinstance(value, torch.Tensor): return value.detach().cpu().tolist() @@ -511,21 +526,6 @@ def _as_optional_tensor(value: torch.Tensor | list | tuple | None) -> torch.Tens latents_std=_to_config_compatible(latents_std), ) - self.encoder_input_size = encoder_input_size - self.noise_tau = float(noise_tau) - self.reshape_to_2d = bool(reshape_to_2d) - self.use_encoder_loss = bool(use_encoder_loss) - - # Validate early, before building the (potentially large) encoder/decoder. - encoder_patch_size = int(encoder_patch_size) - if self.encoder_input_size % encoder_patch_size != 0: - raise ValueError( - f"encoder_input_size={self.encoder_input_size} must be divisible by encoder_patch_size={encoder_patch_size}." - ) - decoder_patch_size = int(patch_size) - if decoder_patch_size <= 0: - raise ValueError("patch_size must be a positive integer (this is decoder_patch_size).") - # Frozen representation encoder (built from config, no downloads) self.encoder: nn.Module = _build_encoder( encoder_type=encoder_type, @@ -534,22 +534,7 @@ def _as_optional_tensor(value: torch.Tensor | list | tuple | None) -> torch.Tens num_hidden_layers=encoder_num_hidden_layers, ) self._encoder_forward_fn = _ENCODER_FORWARD_FNS[encoder_type] - num_patches = (self.encoder_input_size // encoder_patch_size) ** 2 - - grid = int(sqrt(num_patches)) - if grid * grid != num_patches: - raise ValueError(f"Computed num_patches={num_patches} must be a perfect square.") - - derived_image_size = decoder_patch_size * grid - if image_size is None: - image_size = derived_image_size - else: - image_size = int(image_size) - if image_size != derived_image_size: - raise ValueError( - f"image_size={image_size} must equal decoder_patch_size*sqrt(num_patches)={derived_image_size} " - f"for patch_size={decoder_patch_size} and computed num_patches={num_patches}." - ) + num_patches = (encoder_input_size // encoder_patch_size) ** 2 # Encoder input normalization stats (ImageNet defaults) if encoder_norm_mean is None: @@ -584,6 +569,7 @@ def _as_optional_tensor(value: torch.Tensor | list | tuple | None) -> torch.Tens num_channels=int(num_channels), image_size=int(image_size), ) + self.num_patches = int(num_patches) self.decoder_patch_size = int(decoder_patch_size) self.decoder_image_size = int(image_size) @@ -593,16 +579,19 @@ def _as_optional_tensor(value: torch.Tensor | list | tuple | None) -> torch.Tens def _noising(self, x: torch.Tensor, generator: torch.Generator | None = None) -> torch.Tensor: # Per-sample random sigma in [0, noise_tau] - noise_sigma = self.noise_tau * torch.rand( + noise_sigma = self.config.noise_tau * torch.rand( (x.size(0),) + (1,) * (x.ndim - 1), device=x.device, dtype=x.dtype, generator=generator ) return x + noise_sigma * randn_tensor(x.shape, generator=generator, device=x.device, dtype=x.dtype) def _resize_and_normalize(self, x: torch.Tensor) -> torch.Tensor: _, _, h, w = x.shape - if h != self.encoder_input_size or w != self.encoder_input_size: + if h != self.config.encoder_input_size or w != self.config.encoder_input_size: x = F.interpolate( - x, size=(self.encoder_input_size, self.encoder_input_size), mode="bicubic", align_corners=False + x, + size=(self.config.encoder_input_size, self.config.encoder_input_size), + mode="bicubic", + align_corners=False, ) mean = self.encoder_mean.to(device=x.device, dtype=x.dtype) std = self.encoder_std.to(device=x.device, dtype=x.dtype) @@ -631,10 +620,10 @@ def _encode(self, x: torch.Tensor, generator: torch.Generator | None = None) -> else: tokens = self._encoder_forward_fn(self.encoder, x) # (B, N, C) - if self.training and self.noise_tau > 0: + if self.training and self.config.noise_tau > 0: tokens = self._noising(tokens, generator=generator) - if self.reshape_to_2d: + if self.config.reshape_to_2d: b, n, c = tokens.shape side = int(sqrt(n)) if side * side != n: @@ -671,7 +660,7 @@ def _decode(self, z: torch.Tensor) -> torch.Tensor: z = self._denormalize_latents(z) - if self.reshape_to_2d: + if self.config.reshape_to_2d: b, c, h, w = z.shape tokens = z.view(b, c, h * w).transpose(1, 2).contiguous() # (B, N, C) else: