Skip to content

feat: implement rae autoencoder.#13046

Merged
sayakpaul merged 71 commits intohuggingface:mainfrom
Ando233:rae
Mar 5, 2026
Merged

feat: implement rae autoencoder.#13046
sayakpaul merged 71 commits intohuggingface:mainfrom
Ando233:rae

Conversation

@Ando233
Copy link
Contributor

@Ando233 Ando233 commented Jan 28, 2026

What does this PR do?

This PR adds a new representation autoencoder implementation, AutoencoderRAE, to diffusers.
Implements diffusers.models.autoencoders.autoencoder_rae.AutoencoderRAE with a frozen pretrained vision encoder (DINOv2 / SigLIP2 / ViT-MAE) and a ViT-MAE style decoder.
The decoder implementation is aligned with the RAE-main GeneralDecoder parameter structure, enabling loading of existing trained decoder checkpoints (e.g. model.pt) without key mismatches when encoder/decoder settings are consistent.
Adds unit/integration tests under diffusers/tests/models/autoencoders/test_models_autoencoder_rae.py.
Registers exports so users can import directly via from diffusers import AutoencoderRAE.

Fixes #13000

Before submitting

Usage

ae = AutoencoderRAE(
    encoder_cls="dinov2",
    encoder_name_or_path=encoder_path,
    image_size=image_size,
    encoder_input_size=image_size,
    patch_size=patch_size,
    num_patches=num_patches,
    decoder_hidden_size=1152,
    decoder_num_hidden_layers=28,
    decoder_num_attention_heads=16,
    decoder_intermediate_size=4096,
).to(device)
ae.eval()

state = torch.load(args.decoder_ckpt, map_location="cpu")
ae.decoder.load_state_dict(state, strict=False)

with torch.no_grad():
    recon = ae(x).sample

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@sayakpaul sayakpaul requested a review from kashif January 30, 2026 11:31
@sayakpaul
Copy link
Member

@bytetriper if you could take a look?

@kashif
Copy link
Contributor

kashif commented Jan 30, 2026

nice works @Ando233 checking

@kashif
Copy link
Contributor

kashif commented Jan 30, 2026

off the bat,

  • let's have a nice convention for the output datatype classes, have a look at the other autoencoder for the convention in difusers
  • some of the tests might need to be marked as slow and some paths are hard-coded

lets sort out these things and then re-look

@bytetriper
Copy link

Agree with @kashif . Also if possible we can bake all the params into config so we can enable .from_pretrained(), which is more elegant and aligns with diffusers usage. I can help convert our released ckpt to hgf format afterwards

@sayakpaul
Copy link
Member

@Ando233 we're happy to provide assistance if needed.

@kashif
Copy link
Contributor

kashif commented Feb 15, 2026

@Ando233 the one remaining thing is the use of the use_encoder_loss and perhaps an example real-world training script

@kashif
Copy link
Contributor

kashif commented Feb 15, 2026

@bytetriper could you kindly try to run the conversion scripts and upload the diffusers style weights to your huggingface hub for the checkpoints you have?

@kashif
Copy link
Contributor

kashif commented Mar 4, 2026

@dg845 resolved issues thanks

@sayakpaul
Copy link
Member

@Ando233 / @kashif could you merge Ando233#1?

self.decoder_embed = nn.Linear(hidden_size, decoder_hidden_size, bias=True)
self.register_buffer("decoder_pos_embed", torch.zeros(1, num_patches + 1, decoder_hidden_size))
self.register_buffer(
"decoder_pos_embed", torch.zeros(1, num_patches + 1, decoder_hidden_size), persistent=False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if we directly initialize this to what we're doing in the initialize_weights() function? Could we get rid of the explicit device placement in return x_rec.to(device=z.device) then?

Copy link
Member

@sayakpaul sayakpaul Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess a simpler solution would be to directly assign the pos_embed value (we are initializing through initalize_weights()) and just persist it in the state dict. That way, we can skip explicit device placements like return x_rec.to(device=z.device)?

This would require opening PRs to the RAE repos on the Hub, though.

@dg845 WDYT?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think something like

class RAEDecoder(nn.Module):
    ...
    def __init__(...):
        ...
        grid_size = int(num_patches**0.5)
        pos_embed = get_2d_sincos_pos_embed(
            decoder_hidden_size,
            grid_size,
            cls_token=True,
            extra_tokens=1,
            output_type="pt",
        )
        self.register_buffer("decoder_pos_embed", pos_embed.unsqueeze(0), persistent=False)
        ...

is generally how we would implement this, see e.g. here:

self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False)
self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False)

I think this should handle device placement automatically and also avoid the need to change the Hub repo (although maybe the conversion script might need to be changed?).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For test_model_parallelism specifically, the error I encountered earlier in #13046 (comment) may be the result of an accelerate bug where the decoder_pos_embed buffer doesn't end up on the device_map, so AutoencoderRAE.from_pretrained(..., device_map="auto") doesn't know where to put it and gives that error. I've opened an issue for this at huggingface/accelerate#3956.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kashif WDYT? Could we make this change?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes i can do it

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Then let's prefer thi solution.

@sayakpaul sayakpaul merged commit 8ec0a5c into huggingface:main Mar 5, 2026
27 of 28 checks passed
@sayakpaul
Copy link
Member

Thanks a lot @kashif for shipping RAEs!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

RAE support

7 participants