Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bitmind/synthetic_data_generation/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .synthetic_data_generator import SyntheticDataGenerator
from .prompt_generator import PromptGenerator
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
disable_progress_bar()


class ImageAnnotationGenerator:
class PromptGenerator:
"""
A class for generating and moderating image annotations using transformer models.

Expand All @@ -31,10 +31,9 @@ class ImageAnnotationGenerator:

def __init__(
self,
model_name: str,
text_moderation_model_name: str,
vlm_name: str,
llm_name: str,
device: str = 'cuda',
apply_moderation: bool = True
) -> None:
"""
Initialize the ImageAnnotationGenerator with specific models and device settings.
Expand All @@ -47,126 +46,74 @@ def __init__(
apply_moderation: Flag to determine whether text moderation should be
applied to captions.
"""
self.model_name = model_name
self.processor = Blip2Processor.from_pretrained(
self.model_name,
cache_dir=HUGGINGFACE_CACHE_DIR
)

self.apply_moderation = apply_moderation
self.text_moderation_model_name = text_moderation_model_name
self.text_moderation_pipeline = None
self.model = None
self.vlm_name = vlm_name
self.llm_name = llm_name
self.vlm_processor = None
self.vlm = None
self.llm_pipeline = None
self.device = device

def is_model_loaded(self) -> bool:
return self.model is not None
def are_models_loaded(self) -> bool:
return (self.vlm is not None) and (self.llm_pipeline is not None)

def load_models(self) -> None:
"""
Load the necessary models for image annotation and text moderation onto
the specified device.
"""
if self.is_model_loaded():
bt.logging.warning(
f"Image annotation model {self.model_name} is already loaded"
)
if self.are_models_loaded():
bt.logging.warning(f"Models already loaded")
return

bt.logging.info(f"Loading image annotation model {self.model_name}")
self.model = Blip2ForConditionalGeneration.from_pretrained(
self.model_name,
torch_dtype=torch.float16,
bt.logging.info(f"Loading caption generation model {self.vlm_name}")
self.vlm_processor = Blip2Processor.from_pretrained(
self.vlm_name,
cache_dir=HUGGINGFACE_CACHE_DIR
)
self.model.to(self.device)
bt.logging.info(f"Loaded image annotation model {self.model_name}")
bt.logging.info(
f"Loading annotation moderation model {self.text_moderation_model_name}..."
self.vlm = Blip2ForConditionalGeneration.from_pretrained(
self.vlm_name,
torch_dtype=torch.float16,
cache_dir=HUGGINGFACE_CACHE_DIR
)
if self.apply_moderation:
model = AutoModelForCausalLM.from_pretrained(
self.text_moderation_model_name,
torch_dtype=torch.bfloat16,
cache_dir=HUGGINGFACE_CACHE_DIR
)
self.vlm.to(self.device)
bt.logging.info(f"Loaded image annotation model {self.vlm_name}")

tokenizer = AutoTokenizer.from_pretrained(
self.text_moderation_model_name,
cache_dir=HUGGINGFACE_CACHE_DIR
)
model = model.to(self.device)
self.text_moderation_pipeline = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer
)
bt.logging.info(
f"Loaded annotation moderation model {self.text_moderation_model_name}."
bt.logging.info(f"Loading caption moderation model {self.llm_name}")
llm = AutoModelForCausalLM.from_pretrained(
self.llm_name,
torch_dtype=torch.bfloat16,
cache_dir=HUGGINGFACE_CACHE_DIR
)
tokenizer = AutoTokenizer.from_pretrained(
self.llm_name,
cache_dir=HUGGINGFACE_CACHE_DIR
)
llm = llm.to(self.device)
self.llm_pipeline = pipeline(
"text-generation",
model=llm,
tokenizer=tokenizer
)
bt.logging.info(f"Loaded caption moderation model {self.llm_name}")

def clear_gpu(self) -> None:
"""
Clear GPU memory by moving models back to CPU and deleting them,
followed by collecting garbage.
"""
bt.logging.info("Clearing GPU memory after generating image annotation")
self.model.to('cpu')
del self.model
self.model = None
if self.text_moderation_pipeline:
self.text_moderation_pipeline.model.to('cpu')
del self.text_moderation_pipeline
self.text_moderation_pipeline = None
gc.collect()
torch.cuda.empty_cache()

def moderate(self, description: str, max_new_tokens: int = 80) -> str:
"""
Use the text moderation pipeline to make the description more concise
and neutral.

Args:
description: The text description to be moderated.
max_new_tokens: Maximum number of new tokens to generate in the
moderated text.

Returns:
The moderated description text, or the original description if
moderation fails.
"""
messages = [
{
"role": "system",
"content": (
"[INST]You always concisely rephrase given descriptions, "
"eliminate redundancy, and remove all specific references to "
"individuals by name. You do not respond with anything other "
"than the revised description.[/INST]"
)
},
{
"role": "user",
"content": description
}
]
try:
moderated_text = self.text_moderation_pipeline(
messages,
max_new_tokens=max_new_tokens,
pad_token_id=self.text_moderation_pipeline.tokenizer.eos_token_id,
return_full_text=False
)
bt.logging.info("Clearing GPU memory after prompt generation")
if self.vlm:
self.vlm.to('cpu')
del self.vlm
self.vlm = None

if isinstance(moderated_text, list):
return moderated_text[0]['generated_text']
if self.llm_pipeline:
self.llm_pipeline.model.to('cpu')
del self.llm_pipeline
self.llm_pipeline = None

bt.logging.error("Moderated text did not return a list.")
return description

except Exception as e:
bt.logging.error(f"An error occurred during moderation: {e}", exc_info=True)
return description
gc.collect()
torch.cuda.empty_cache()

def generate(
self,
Expand Down Expand Up @@ -200,17 +147,17 @@ def generate(

for i, prompt in enumerate(prompts):
description += prompt + ' '
inputs = self.processor(
inputs = self.vlm_processor(
image,
text=description,
return_tensors="pt"
).to(self.device, torch.float16)

generated_ids = self.model.generate(
generated_ids = self.vlm.generate(
**inputs,
max_new_tokens=max_new_tokens
)
answer = self.processor.batch_decode(
answer = self.vlm_processor.batch_decode(
generated_ids,
skip_special_tokens=True
)[0].strip()
Expand All @@ -237,8 +184,102 @@ def generate(
if not description.endswith('.'):
description += '.'

if self.apply_moderation:
moderated_description = self.moderate(description)
return moderated_description
moderated_description = self.moderate(description)
enhanced_description = self.enhance(description)
return enhanced_description

return description
def moderate(self, description: str, max_new_tokens: int = 80) -> str:
"""
Use the text moderation pipeline to make the description more concise
and neutral.

Args:
description: The text description to be moderated.
max_new_tokens: Maximum number of new tokens to generate in the
moderated text.

Returns:
The moderated description text, or the original description if
moderation fails.
"""
messages = [
{
"role": "system",
"content": (
"[INST]You always concisely rephrase given descriptions, "
"eliminate redundancy, and remove all specific references to "
"individuals by name. You do not respond with anything other "
"than the revised description.[/INST]"
)
},
{
"role": "user",
"content": description
}
]
try:
moderated_text = self.llm_pipeline(
messages,
max_new_tokens=max_new_tokens,
pad_token_id=self.llm_pipeline.tokenizer.eos_token_id,
return_full_text=False
)
return moderated_text[0]['generated_text']

except Exception as e:
bt.logging.error(f"An error occurred during moderation: {e}", exc_info=True)
return description

def enhance(self, description: str, max_new_tokens: int = 80) -> str:
"""
Enhance a static image description to make it suitable for video generation
by adding dynamic elements and motion.

Args:
description: The static image description to enhance.
max_new_tokens: Maximum number of new tokens to generate in the enhanced text.

Returns:
An enhanced description suitable for video generation, or the original
description if enhancement fails.
"""
messages = [
{
"role": "system",
"content": (
"[INST]You are an expert at converting static image descriptions "
"into dynamic video prompts. Enhance the given description by "
"adding natural motion and temporal elements while preserving the "
"core scene. Follow these rules:\n"
"1. Maintain the essential elements of the original description\n"
"2. Add smooth, continuous motions that work well in video\n"
"3. For portraits: Add natural facial movements or expressions\n"
"4. For non-portrait images with people: Add contextually appropriate "
"actions (e.g., for a beach scene, people might be walking along "
"the shoreline or playing in the waves; for a cafe scene, people "
"might be sipping drinks or engaging in conversation)\n"
"5. For landscapes: Add environmental motion like wind or water\n"
"6. For urban scenes: Add dynamic elements like people or traffic\n"
"7. Keep the description concise but descriptive\n"
"8. Focus on gradual, natural transitions\n"
"Only respond with the enhanced description.[/INST]"
)
},
{
"role": "user",
"content": description
}
]

try:
enhanced_text = self.llm_pipeline(
messages,
max_new_tokens=max_new_tokens,
pad_token_id=self.llm_pipeline.tokenizer.eos_token_id,
return_full_text=False
)
return enhanced_text[0]['generated_text']

except Exception as e:
print(f"An error occurred during motion enhancement: {e}")
return description
27 changes: 13 additions & 14 deletions bitmind/synthetic_data_generation/synthetic_data_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
get_modality
)
from bitmind.synthetic_data_generation.prompt_utils import truncate_prompt_if_too_long
from bitmind.synthetic_data_generation.image_annotation_generator import ImageAnnotationGenerator
from bitmind.synthetic_data_generation.prompt_generator import PromptGenerator
from bitmind.validator.cache import ImageCache


Expand Down Expand Up @@ -59,7 +59,7 @@ class SyntheticDataGenerator:
prompt_type: The type of prompt generation strategy ('random', 'annotation').
prompt_generator_name: Name of the prompt generation model.
t2vis_model_name: Name of the t2v or t2i model.
image_annotation_generator: The generator object for annotating images if required.
prompt_generator: The vlm/llm pipeline for generating input prompts for t2i/t2v models
output_dir: Directory to write generated data.
"""

Expand Down Expand Up @@ -106,20 +106,20 @@ def __init__(
self.t2vis_model_name = None

self.prompt_type = prompt_type
if self.prompt_type == 'annotation':
self.image_annotation_generator = ImageAnnotationGenerator(
model_name=IMAGE_ANNOTATION_MODEL,
text_moderation_model_name=TEXT_MODERATION_MODEL
)
else:
raise NotImplementedError(f"Unsupported prompt type: {self.prompt_type}")
self.image_cache = image_cache
if self.prompt_type == 'annotation' and self.image_cache is None:
raise ValueError(f"image_cache cannot be None if prompt_type == 'annotation'")

self.prompt_generator = PromptGenerator(
vlm_name=IMAGE_ANNOTATION_MODEL,
llm_name=TEXT_MODERATION_MODEL
)

self.output_dir = Path(output_dir) if output_dir else None
if self.output_dir:
(self.output_dir / "video").mkdir(parents=True, exist_ok=True)
(self.output_dir / "image").mkdir(parents=True, exist_ok=True)

self.image_cache = image_cache

def batch_generate(self, batch_size: int = 5) -> None:
"""
Expand All @@ -136,7 +136,6 @@ def batch_generate(self, batch_size: int = 5) -> None:
prompts.append(self.generate_prompt(image=image_sample['image'], clear_gpu=i==batch_size-1))
bt.logging.info(f"Caption {i+1}/{batch_size} generated: {prompts[-1]}")


# shuffle and interleave models
t2i_model_names = random.sample(T2I_MODEL_NAMES, len(T2I_MODEL_NAMES))
t2v_model_names = random.sample(T2V_MODEL_NAMES, len(T2V_MODEL_NAMES))
Expand Down Expand Up @@ -206,10 +205,10 @@ def generate_prompt(
raise ValueError(
"image can't be None if self.prompt_type is 'annotation'"
)
self.image_annotation_generator.load_models()
prompt = self.image_annotation_generator.generate(image)
self.prompt_generator.load_models()
prompt = self.prompt_generator.generate(image)
if clear_gpu:
self.image_annotation_generator.clear_gpu()
self.prompt_generator.clear_gpu()
else:
raise NotImplementedError(f"Unsupported prompt type: {self.prompt_type}")
return prompt
Expand Down
Loading