diff --git a/bitmind/__init__.py b/bitmind/__init__.py index f65f125c..f2841d24 100644 --- a/bitmind/__init__.py +++ b/bitmind/__init__.py @@ -18,7 +18,7 @@ # DEALINGS IN THE SOFTWARE. -__version__ = "2.2.4" +__version__ = "2.2.6" version_split = __version__.split(".") __spec_version__ = ( (1000 * int(version_split[0])) diff --git a/bitmind/synthetic_data_generation/synthetic_data_generator.py b/bitmind/synthetic_data_generation/synthetic_data_generator.py index b910aa47..19b60607 100644 --- a/bitmind/synthetic_data_generation/synthetic_data_generator.py +++ b/bitmind/synthetic_data_generation/synthetic_data_generator.py @@ -418,6 +418,15 @@ def load_model(self, model_name: Optional[str] = None, modality: Optional[str] = **pipeline_args, add_watermarker=False ) + + # Load LoRA weights if specified + if 'lora_model_id' in model_config: + bt.logging.info(f"Loading LoRA weights from {model_config['lora_model_id']}") + lora_loading_args = model_config.get('lora_loading_args', {}) + self.model.load_lora_weights( + model_config['lora_model_id'], + **lora_loading_args + ) # Load scheduler if specified if 'scheduler' in model_config: diff --git a/bitmind/validator/config.py b/bitmind/validator/config.py index f7ba906e..9d7e3333 100644 --- a/bitmind/validator/config.py +++ b/bitmind/validator/config.py @@ -115,10 +115,13 @@ class Modality(StrEnum): {"path": "bitmind/caltech-256"}, {"path": "bitmind/caltech-101"}, {"path": "bitmind/dtd"} - ], "semisynthetic": [ {"path": "bitmind/face-swap"} + ], + "synthetic": [ + {"path": "bitmind/JourneyDB"}, + {"path": "bitmind/GenImage_MidJourney"} ] } @@ -178,6 +181,19 @@ class Modality(StrEnum): }, "enable_model_cpu_offload": False }, + "runwayml/stable-diffusion-v1-5-midjourney-v6": { + "pipeline_cls": StableDiffusionPipeline, + "from_pretrained_args": { + "model_id": "runwayml/stable-diffusion-v1-5", + "use_safetensors": True, + "torch_dtype": torch.float16, + }, + "lora_model_id": "Kvikontent/midjourney-v6", + "lora_loading_args": { + "use_peft_backend": True + }, + "enable_model_cpu_offload": False + }, "prompthero/openjourney-v4" : { "pipeline_cls": StableDiffusionPipeline, "from_pretrained_args": { @@ -442,4 +458,4 @@ def select_random_model(task: Optional[str] = None) -> str: elif task == 'i2i': return np.random.choice(I2I_MODEL_NAMES) else: - raise NotImplementedError(f"Unsupported task: {task}") + raise NotImplementedError(f"Unsupported task: {task}") \ No newline at end of file diff --git a/bitmind/validator/scripts/run_cache_updater.py b/bitmind/validator/scripts/run_cache_updater.py index 1d4d1be3..1747b3a6 100644 --- a/bitmind/validator/scripts/run_cache_updater.py +++ b/bitmind/validator/scripts/run_cache_updater.py @@ -17,6 +17,7 @@ VIDEO_ZIP_CACHE_UPDATE_INTERVAL, REAL_VIDEO_CACHE_DIR, REAL_IMAGE_CACHE_DIR, + SYNTH_IMAGE_CACHE_DIR, SEMISYNTH_VIDEO_CACHE_DIR, SEMISYNTH_IMAGE_CACHE_DIR, MAX_COMPRESSED_GB, @@ -52,21 +53,21 @@ async def main(args): max_compressed_size_gb=MAX_COMPRESSED_GB ) semisynth_image_cache.start_updater() - - if args.modality in ['all', 'video']: - bt.logging.info("Starting semisynthetic video cache updater") - semisynth_video_cache = VideoCache( - cache_dir=SEMISYNTH_VIDEO_CACHE_DIR, - datasets=VIDEO_DATASETS['semisynthetic'], - video_update_interval=args.video_interval, - zip_update_interval=args.video_zip_interval, - num_zips_per_dataset=2, - num_videos_per_zip=100, + + bt.logging.info("Starting synthetic image cache updater") + synth_image_cache = ImageCache( + cache_dir=SYNTH_IMAGE_CACHE_DIR, + datasets=IMAGE_DATASETS['synthetic'], + parquet_update_interval=args.image_parquet_interval, + image_update_interval=args.image_interval, + num_parquets_per_dataset=5, + num_images_per_source=100, max_extracted_size_gb=MAX_EXTRACTED_GB, max_compressed_size_gb=MAX_COMPRESSED_GB ) - semisynth_video_cache.start_updater() - + synth_image_cache.start_updater() + + if args.modality in ['all', 'video']: bt.logging.info("Starting real video cache updater") real_video_cache = VideoCache( cache_dir=REAL_VIDEO_CACHE_DIR, @@ -79,6 +80,19 @@ async def main(args): max_compressed_size_gb=100, ) real_video_cache.start_updater() + + bt.logging.info("Starting semisynthetic video cache updater") + semisynth_video_cache = VideoCache( + cache_dir=SEMISYNTH_VIDEO_CACHE_DIR, + datasets=VIDEO_DATASETS['semisynthetic'], + video_update_interval=args.video_interval, + zip_update_interval=args.video_zip_interval, + num_zips_per_dataset=2, + num_videos_per_zip=100, + max_extracted_size_gb=MAX_EXTRACTED_GB, + max_compressed_size_gb=MAX_COMPRESSED_GB + ) + semisynth_video_cache.start_updater() while True: bt.logging.info(f"Running cache updaters for: {args.modality}") diff --git a/requirements.txt b/requirements.txt index 50aaf9ab..eb68fdf4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,6 +17,7 @@ timm==1.0.12 einops==0.8.0 ultralytics==8.3.44 janus @ git+https://github.com/deepseek-ai/Janus.git +peft==0.15.0 # Image/Video processing datasets==3.1.0