diff --git a/bitmind/generation/models.py b/bitmind/generation/models.py index 0bc0b7ad..d3afb505 100644 --- a/bitmind/generation/models.py +++ b/bitmind/generation/models.py @@ -17,12 +17,15 @@ AutoPipelineForInpainting, CogView4Pipeline, CogVideoXImageToVideoPipeline, + WanPipeline, + AutoencoderKLWan ) from bitmind.generation.model_registry import ModelRegistry from bitmind.generation.util.model import ( load_hunyuanvideo_transformer, load_annimatediff_motion_adapter, + load_vae, JanusWrapper, ) from bitmind.types import ModelConfig, ModelTask @@ -255,6 +258,28 @@ def get_text_to_video_models() -> List[ModelConfig]: List of text-to-video model configurations """ return [ + ModelConfig( + path="Wan-AI/Wan2.1-T2V-1.3B-Diffusers", + task=ModelTask.TEXT_TO_VIDEO, + pipeline_cls=WanPipeline, + pretrained_args={ + "vae": load_vae( + vae_cls=AutoencoderKLWan, + model_id="Wan-AI/Wan2.1-T2V-1.3B-Diffusers", + subfolder="vae", + torch_dtype=torch.float32 + ), + "torch_dtype": torch.bfloat16 + }, + generate_args={ + "resolution": [480, 832], + "num_frames": 81, + "guidance_scale": 5.0 + }, + save_args={"fps": 15}, + use_autocast=False, + tags=["wan2.1"] + ), ModelConfig( path="tencent/HunyuanVideo", task=ModelTask.TEXT_TO_VIDEO, diff --git a/bitmind/generation/util/model.py b/bitmind/generation/util/model.py index 13acae0a..4156577a 100644 --- a/bitmind/generation/util/model.py +++ b/bitmind/generation/util/model.py @@ -14,6 +14,13 @@ from typing import Any, Dict, Optional +def load_vae(vae_cls, model_id, subfolder, torch_dtype=torch.float32): + return vae_cls.from_pretrained( + model_id, + subfolder=subfolder, + torch_dtype=torch_dtype + ) + def load_hunyuanvideo_transformer( model_id: str = "tencent/HunyuanVideo", subfolder: str = "transformer",