diff --git a/bitmind/synthetic_data_generation/synthetic_data_generator.py b/bitmind/synthetic_data_generation/synthetic_data_generator.py index b910aa47..19b60607 100644 --- a/bitmind/synthetic_data_generation/synthetic_data_generator.py +++ b/bitmind/synthetic_data_generation/synthetic_data_generator.py @@ -418,6 +418,15 @@ def load_model(self, model_name: Optional[str] = None, modality: Optional[str] = **pipeline_args, add_watermarker=False ) + + # Load LoRA weights if specified + if 'lora_model_id' in model_config: + bt.logging.info(f"Loading LoRA weights from {model_config['lora_model_id']}") + lora_loading_args = model_config.get('lora_loading_args', {}) + self.model.load_lora_weights( + model_config['lora_model_id'], + **lora_loading_args + ) # Load scheduler if specified if 'scheduler' in model_config: diff --git a/bitmind/validator/config.py b/bitmind/validator/config.py index f7ba906e..40a9525e 100644 --- a/bitmind/validator/config.py +++ b/bitmind/validator/config.py @@ -178,6 +178,19 @@ class Modality(StrEnum): }, "enable_model_cpu_offload": False }, + "runwayml/stable-diffusion-v1-5-midjourney-v6": { + "pipeline_cls": StableDiffusionPipeline, + "from_pretrained_args": { + "model_id": "runwayml/stable-diffusion-v1-5", + "use_safetensors": True, + "torch_dtype": torch.float16, + }, + "lora_model_id": "Kvikontent/midjourney-v6", + "lora_loading_args": { + "use_peft_backend": True + }, + "enable_model_cpu_offload": False + }, "prompthero/openjourney-v4" : { "pipeline_cls": StableDiffusionPipeline, "from_pretrained_args": { @@ -442,4 +455,4 @@ def select_random_model(task: Optional[str] = None) -> str: elif task == 'i2i': return np.random.choice(I2I_MODEL_NAMES) else: - raise NotImplementedError(f"Unsupported task: {task}") + raise NotImplementedError(f"Unsupported task: {task}") \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 50aaf9ab..eb68fdf4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,6 +17,7 @@ timm==1.0.12 einops==0.8.0 ultralytics==8.3.44 janus @ git+https://github.com/deepseek-ai/Janus.git +peft==0.15.0 # Image/Video processing datasets==3.1.0