Skip to content
63 changes: 63 additions & 0 deletions docs/source/en/api/pipelines/text_to_video.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,12 @@ Resources:
| Pipeline | Tasks | Demo
|---|---|:---:|
| [TextToVideoSDPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py) | *Text-to-Video Generation* | [🤗 Spaces](https://huggingface.co/spaces/damo-vilab/modelscope-text-to-video-synthesis)
| [VideoToVideoSDPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py) | *Text-Guided Video-to-Video Generation* | [(TODO)🤗 Spaces]()

## Usage example

### `text-to-video-ms-1.7b`

Let's start by generating a short video with the default length of 16 frames (2s at 8 fps):

```python
Expand Down Expand Up @@ -119,12 +122,72 @@ Here are some sample outputs:
</tr>
</table>

### `cerspense/zeroscope_v2_576w` & `cerspense/zeroscope_v2_XL`

Zeroscope are watermark-free model and have been trained on specific sizes such as `576x320` and `1024x576`.
One should first generate a video using the lower resolution checkpoint [`cerspense/zeroscope_v2_576w`](https://huggingface.co/cerspense/zeroscope_v2_576w) with [`TextToVideoSDPipeline`],
which can then be upscaled using [`VideoToVideoSDPipeline`] and [`cerspense/zeroscope_v2_XL`](https://huggingface.co/cerspense/zeroscope_v2_XL).


```py
import torch
from diffusers import DiffusionPipeline
from diffusers.utils import export_to_video

pipe = DiffusionPipeline.from_pretrained("cerspense/zeroscope_v2_576w", torch_dtype=torch.float16)
pipe.enable_model_cpu_offload()

# memory optimization
pipe.enable_vae_slicing()

prompt = "Darth Vader surfing a wave"
video_frames = pipe(prompt, num_frames=24).frames
video_path = export_to_video(video_frames)
video_path
```

Now the video can be upscaled:

```py
pipe = DiffusionPipeline.from_pretrained("cerspense/zeroscope_v2_XL", torch_dtype=torch.float16)
pipe.vae.enable_slicing()
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload()

video = [Image.fromarray(frame).resize((1024, 576)) for frame in video_frames]

video_frames = pipe(prompt, video=video, strength=0.6).frames
video_path = export_to_video(video_frames)
video_path
```

Here are some sample outputs:

<table>
<tr>
<td ><center>
Darth vader surfing in waves.
<br>
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/darthvader_cerpense.gif"
alt="Darth vader surfing in waves."
style="width: 576px;" />
</center></td>
</tr>
</table>

## Available checkpoints

* [damo-vilab/text-to-video-ms-1.7b](https://huggingface.co/damo-vilab/text-to-video-ms-1.7b/)
* [damo-vilab/text-to-video-ms-1.7b-legacy](https://huggingface.co/damo-vilab/text-to-video-ms-1.7b-legacy)
* [cerspense/zeroscope_v2_576w](https://huggingface.co/cerspense/zeroscope_v2_576w)
* [cerspense/zeroscope_v2_XL](https://huggingface.co/cerspense/zeroscope_v2_XL)

## TextToVideoSDPipeline
[[autodoc]] TextToVideoSDPipeline
- all
- __call__

## VideoToVideoSDPipeline
[[autodoc]] VideoToVideoSDPipeline
- all
- __call__
1 change: 1 addition & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@
VersatileDiffusionImageVariationPipeline,
VersatileDiffusionPipeline,
VersatileDiffusionTextToImagePipeline,
VideoToVideoSDPipeline,
VQDiffusionPipeline,
)

Expand Down
7 changes: 6 additions & 1 deletion src/diffusers/models/autoencoder_kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,12 @@ def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderK
if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
return self.tiled_encode(x, return_dict=return_dict)

h = self.encoder(x)
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self.encoder(x)

moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)

Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
StableUnCLIPPipeline,
)
from .stable_diffusion_safe import StableDiffusionPipelineSafe
from .text_to_video_synthesis import TextToVideoSDPipeline, TextToVideoZeroPipeline
from .text_to_video_synthesis import TextToVideoSDPipeline, TextToVideoZeroPipeline, VideoToVideoSDPipeline
from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline
from .unidiffuser import ImageTextPipelineOutput, UniDiffuserModel, UniDiffuserPipeline, UniDiffuserTextDecoder
from .versatile_diffusion import (
Expand Down
3 changes: 2 additions & 1 deletion src/diffusers/pipelines/text_to_video_synthesis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,6 @@ class TextToVideoSDPipelineOutput(BaseOutput):
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .pipeline_text_to_video_synth import TextToVideoSDPipeline # noqa: F401
from .pipeline_text_to_video_synth import TextToVideoSDPipeline
from .pipeline_text_to_video_synth_img2img import VideoToVideoSDPipeline # noqa: F401
from .pipeline_text_to_video_zero import TextToVideoZeroPipeline
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,9 @@ def __call__(
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)

if output_type == "latent":
return TextToVideoSDPipelineOutput(frames=latents)

video_tensor = self.decode_latents(latents)

if output_type == "pt":
Expand Down
Loading