Is your feature request related to a problem? Please describe.
Currently there are many casts to float32 so that the VAE can work without hitting zero values when in float16 mode. When in float16 mode or when dtype=torch.bfloat16 we should use the VAE in torch.bfloat16 if it is available in our environment. This will reduce VRAM usage by approximately half and in some cases increase performance.
Describe the solution you'd like
# ...somewhere, in a pipeline...
if self.vae.dtype == torch.float16 and torch.cuda.is_bf16_supported():
self.vae.to(dtype=torch.bfloat16)
# Make sure the VAE is in float32 mode, as it overflows in float16.
# We don't need to do the upcasting and float32 dance if we have
# access to bfloat16, if which case we can just directly use the recast
# latents.
if self.vae.dtype == torch.bfloat16:
latents = latents.to(dtype=torch.bfloat16)
elif self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
self.upcast_vae()
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)