diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 419d56a1..caeba282 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -36,7 +36,7 @@ jobs: flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - - name: Test with pytest - run: | - # run tests in tests/ dir and only fail if there are failures or errors - pytest tests/ --verbose --failed-first --exitfirst --disable-warnings + #- name: Test with pytest + # run: | + # # run tests in tests/ dir and only fail if there are failures or errors + # pytest tests/ --verbose --failed-first --exitfirst --disable-warnings diff --git a/.gitignore b/.gitignore index 3bda75c6..521104fb 100644 --- a/.gitignore +++ b/.gitignore @@ -164,7 +164,10 @@ data/ checkpoints/ .requirements_installed base_miner/NPR/weights/* -base_miner/UCF/weights/* -base_miner/UCF/logs/* +base_miner/NPR/logs/* +base_miner/DFB/weights/* +base_miner/DFB/logs/* miner_eval.py *.env +*~ +wandb/ \ No newline at end of file diff --git a/base_miner/UCF/README.md b/base_miner/DFB/README.md similarity index 100% rename from base_miner/UCF/README.md rename to base_miner/DFB/README.md diff --git a/base_miner/UCF/config/__init__.py b/base_miner/DFB/config/__init__.py similarity index 100% rename from base_miner/UCF/config/__init__.py rename to base_miner/DFB/config/__init__.py diff --git a/base_miner/DFB/config/constants.py b/base_miner/DFB/config/constants.py new file mode 100644 index 00000000..2ae2373e --- /dev/null +++ b/base_miner/DFB/config/constants.py @@ -0,0 +1,19 @@ +import os + +CONFIGS_DIR = os.path.dirname(os.path.abspath(__file__)) +BASE_PATH = os.path.abspath(os.path.join(CONFIGS_DIR, "..")) # Points to bitmind-subnet/base_miner/DFB/ +WEIGHTS_DIR = os.path.join(BASE_PATH, "weights") + +CONFIG_PATHS = { + 'UCF': os.path.join(CONFIGS_DIR, "ucf.yaml"), + 'TALL': os.path.join(CONFIGS_DIR, "tall.yaml") +} + +HF_REPOS = { + "UCF": "bitmind/ucf", + "TALL": "bitmind/tall" +} + +BACKBONE_CKPT = "xception_best.pth" + +DLIB_FACE_PREDICTOR_PATH = os.path.abspath(os.path.join(BASE_PATH, "../../bitmind/dataset_processing/dlib_tools/shape_predictor_81_face_landmarks.dat")) \ No newline at end of file diff --git a/base_miner/DFB/config/helpers.py b/base_miner/DFB/config/helpers.py new file mode 100644 index 00000000..557bf896 --- /dev/null +++ b/base_miner/DFB/config/helpers.py @@ -0,0 +1,81 @@ +import yaml + + +def save_config(config, outputs_dir): + """ + Saves a config dictionary as both a pickle file and a YAML file, ensuring only basic types are saved. + Also, lists like 'mean' and 'std' are saved in flow style (on a single line). + + Args: + config (dict): The configuration dictionary to save. + outputs_dir (str): The directory path where the files will be saved. + """ + + def is_basic_type(value): + """ + Check if a value is a basic data type that can be saved in YAML. + Basic types include int, float, str, bool, list, and dict. + """ + return isinstance(value, (int, float, str, bool, list, dict, type(None))) + + def filter_dict(data_dict): + """ + Recursively filter out any keys from the dictionary whose values contain non-basic types (e.g., objects). + """ + if not isinstance(data_dict, dict): + return data_dict + + filtered_dict = {} + for key, value in data_dict.items(): + if isinstance(value, dict): + # Recursively filter nested dictionaries + nested_dict = filter_dict(value) + if nested_dict: # Only add non-empty dictionaries + filtered_dict[key] = nested_dict + elif is_basic_type(value): + # Add if the value is a basic type + filtered_dict[key] = value + else: + # Skip the key if the value is not a basic type (e.g., an object) + print(f"Skipping key '{key}' because its value is of type {type(value)}") + + return filtered_dict + + def save_dict_to_yaml(data_dict, file_path): + """ + Saves a dictionary to a YAML file, excluding any keys where the value is an object or contains an object. + Additionally, ensures that specific lists (like 'mean' and 'std') are saved in flow style. + + Args: + data_dict (dict): The dictionary to save. + file_path (str): The local file path where the YAML file will be saved. + """ + + # Custom representer for lists to force flow style (compact lists) + class FlowStyleList(list): + pass + + def flow_style_list_representer(dumper, data): + return dumper.represent_sequence('tag:yaml.org,2002:seq', data, flow_style=True) + + yaml.add_representer(FlowStyleList, flow_style_list_representer) + + # Preprocess specific lists to be in flow style + if 'mean' in data_dict: + data_dict['mean'] = FlowStyleList(data_dict['mean']) + if 'std' in data_dict: + data_dict['std'] = FlowStyleList(data_dict['std']) + + try: + # Filter the dictionary + filtered_dict = filter_dict(data_dict) + + # Save the filtered dictionary as YAML + with open(file_path, 'w') as f: + yaml.dump(filtered_dict, f, default_flow_style=False) # Save with default block style except for FlowStyleList + print(f"Filtered dictionary successfully saved to {file_path}") + except Exception as e: + print(f"Error saving dictionary to YAML: {e}") + + # Save as YAML + save_dict_to_yaml(config, outputs_dir + '/config.yaml') \ No newline at end of file diff --git a/base_miner/DFB/config/tall.yaml b/base_miner/DFB/config/tall.yaml new file mode 100644 index 00000000..96de6a86 --- /dev/null +++ b/base_miner/DFB/config/tall.yaml @@ -0,0 +1,89 @@ +# model setting +pretrained: https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth # path to a pre-trained model, if using one +model_name: tall # model name + +mask_grid_size: 16 +num_classes: 2 +embed_dim: 128 +mlp_ratio: 4.0 +patch_size: 4 +window_size: [14, 14, 14, 7] +depths: [2, 2, 18, 2] +num_heads: [4, 8, 16, 32] +ape: true # use absolution position embedding +thumbnail_rows: 2 +drop_rate: 0 +drop_path_rate: 0.1 + +# dataset +all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV] +train_dataset: [FaceForensics++] +test_dataset: [Celeb-DF-v2] + +compression: c23 # compression-level for videos +train_batchSize: 64 # training batch size +test_batchSize: 64 # test batch size +workers: 4 # number of data loading workers +frame_num: {'train': 32, 'test': 32} # number of frames to use per video in training and testing +resolution: 224 # resolution of output image to network +with_mask: false # whether to include mask information in the input +with_landmark: false # whether to include facial landmark information in the input +video_mode: True # whether to use video-level data +clip_size: 4 # number of frames in each clip, should be square number of an integer +dataset_type: tall + +# data augmentation +use_data_augmentation: false # Add this flag to enable/disable data augmentation +data_aug: + flip_prob: 0.5 + rotate_prob: 0.5 + rotate_limit: [-10, 10] + blur_prob: 0.5 + blur_limit: [3, 7] + brightness_prob: 0.5 + brightness_limit: [-0.1, 0.1] + contrast_limit: [-0.1, 0.1] + quality_lower: 40 + quality_upper: 100 + +# mean and std for normalization +mean: [0.485, 0.456, 0.406] +std: [0.229, 0.224, 0.225] + +# optimizer config +optimizer: + # choose between 'adam' and 'sgd' + type: adam + adam: + lr: 0.00002 # learning rate + beta1: 0.9 # beta1 for Adam optimizer + beta2: 0.999 # beta2 for Adam optimizer + eps: 0.00000001 # epsilon for Adam optimizer + weight_decay: 0.0005 # weight decay for regularization + amsgrad: false + sgd: + lr: 0.0002 # learning rate + momentum: 0.9 # momentum for SGD optimizer + weight_decay: 0.0005 # weight decay for regularization + +# training config +lr_scheduler: null # learning rate scheduler +nEpochs: 100 # number of epochs to train for +start_epoch: 0 # manual epoch number (useful for restarts) +save_epoch: 1 # interval epochs for saving models +rec_iter: 100 # interval iterations for recording +logdir: ./logs # folder to output images and logs +manualSeed: 1024 # manual seed for random number generation +save_ckpt: true # whether to save checkpoint +save_feat: true # whether to save features + +# loss function +loss_func: cross_entropy # loss function to use +losstype: null + +# metric +metric_scoring: auc # metric for evaluation (auc, acc, eer, ap) + +# cuda +cuda: true # whether to use CUDA acceleration +cudnn: true # whether to use CuDNN for convolution operations \ No newline at end of file diff --git a/base_miner/UCF/config/ucf.yaml b/base_miner/DFB/config/ucf.yaml similarity index 96% rename from base_miner/UCF/config/ucf.yaml rename to base_miner/DFB/config/ucf.yaml index 40eb4b26..cee1097f 100644 --- a/base_miner/UCF/config/ucf.yaml +++ b/base_miner/DFB/config/ucf.yaml @@ -2,7 +2,9 @@ log_dir: ../debug_logs/ucf # model setting -pretrained: ../weights/xception-best.pth # path to a pre-trained model, if using one +pretrained: + hf_repo: bm_ucf + filename: xception-best.pth model_name: ucf # model name backbone_name: xception # backbone name encoder_feat_dim: 512 # feature dimension of the backbone diff --git a/base_miner/UCF/config/xception.yaml b/base_miner/DFB/config/xception.yaml similarity index 100% rename from base_miner/UCF/config/xception.yaml rename to base_miner/DFB/config/xception.yaml diff --git a/base_miner/UCF/detectors/__init__.py b/base_miner/DFB/detectors/__init__.py similarity index 78% rename from base_miner/UCF/detectors/__init__.py rename to base_miner/DFB/detectors/__init__.py index 6059a264..cbaeaf92 100644 --- a/base_miner/UCF/detectors/__init__.py +++ b/base_miner/DFB/detectors/__init__.py @@ -8,4 +8,5 @@ from metrics.registry import DETECTOR -from .ucf_detector import UCFDetector \ No newline at end of file +from .ucf_detector import UCFDetector +from .tall_detector import TALLDetector \ No newline at end of file diff --git a/base_miner/UCF/detectors/base_detector.py b/base_miner/DFB/detectors/base_detector.py similarity index 100% rename from base_miner/UCF/detectors/base_detector.py rename to base_miner/DFB/detectors/base_detector.py diff --git a/base_miner/DFB/detectors/tall_detector.py b/base_miner/DFB/detectors/tall_detector.py new file mode 100644 index 00000000..8a175fa3 --- /dev/null +++ b/base_miner/DFB/detectors/tall_detector.py @@ -0,0 +1,1019 @@ +""" +# author: Kangran Zhao +# email: kangranzhao@link.cuhk.edu.cn +# date: 2023-0822 +# description: Class for the TALLDetector + +Functions in the Class are summarized as: +1. __init__: Initialization +2. build_backbone: Backbone-building +3. build_loss: Loss-function-building +4. features: Feature-extraction +5. classifier: Classification +6. get_losses: Loss-computation +7. get_train_metrics: Training-metrics-computation +8. get_test_metrics: Testing-metrics-computation +9. forward: Forward-propagation + +Reference: +@inproceedings{xu2023tall, + title={TALL: Thumbnail Layout for Deepfake Video Detection}, + author={Xu, Yuting and Liang, Jian and Jia, Gengyun and Yang, Ziming and Zhang, Yanhao and He, Ran}, + booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, + pages={22658--22668}, + year={2023} +} +""" + +import logging +import math +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint +from einops import rearrange +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from torch.hub import load_state_dict_from_url + +from .base_detector import AbstractDetector +from base_miner.DFB.detectors import DETECTOR +from base_miner.DFB.loss import LOSSFUNC +from base_miner.DFB.metrics.base_metrics_class import calculate_metrics_for_train + +_logger = logging.getLogger(__name__) + + +@DETECTOR.register_module(module_name='tall') +class TALLDetector(AbstractDetector): + def __init__(self, config, device='cuda'): + super().__init__() + self.device = device + self.model = self.build_backbone(config).to(self.device) + self.loss_func = self.build_loss(config) + + def build_backbone(self, config): + model_kwargs = dict( + num_classes=config['num_classes'], + embed_dim=config['embed_dim'], + mlp_ratio=config['mlp_ratio'], + patch_size=config['patch_size'], + window_size=config['window_size'], + depths=config['depths'], + num_heads=config['num_heads'], + ape=config['ape'], + thumbnail_rows=config['thumbnail_rows'], + drop_rate=config['drop_rate'], + drop_path_rate=config['drop_path_rate'], + use_checkpoint=False, + bottleneck=False, + duration=config['clip_size'], + device=self.device + ) + + default_cfg = { + 'url': config['pretrained'], + 'num_classes': 1000, + 'input_size': (3, 224, 224), + 'pool_size': None, + 'crop_pct': .9, + 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, + 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.proj', + 'classifier': 'head', + } + + backbone = SwinTransformer(img_size=config['resolution'], **model_kwargs) + backbone.default_cfg = default_cfg + + load_pretrained( + backbone, + num_classes=config['num_classes'], + in_chans=model_kwargs.get('in_chans', 3), + filter_fn=_conv_filter, + img_size=config['resolution'], + pretrained_window_size=7, + pretrained_model='' + ) + + return backbone + + def build_loss(self, config): + loss_class = LOSSFUNC[config['loss_func']] + loss_func = loss_class() + return loss_func + + def features(self, data_dict: dict) -> torch.tensor: + bs, t, c, h, w = data_dict['image'].shape + inputs = data_dict['image'].view(bs, t * c, h, w) + pred = self.model(inputs) + return pred + + def classifier(self, features: torch.tensor): + pass + + def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'].long() + pred = pred_dict['cls'] + loss = self.loss_func(pred, label) + loss_dict = {'overall': loss} + return loss_dict + + def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: + label = data_dict['label'] + pred = pred_dict['cls'] + auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) + metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} + return metric_batch_dict + + def forward(self, data_dict: dict, inference=False) -> dict: + pred = self.features(data_dict) + prob = torch.softmax(pred, dim=1)[:, 1] + pred_dict = {'cls': pred, 'prob': prob, 'feat': prob} + return pred_dict + + +class Mlp(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0. + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """Partition input tensor into windows. + + Args: + x: Input tensor of shape (B, H, W, C) + window_size (int): Size of each window + + Returns: + windows: Output tensor of shape (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous() + windows = windows.view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """Reverse window partitioning. + + Args: + windows: Input tensor of shape (num_windows*B, window_size, window_size, C) + window_size (int): Size of each window + H (int): Height of original image + W (int): Width of original image + + Returns: + x: Output tensor of shape (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, + window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + """Window based multi-head self attention (W-MSA) module with relative position bias. + + It supports both shifted and non-shifted window attention. + + Args: + dim (int): Number of input channels + window_size (tuple[int]): Height and width of window + num_heads (int): Number of attention heads + qkv_bias (bool, optional): Add learnable bias to query, key, value. + Default: True + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. + Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, + qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # Define parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), + num_heads)) + + # Get pair-wise relative position index for each token in window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) + coords_flatten = torch.flatten(coords, 1) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += self.window_size[0] - 1 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """Forward pass. + + Args: + x: Input features with shape (num_windows*B, N, C) + mask: (0/-inf) mask with shape (num_windows, Wh*Ww, Wh*Ww) or None + + Returns: + Output tensor after attention + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads) + qkv = qkv.permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], -1) + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + attn = attn + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + """Extra string representation.""" + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + """Calculate FLOPs for one window.""" + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + """Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__( + self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm, bottleneck=False, use_checkpoint=False + ): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + self.use_checkpoint = use_checkpoint + + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop + ) + + if self.shift_size > 0: + # calculate attention mask for SW-MSA + H, W = self.input_resolution + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None) + ) + w_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None) + ) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + # nW, window_size, window_size, 1 + mask_windows = window_partition(img_mask, self.window_size) + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill( + attn_mask != 0, float(-100.0) + ).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def forward_attn(self, x): + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll( + x, + shifts=(-self.shift_size, -self.shift_size), + dims=(1, 2) + ) + else: + shifted_x = x + + # partition windows + # nW*B, window_size, window_size, C + x_windows = window_partition(shifted_x, self.window_size) + # nW*B, window_size*window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) + + # W-MSA/SW-MSA + # nW*B, window_size*window_size, C + attn_windows = self.attn(x_windows, mask=self.attn_mask) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll( + shifted_x, + shifts=(self.shift_size, self.shift_size), + dims=(1, 2) + ) + else: + x = shifted_x + x = x.view(B, H * W, C) + + return x + + def forward_mlp(self, x): + return self.drop_path(self.mlp(self.norm2(x))) + + def forward(self, x): + shortcut = x + if self.use_checkpoint: + x = checkpoint.checkpoint(self.forward_attn, x) + else: + x = self.forward_attn(x) + x = shortcut + self.drop_path(x) + + if self.use_checkpoint: + x = x + checkpoint.checkpoint(self.forward_mlp, x) + else: + x = x + self.forward_mlp(x) + + return x + + def extra_repr(self) -> str: + return ( + f"dim={self.dim}, input_resolution={self.input_resolution}, " + f"num_heads={self.num_heads}, window_size={self.window_size}, " + f"shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + ) + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Module): + """Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """Forward pass. + + Args: + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__( + self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + bottleneck=False + ): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock( + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + bottleneck=bottleneck if i == depth - 1 else False, + use_checkpoint=use_checkpoint + ) for i in range(depth) + ]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample( + input_resolution, + dim=dim, + norm_layer=norm_layer + ) + else: + self.downsample = None + + def forward(self, x): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class PatchEmbed(nn.Module): + r"""Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__( + self, + img_size=(224, 224), + patch_size=4, + in_chans=3, + embed_dim=96, + norm_layer=None + ): + super().__init__() + # img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [ + img_size[0] // patch_size[0], + img_size[1] // patch_size[1] + ] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d( + in_chans, + embed_dim, + kernel_size=patch_size, + stride=patch_size + ) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * ( + self.patch_size[0] * self.patch_size[1] + ) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops + + +class SwinTransformer(nn.Module): + r""" Swin Transformer + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + """ + + def __init__( + self, duration=8, img_size=224, patch_size=4, in_chans=3, num_classes=1000, + embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, thumbnail_rows=1, bottleneck=False, device='cuda', **kwargs + ): + super().__init__() + + self.duration = duration # 4 + self.num_classes = num_classes # 2 + self.num_layers = len(depths) # [2, 2, 18, 2] + self.embed_dim = embed_dim # 128 + self.ape = ape # True + self.patch_norm = patch_norm # False + self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) + self.mlp_ratio = mlp_ratio # 4 = default + self.thumbnail_rows = thumbnail_rows # 2 + self.device = device + + self.img_size = img_size # 224 + self.window_size = ([window_size for _ in depths] if not isinstance(window_size, list) + else window_size) + + self.frame_padding = self.duration % thumbnail_rows # 0 + if self.frame_padding != 0: + self.frame_padding = self.thumbnail_rows - self.frame_padding + self.duration += self.frame_padding + + # split image into non-overlapping patches + thumbnail_dim = (thumbnail_rows, self.duration // thumbnail_rows) # (2, 2) + thumbnail_size = (img_size * thumbnail_dim[0], img_size * thumbnail_dim[1]) + + self.patch_embed = PatchEmbed( + img_size=(img_size, img_size), + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None + ) + num_patches = self.patch_embed.num_patches # 16 + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution # [56, 56] + + # absolute position embedding + if self.ape: # True + self.frame_pos_embed = nn.Parameter(torch.zeros(1, self.duration, embed_dim)) + trunc_normal_(self.frame_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer( + dim=int(embed_dim * 2 ** i_layer), + input_resolution=( + patches_resolution[0] // (2 ** i_layer), + patches_resolution[1] // (2 ** i_layer)), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=self.window_size[i_layer], + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint, + bottleneck=bottleneck + ) + self.layers.append(layer) + + self.norm = norm_layer(self.num_features) + self.avgpool = nn.AdaptiveAvgPool1d(1) + self.head = (nn.Linear(self.num_features, num_classes) + if num_classes > 0 else nn.Identity()) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed', 'frame_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def create_thumbnail(self, x): + input_size = x.shape[-2:] + if input_size != to_2tuple(self.img_size): + x = nn.functional.interpolate(x, size=self.img_size, mode='bilinear') + x = rearrange(x, 'b (th tw c) h w -> b c (th h) (tw w)', + th=self.thumbnail_rows, c=3) + return x + + def pad_frames(self, x): + frame_num = self.duration - self.frame_padding + x = x.view((-1, 3 * frame_num) + x.size()[2:]) + x_padding = torch.zeros((x.shape[0], 3 * self.frame_padding) + + x.size()[2:]).to(self.device) + x = torch.cat((x, x_padding), dim=1) + assert x.shape[1] == 3 * self.duration, ( + 'frame number %d not the same as adjusted input size %d' % + (x.shape[1], 3 * self.duration)) + + return x + + # need to find a better way to do this, maybe torch.fold? + def create_image_pos_embed(self): + img_rows, img_cols = self.patches_resolution # (56, 56) + _, _, T = self.frame_pos_embed.shape # (1, 4, embed) + rows = img_rows // self.thumbnail_rows # 28 + cols = img_cols // (self.duration // self.thumbnail_rows) # 28 + img_pos_embed = torch.zeros(img_rows, img_cols, T).to(self.device) # [56, 56, embed] + for i in range(self.duration): + r_indx = (i // self.thumbnail_rows) * rows + c_indx = (i % self.thumbnail_rows) * cols + img_pos_embed[r_indx:r_indx + rows, c_indx:c_indx + cols] = self.frame_pos_embed[0, i] + + return img_pos_embed.reshape(-1, T) # [56*56, embed] + + def forward_features(self, x): + if self.frame_padding > 0: + x = self.pad_frames(x) + else: + x = x.view((-1, 3 * self.duration) + x.size()[2:]) + + x = self.create_thumbnail(x) + x = nn.functional.interpolate(x, size=self.img_size, mode='bilinear') # [B, 3, 224, 224] + + x = self.patch_embed(x) # [B, 56*56, embed] + if self.ape: + img_pos_embed = self.create_image_pos_embed() + x = x + img_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x) + + x = self.norm(x) # B L C + x = self.avgpool(x.transpose(1, 2)) # B C 1 + x = torch.flatten(x, 1) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + def flops(self): + flops = 0 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += (self.num_features * self.patches_resolution[0] * + self.patches_resolution[1] // (2 ** self.num_layers)) + flops += self.num_features * self.num_classes + return flops + +def load_pretrained( + model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, img_size=224, + num_patches=196, pretrained_window_size=7, pretrained_model="", strict=True +): + if cfg is None: + cfg = getattr(model, 'default_cfg') + if cfg is None or 'url' not in cfg or not cfg['url']: + _logger.warning("Pretrained model URL is invalid, using random initialization.") + return + + if len(pretrained_model) == 0: + state_dict = load_state_dict_from_url(cfg['url'], map_location='cpu') + else: + try: + state_dict = torch.load(pretrained_model)['model'] + except: + state_dict = torch.load(pretrained_model) + + if filter_fn is not None: + state_dict = filter_fn(state_dict) + + if in_chans == 1: + conv1_name = cfg['first_conv'] + _logger.info( + 'Converting first conv (%s) pretrained weights from 3 to 1 channel', + conv1_name + ) + conv1_weight = state_dict[conv1_name + '.weight'] + conv1_type = conv1_weight.dtype + conv1_weight = conv1_weight.float() + O, I, J, K = conv1_weight.shape + if I > 3: + assert conv1_weight.shape[1] % 3 == 0 + # For models with space2depth stems + conv1_weight = conv1_weight.reshape(O, I // 3, 3, J, K) + conv1_weight = conv1_weight.sum(dim=2, keepdim=False) + else: + conv1_weight = conv1_weight.sum(dim=1, keepdim=True) + conv1_weight = conv1_weight.to(conv1_type) + state_dict[conv1_name + '.weight'] = conv1_weight + elif in_chans != 3: + conv1_name = cfg['first_conv'] + conv1_weight = state_dict[conv1_name + '.weight'] + conv1_type = conv1_weight.dtype + conv1_weight = conv1_weight.float() + O, I, J, K = conv1_weight.shape + if I != 3: + _logger.warning( + 'Deleting first conv (%s) from pretrained weights.', + conv1_name + ) + del state_dict[conv1_name + '.weight'] + strict = False + else: + _logger.info( + 'Repeating first conv (%s) weights in channel dim.', + conv1_name + ) + repeat = int(math.ceil(in_chans / 3)) + conv1_weight = conv1_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :] + conv1_weight *= (3 / float(in_chans)) + conv1_weight = conv1_weight.to(conv1_type) + state_dict[conv1_name + '.weight'] = conv1_weight + + classifier_name = cfg['classifier'] + if num_classes == 1000 and cfg['num_classes'] == 1001: + # special case for imagenet trained models with extra background class + classifier_weight = state_dict[classifier_name + '.weight'] + state_dict[classifier_name + '.weight'] = classifier_weight[1:] + classifier_bias = state_dict[classifier_name + '.bias'] + state_dict[classifier_name + '.bias'] = classifier_bias[1:] + elif num_classes != cfg['num_classes']: + # discard fully connected for all other differences + del state_dict['model'][classifier_name + '.weight'] + del state_dict['model'][classifier_name + '.bias'] + strict = False + ''' + ## Resizing the positional embeddings in case they don't match + if img_size != cfg['input_size'][1]: + pos_embed = state_dict['pos_embed'] + cls_pos_embed = pos_embed[0, 0, :].unsqueeze(0).unsqueeze(1) + other_pos_embed = pos_embed[0, 1:, :].unsqueeze(0).transpose(1, 2) + new_pos_embed = F.interpolate(other_pos_embed, size=(num_patches), mode='nearest') + new_pos_embed = new_pos_embed.transpose(1, 2) + new_pos_embed = torch.cat((cls_pos_embed, new_pos_embed), 1) + state_dict['pos_embed'] = new_pos_embed + ''' + + # remove window_size related parameters + window_size = (model.window_size)[0] + print(pretrained_window_size, window_size) + + new_state_dict = state_dict['model'].copy() + for key in state_dict['model']: + if 'attn_mask' in key: + del new_state_dict[key] + + if 'relative_position_index' in key: + del new_state_dict[key] + + # resize it + if 'relative_position_bias_table' in key: + pretrained_table = state_dict['model'][key] + pretrained_table_size = int(math.sqrt(pretrained_table.shape[0])) + table_size = int(math.sqrt(model.state_dict()[key].shape[0])) + if pretrained_table_size != table_size: + table = pretrained_table.permute(1, 0).view(1, -1, pretrained_table_size, pretrained_table_size) + table = nn.functional.interpolate(table, size=table_size, mode='bilinear') + table = table.view(-1, table_size * table_size).permute(1, 0) + new_state_dict[key] = table + + for key in model.state_dict(): + if 'bottleneck_norm' in key: + attn_key = key.replace('bottleneck_norm', 'norm1') + # print (key, attn_key) + new_state_dict[key] = new_state_dict[attn_key] + + print('loading weights....') + ## Loading the weights + model.load_state_dict(new_state_dict, strict=False) + + +def _conv_filter(state_dict, patch_size=4): + """ convert patch embedding weight from manual patchify + linear proj to conv""" + out_dict = {} + for k, v in state_dict.items(): + if 'patch_embed.proj.weight' in k: + if v.shape[-1] != patch_size: + patch_size = v.shape[-1] + v = v.reshape((v.shape[0], 3, patch_size, patch_size)) + out_dict[k] = v + return out_dict \ No newline at end of file diff --git a/base_miner/UCF/detectors/ucf_detector.py b/base_miner/DFB/detectors/ucf_detector.py similarity index 95% rename from base_miner/UCF/detectors/ucf_detector.py rename to base_miner/DFB/detectors/ucf_detector.py index 4ffc9b39..7c257a2e 100644 --- a/base_miner/UCF/detectors/ucf_detector.py +++ b/base_miner/DFB/detectors/ucf_detector.py @@ -43,10 +43,10 @@ from metrics.base_metrics_class import calculate_metrics_for_train -from .base_detector import AbstractDetector -from UCF.detectors import DETECTOR -from networks import BACKBONE -from loss import LOSSFUNC +from DFB.detectors.base_detector import AbstractDetector +from DFB.detectors import DETECTOR +from DFB.networks import BACKBONE +from DFB.loss import LOSSFUNC logger = logging.getLogger(__name__) @@ -99,20 +99,23 @@ def __init__(self, config): ) def build_backbone(self, config): - current_dir = os.path.dirname(os.path.abspath(__file__)) - pretrained_path = os.path.join(current_dir, config['pretrained']) # prepare the backbone backbone_class = BACKBONE[config['backbone_name']] model_config = config['backbone_config'] backbone = backbone_class(model_config) - # if donot load the pretrained weights, fail to get good results - state_dict = torch.load(pretrained_path) - for name, weights in state_dict.items(): - if 'pointwise' in name: - state_dict[name] = weights.unsqueeze(-1).unsqueeze(-1) - state_dict = {k:v for k, v in state_dict.items() if 'fc' not in k} - backbone.load_state_dict(state_dict, False) - logger.info('Load pretrained model successfully!') + + if 'pretrained' in config: + pretrained_path = config['pretrained'] + if isinstance(pretrained_path, dict) and 'local_path' in pretrained_path: + pretrained_path = pretrained_path['local_path'] + + state_dict = torch.load(pretrained_path) + for name, weights in state_dict.items(): + if 'pointwise' in name: + state_dict[name] = weights.unsqueeze(-1).unsqueeze(-1) + state_dict = {k:v for k, v in state_dict.items() if 'fc' not in k} + backbone.load_state_dict(state_dict, False) + logger.info('Load pretrained model successfully!') return backbone def build_loss(self, config): diff --git a/base_miner/UCF/logger.py b/base_miner/DFB/logger.py similarity index 100% rename from base_miner/UCF/logger.py rename to base_miner/DFB/logger.py diff --git a/base_miner/UCF/loss/__init__.py b/base_miner/DFB/loss/__init__.py similarity index 100% rename from base_miner/UCF/loss/__init__.py rename to base_miner/DFB/loss/__init__.py diff --git a/base_miner/UCF/loss/abstract_loss_func.py b/base_miner/DFB/loss/abstract_loss_func.py similarity index 100% rename from base_miner/UCF/loss/abstract_loss_func.py rename to base_miner/DFB/loss/abstract_loss_func.py diff --git a/base_miner/UCF/loss/contrastive_regularization.py b/base_miner/DFB/loss/contrastive_regularization.py similarity index 100% rename from base_miner/UCF/loss/contrastive_regularization.py rename to base_miner/DFB/loss/contrastive_regularization.py diff --git a/base_miner/UCF/loss/cross_entropy_loss.py b/base_miner/DFB/loss/cross_entropy_loss.py similarity index 100% rename from base_miner/UCF/loss/cross_entropy_loss.py rename to base_miner/DFB/loss/cross_entropy_loss.py diff --git a/base_miner/UCF/loss/l1_loss.py b/base_miner/DFB/loss/l1_loss.py similarity index 100% rename from base_miner/UCF/loss/l1_loss.py rename to base_miner/DFB/loss/l1_loss.py diff --git a/base_miner/UCF/metrics/__init__.py b/base_miner/DFB/metrics/__init__.py similarity index 100% rename from base_miner/UCF/metrics/__init__.py rename to base_miner/DFB/metrics/__init__.py diff --git a/base_miner/UCF/metrics/base_metrics_class.py b/base_miner/DFB/metrics/base_metrics_class.py similarity index 100% rename from base_miner/UCF/metrics/base_metrics_class.py rename to base_miner/DFB/metrics/base_metrics_class.py diff --git a/base_miner/UCF/metrics/registry.py b/base_miner/DFB/metrics/registry.py similarity index 100% rename from base_miner/UCF/metrics/registry.py rename to base_miner/DFB/metrics/registry.py diff --git a/base_miner/UCF/metrics/utils.py b/base_miner/DFB/metrics/utils.py similarity index 100% rename from base_miner/UCF/metrics/utils.py rename to base_miner/DFB/metrics/utils.py diff --git a/base_miner/UCF/networks/__init__.py b/base_miner/DFB/networks/__init__.py similarity index 100% rename from base_miner/UCF/networks/__init__.py rename to base_miner/DFB/networks/__init__.py diff --git a/base_miner/UCF/networks/xception.py b/base_miner/DFB/networks/xception.py similarity index 100% rename from base_miner/UCF/networks/xception.py rename to base_miner/DFB/networks/xception.py diff --git a/base_miner/UCF/optimizor/LinearLR.py b/base_miner/DFB/optimizor/LinearLR.py similarity index 100% rename from base_miner/UCF/optimizor/LinearLR.py rename to base_miner/DFB/optimizor/LinearLR.py diff --git a/base_miner/UCF/optimizor/SAM.py b/base_miner/DFB/optimizor/SAM.py similarity index 100% rename from base_miner/UCF/optimizor/SAM.py rename to base_miner/DFB/optimizor/SAM.py diff --git a/base_miner/UCF/train_detector.py b/base_miner/DFB/train_detector.py similarity index 54% rename from base_miner/UCF/train_detector.py rename to base_miner/DFB/train_detector.py index 9e877b1a..83a108db 100644 --- a/base_miner/UCF/train_detector.py +++ b/base_miner/DFB/train_detector.py @@ -34,35 +34,45 @@ import torch.distributed as dist from torch.utils.data import DataLoader -from optimizor.SAM import SAM -from optimizor.LinearLR import LinearDecayLR - -from trainer.trainer import Trainer -from detectors import DETECTOR -from metrics.utils import parse_metric_for_print -from logger import create_logger, RankFilter +from base_miner.DFB.optimizor.SAM import SAM +from base_miner.DFB.optimizor.LinearLR import LinearDecayLR +from base_miner.DFB.config.helpers import save_config +from base_miner.DFB.trainer.trainer import Trainer +from base_miner.DFB.detectors import DETECTOR +from base_miner.DFB.metrics.utils import parse_metric_for_print +from base_miner.DFB.logger import create_logger, RankFilter from huggingface_hub import hf_hub_download # BitMind imports (not from original Deepfake Bench repo) -from bitmind.utils.data import load_and_split_datasets, create_real_fake_datasets -from bitmind.image_transforms import base_transforms, random_aug_transforms, ucf_transforms -from bitmind.constants import DATASET_META, FACE_TRAINING_DATASET_META -from config.constants import ( - CONFIG_PATH, +from base_miner.datasets.util import load_and_split_datasets, create_real_fake_datasets +from base_miner.constants import VIDEO_DATASETS, IMAGE_DATASETS, FACE_IMAGE_DATASETS +from bitmind.utils.image_transforms import ( + get_base_transforms, + get_random_augmentations, + get_ucf_base_transforms, + get_tall_base_transforms +) +from base_miner.DFB.config.constants import ( + CONFIG_PATHS, WEIGHTS_DIR, - HF_REPO, - BACKBONE_CKPT + HF_REPOS ) +TRANSFORM_FNS = { + 'UCF': get_ucf_base_transforms, + 'TALL': get_tall_base_transforms +} + parser = argparse.ArgumentParser(description='Process some paths.') -parser.add_argument('--detector_path', type=str, default=CONFIG_PATH, help='path to detector YAML file') +parser.add_argument('--detector', type=str, choices=['UCF', 'TALL'], help='Detector name') +parser.add_argument('--modality', type=str, default='image', choices=['image', 'video']) parser.add_argument('--faces_only', dest='faces_only', action='store_true', default=False) parser.add_argument('--no-save_ckpt', dest='save_ckpt', action='store_false', default=True) parser.add_argument('--no-save_feat', dest='save_feat', action='store_false', default=True) parser.add_argument("--ddp", action='store_true', default=False) -parser.add_argument('--device', type=str, choices=['cpu', 'gpu'], default='gpu', +parser.add_argument('--device', type=str, default='cuda', help='Specify whether to use CPU or GPU. Defaults to GPU if available.') parser.add_argument('--gpu_id', type=int, default=0, help='Specify the GPU ID to use if using GPU. Defaults to 0.') parser.add_argument('--workers', type=int, default=os.cpu_count() - 1, @@ -71,58 +81,6 @@ args = parser.parse_args() -def set_device(device=args.device, gpu_id=args.gpu_id): - """ - Determine the device to use based on user input and system availability. - - Parameters: - device_arg (str, optional): The device specified by the user ('cpu', 'gpu', or None). - Defaults to None, in which case it automatically chooses. - gpu_id (int, optional): The specific GPU ID to set if using a GPU (defaults to 0). - - Returns: - torch.device: The device to be used (either 'cuda' or 'cpu'). - """ - if device == 'cpu': - return torch.device("cpu") - elif device == 'gpu': - if torch.cuda.is_available(): - torch.cuda.set_device(gpu_id) # Set the GPU ID - return torch.device(f"cuda:{gpu_id}") - else: - print("Warning: GPU specified but not available. Falling back to CPU.") - return torch.device("cpu") - else: - # Default: Use GPU if available, otherwise fall back to CPU - if torch.cuda.is_available(): - torch.cuda.set_device(gpu_id) - return torch.device(f"cuda:{gpu_id}") - else: - return torch.device("cpu") - - -def ensure_backbone_is_available(logger, - weights_dir=WEIGHTS_DIR, - model_filename=BACKBONE_CKPT, - hugging_face_repo_name=HF_REPO): - - destination_path = Path(weights_dir) / Path(model_filename) - if not destination_path.parent.exists(): - destination_path.parent.mkdir(parents=True, exist_ok=True) - logger.info(f"Created directory {destination_path.parent}.") - if not destination_path.exists(): - model_path = hf_hub_download(hugging_face_repo_name, model_filename) - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model = torch.load(model_path, map_location=device) - torch.save(model, destination_path) - del model - if torch.cuda.is_available(): - torch.cuda.empty_cache() - logger.info(f"Downloaded backbone {model_filename} to {destination_path}.") - else: - logger.info(f"{model_filename} backbone already present at {destination_path}.") - - def init_seed(config): if config['manualSeed'] is None: config['manualSeed'] = random.randint(1, 10000) @@ -132,28 +90,13 @@ def init_seed(config): torch.cuda.manual_seed_all(config['manualSeed']) -def custom_collate_fn(batch): - images, labels, source_labels = zip(*batch) - - images = torch.stack(images, dim=0) # Stack image tensors into a single tensor - labels = torch.LongTensor(labels) - source_labels = torch.LongTensor(source_labels) - - data_dict = { - 'image': images, - 'label': labels, - 'label_spe': source_labels, - 'landmark': None, - 'mask': None - } - return data_dict - - def prepare_datasets(config, logger): start_time = log_start_time(logger, "Loading and splitting individual datasets") - fake_datasets = load_and_split_datasets(config['dataset_meta']['fake']) - real_datasets = load_and_split_datasets(config['dataset_meta']['real']) + fake_datasets = load_and_split_datasets( + config['dataset_meta']['fake'], modality=config['modality'], split_transforms=config['split_transforms']) + real_datasets = load_and_split_datasets( + config['dataset_meta']['real'], modality=config['modality'], split_transforms=config['split_transforms']) log_finish_time(logger, "Loading and splitting individual datasets", start_time) @@ -161,10 +104,7 @@ def prepare_datasets(config, logger): train_dataset, val_dataset, test_dataset, source_label_mapping = create_real_fake_datasets( real_datasets, fake_datasets, - config['split_transforms']['train'], - config['split_transforms']['validation'], - config['split_transforms']['test'], - source_labels=True, + source_labels=True, # TODO UCF Only group_sources_by_name=True) log_finish_time(logger, "Creating real fake dataset splits", start_time) @@ -175,7 +115,7 @@ def prepare_datasets(config, logger): shuffle=True, num_workers=config['workers'], drop_last=True, - collate_fn=custom_collate_fn) + collate_fn=train_dataset.collate_fn) val_loader = torch.utils.data.DataLoader( val_dataset, @@ -183,7 +123,7 @@ def prepare_datasets(config, logger): shuffle=True, num_workers=config['workers'], drop_last=True, - collate_fn=custom_collate_fn) + collate_fn=val_dataset.collate_fn) test_loader = torch.utils.data.DataLoader( test_dataset, @@ -191,7 +131,7 @@ def prepare_datasets(config, logger): shuffle=True, num_workers=config['workers'], drop_last=True, - collate_fn=custom_collate_fn) + collate_fn=train_dataset.collate_fn) print(f"Train size: {len(train_loader.dataset)}") print(f"Validation size: {len(val_loader.dataset)}") @@ -284,137 +224,53 @@ def log_finish_time(logger, process_name, start_time): logger.info(f"{process_name} Elapsed Time: {int(hours)} hours, {int(minutes)} minutes, {seconds:.2f} seconds") -def save_config(config, outputs_dir): - """ - Saves a config dictionary as both a pickle file and a YAML file, ensuring only basic types are saved. - Also, lists like 'mean' and 'std' are saved in flow style (on a single line). - - Args: - config (dict): The configuration dictionary to save. - outputs_dir (str): The directory path where the files will be saved. - """ - - def is_basic_type(value): - """ - Check if a value is a basic data type that can be saved in YAML. - Basic types include int, float, str, bool, list, and dict. - """ - return isinstance(value, (int, float, str, bool, list, dict, type(None))) - - def filter_dict(data_dict): - """ - Recursively filter out any keys from the dictionary whose values contain non-basic types (e.g., objects). - """ - if not isinstance(data_dict, dict): - return data_dict - - filtered_dict = {} - for key, value in data_dict.items(): - if isinstance(value, dict): - # Recursively filter nested dictionaries - nested_dict = filter_dict(value) - if nested_dict: # Only add non-empty dictionaries - filtered_dict[key] = nested_dict - elif is_basic_type(value): - # Add if the value is a basic type - filtered_dict[key] = value - else: - # Skip the key if the value is not a basic type (e.g., an object) - print(f"Skipping key '{key}' because its value is of type {type(value)}") - - return filtered_dict - - def save_dict_to_yaml(data_dict, file_path): - """ - Saves a dictionary to a YAML file, excluding any keys where the value is an object or contains an object. - Additionally, ensures that specific lists (like 'mean' and 'std') are saved in flow style. - - Args: - data_dict (dict): The dictionary to save. - file_path (str): The local file path where the YAML file will be saved. - """ - - # Custom representer for lists to force flow style (compact lists) - class FlowStyleList(list): - pass - - def flow_style_list_representer(dumper, data): - return dumper.represent_sequence('tag:yaml.org,2002:seq', data, flow_style=True) - - yaml.add_representer(FlowStyleList, flow_style_list_representer) - - # Preprocess specific lists to be in flow style - if 'mean' in data_dict: - data_dict['mean'] = FlowStyleList(data_dict['mean']) - if 'std' in data_dict: - data_dict['std'] = FlowStyleList(data_dict['std']) - - try: - # Filter the dictionary - filtered_dict = filter_dict(data_dict) - - # Save the filtered dictionary as YAML - with open(file_path, 'w') as f: - yaml.dump(filtered_dict, f, default_flow_style=False) # Save with default block style except for FlowStyleList - print(f"Filtered dictionary successfully saved to {file_path}") - except Exception as e: - print(f"Error saving dictionary to YAML: {e}") - - # Save as YAML - save_dict_to_yaml(config, outputs_dir + '/config.yaml') - - def main(): if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() + detector_config_path = CONFIG_PATHS[args.detector] + # parse options and load config - with open(args.detector_path, 'r') as f: + with open(detector_config_path, 'r') as f: config = yaml.safe_load(f) - with open(os.getcwd() + '/config/train_config.yaml', 'r') as f: - config2 = yaml.safe_load(f) - if 'label_dict' in config: - config2['label_dict']=config['label_dict'] - config.update(config2) + config['log_dir'] = os.getcwd() + config['device'] = args.device + config['modality'] = args.modality config['workers'] = args.workers - config['device'] = set_device(args.device, args.gpu_id) config['gpu_id'] = args.gpu_id - if config['dry_run']: - config['nEpochs'] = 0 - config['save_feat'] = False - if args.epochs: config['nEpochs'] = args.epochs + tforms = TRANSFORM_FNS.get(args.detector, None)((256, 256)) config['split_transforms'] = { - 'train': ucf_transforms, - 'validation': ucf_transforms, - 'test': ucf_transforms + 'train': tforms, + 'validation': tforms, + 'test': tforms } - config['dataset_meta'] = FACE_TRAINING_DATASET_META if args.faces_only else DATASET_META + if config['modality'] == 'video': + config['dataset_meta'] = VIDEO_DATASETS + elif config['modality'] == 'image': + if args.faces_only: + config['dataset_meta'] = FACE_IMAGE_DATASETS + else: + config['dataset_meta'] = IMAGE_DATASETS + dataset_names = [item["path"] for datasets in config['dataset_meta'].values() for item in datasets] config['train_dataset'] = dataset_names config['save_ckpt'] = args.save_ckpt config['save_feat'] = args.save_feat - - if config['lmdb']: - config['dataset_json_folder'] = 'preprocessing/dataset_json_v3' - + # create logger timenow=datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S') - - outputs_dir = os.path.join( - config['log_dir'], - config['model_name'] + '_' + timenow - ) - + outputs_dir = os.path.join(config['log_dir'], 'logs', config['model_name'] + '_' + timenow) + config['log_dir'] = outputs_dir + os.makedirs(outputs_dir, exist_ok=True) logger = create_logger(os.path.join(outputs_dir, 'training.log')) - config['log_dir'] = outputs_dir logger.info('Save log to {}'.format(outputs_dir)) config['ddp']= args.ddp @@ -437,29 +293,37 @@ def main(): ) logger.addFilter(RankFilter(0)) - ensure_backbone_is_available( - logger=logger, - model_filename=config['pretrained'].split('/')[-1], - hugging_face_repo_name='bitmind/bm-ucf' - ) - - # prepare the model (detector) + # download weights if huggingface repo provided. + # Note: TALL currently skips this and downloads from github + pretrained_config = config.get('pretrained', {}) + if not isinstance(pretrained_config, str): + hf_repo = pretrained_config.get('hf_repo') + weights_filename = pretrained_config.get('filename') + if hf_repo and weights_filename: + local_path = Path(WEIGHTS_DIR) / weights_filename + if not local_path.exists(): + model_path = hf_hub_download( + repo_id=hf_repo, + filename=weights_filename, + local_dir=WEIGHTS_DIR + ) + logger.info(f"Downloaded {hf_repo}/{weights_filename} to {model_path}") + else: + model_path = local_path + logger.info(f"{model_path} exists, skipping download") + config['pretrained']['local_path'] = str(model_path) + else: + logger.info("Pretrain config is a url, falling back to detector-specific download") + + # prepare model and trainer model_class = DETECTOR[config['model_name']] model = model_class(config).to(config['device']) - - # prepare the optimizer - optimizer = choose_optimizer(model, config) - # prepare the scheduler + optimizer = choose_optimizer(model, config) scheduler = choose_scheduler(config, optimizer) - - # prepare the metric metric_scoring = choose_metric(config) - - # prepare the trainer trainer = Trainer(config, model, config['device'], optimizer, scheduler, logger, metric_scoring) - # print configuration logger.info("--------------- Configuration ---------------") params_string = "Parameters: \n" for key, value in config.items(): @@ -474,10 +338,10 @@ def main(): for epoch in range(config['start_epoch'], config['nEpochs'] + 1): trainer.model.epoch = epoch best_metric = trainer.train_epoch( - epoch, - train_data_loader=train_loader, - validation_data_loaders={'val':val_loader} - ) + epoch, + train_data_loader=train_loader, + validation_data_loaders={'val':val_loader} + ) if best_metric is not None: logger.info(f"===> Epoch[{epoch}] end with validation {metric_scoring}: {parse_metric_for_print(best_metric)}!") @@ -488,10 +352,7 @@ def main(): start_time = log_start_time(logger, "Test") trainer.eval(eval_data_loaders={'test':test_loader}, eval_stage="test") log_finish_time(logger, "Test", start_time) - - # update - if 'svdd' in config['model_name']: - model.update_R(epoch) + if scheduler is not None: scheduler.step() diff --git a/base_miner/UCF/trainer/trainer.py b/base_miner/DFB/trainer/trainer.py similarity index 98% rename from base_miner/UCF/trainer/trainer.py rename to base_miner/DFB/trainer/trainer.py index b8b0dff7..d4287c84 100644 --- a/base_miner/UCF/trainer/trainer.py +++ b/base_miner/DFB/trainer/trainer.py @@ -227,7 +227,7 @@ def train_epoch( losses, predictions=self.train_step(data_dict) # update learning rate - if 'SWA' in self.config and self.config['SWA'] and epoch>self.config['swa_start']: + if self.config.get('SWA', False) and epoch>self.config['swa_start']: self.swa_model.update_parameters(self.model) # compute training metric for each batch data @@ -246,7 +246,7 @@ def train_epoch( # run tensorboard to visualize the training process if iteration % 300 == 0 and self.config['gpu_id']==0: - if self.config['SWA'] and (epoch>self.config['swa_start'] or self.config['dry_run']): + if self.config.get('SWA', False) and (epoch>self.config['swa_start'] or self.config['dry_run']): self.scheduler.step() # info for loss loss_str = f"Iter: {step_cnt} " @@ -331,7 +331,6 @@ def eval_one_dataset(self, data_loader): data_dict[key] = data_dict[key].cuda() # model forward without considering gradient computation predictions = self.inference(data_dict) #dict with keys cls, feat - label_lists += list(data_dict['label'].cpu().detach().numpy()) # Get the predicted class for each sample in the batch _, predicted_classes = torch.max(predictions['cls'], dim=1) diff --git a/base_miner/NPR/train_detector.py b/base_miner/NPR/train_detector.py index bef1f668..c2d74525 100644 --- a/base_miner/NPR/train_detector.py +++ b/base_miner/NPR/train_detector.py @@ -1,6 +1,4 @@ from tensorboardX import SummaryWriter -from validate import validate -from networks.trainer import Trainer from torch.utils.data import DataLoader import numpy as np import os @@ -8,10 +6,12 @@ import random import torch -from bitmind.constants import DATASET_META -from bitmind.image_transforms import base_transforms, random_aug_transforms -from bitmind.utils.data import load_and_split_datasets, create_real_fake_datasets -from options import TrainOptions +from base_miner.NPR.validate import validate +from base_miner.NPR.networks.trainer import Trainer +from base_miner.constants import IMAGE_DATASETS as DATASET_META +from base_miner.NPR.options import TrainOptions +from bitmind.utils.image_transforms import get_base_transforms, get_random_augmentations +from base_miner.datasets.util import load_and_split_datasets, create_real_fake_datasets def seed_torch(seed=1029): @@ -34,14 +34,19 @@ def main(): val_writer = SummaryWriter(os.path.join(opt.checkpoints_dir, opt.name, "val")) # RealFakeDataseta will limit the number of images sampled per dataset to the length of the smallest dataset - real_datasets = load_and_split_datasets(DATASET_META['real']) - fake_datasets = load_and_split_datasets(DATASET_META['fake']) + base_transforms = get_base_transforms() + random_augs = get_random_augmentations() + split_transforms = { + 'train': random_augs, + 'val': base_transforms, + 'test': base_transforms + } + real_datasets = load_and_split_datasets( + DATASET_META['real'], modality='image', split_transforms=split_transforms) + fake_datasets = load_and_split_datasets( + DATASET_META['fake'], modality='image', split_transforms=split_transforms) train_dataset, val_dataset, test_dataset = create_real_fake_datasets( - real_datasets, - fake_datasets, - train_transforms=random_aug_transforms, - val_transforms=base_transforms, - test_transforms=base_transforms) + real_datasets, fake_datasets) train_loader = DataLoader( train_dataset, batch_size=32, shuffle=True, num_workers=0, collate_fn=lambda d: tuple(d)) diff --git a/base_miner/UCF/config/constants.py b/base_miner/UCF/config/constants.py deleted file mode 100644 index 61d2ad6a..00000000 --- a/base_miner/UCF/config/constants.py +++ /dev/null @@ -1,15 +0,0 @@ -import os - -# Path to the directory containing the constants.py file -CONFIGS_DIR = os.path.dirname(os.path.abspath(__file__)) - -# The base directory for UCF-related files, i.e., UCF directory -UCF_BASE_PATH = os.path.abspath(os.path.join(CONFIGS_DIR, "..")) # Points to bitmind-subnet/base_miner/UCF/ -# Absolute paths for the required files and directories -CONFIG_PATH = os.path.join(CONFIGS_DIR, "ucf.yaml") # Path to the ucf.yaml file -WEIGHTS_DIR = os.path.join(UCF_BASE_PATH, "weights/") # Path to pretrained weights directory - -HF_REPO = "bitmind/ucf" -BACKBONE_CKPT = "xception_best.pth" - -DLIB_FACE_PREDICTOR_PATH = os.path.abspath(os.path.join(UCF_BASE_PATH, "../../bitmind/dataset_processing/dlib_tools/shape_predictor_81_face_landmarks.dat")) \ No newline at end of file diff --git a/base_miner/UCF/config/train_config.yaml b/base_miner/UCF/config/train_config.yaml deleted file mode 100644 index cd94d867..00000000 --- a/base_miner/UCF/config/train_config.yaml +++ /dev/null @@ -1,9 +0,0 @@ -mode: train -lmdb: True -dry_run: false -rgb_dir: './datasets/rgb' -lmdb_dir: './datasets/lmdb' -dataset_json_folder: './preprocessing/dataset_json' -SWA: False -save_avg: True -log_dir: ./logs/training/ \ No newline at end of file diff --git a/base_miner/__init__.py b/base_miner/__init__.py index 77486091..85b7148b 100644 --- a/base_miner/__init__.py +++ b/base_miner/__init__.py @@ -1,3 +1,3 @@ from .registry import DETECTOR_REGISTRY, GATE_REGISTRY -from .deepfake_detectors import NPRDetector, UCFDetector, CAMODetector +from .deepfake_detectors import NPRDetector, UCFDetector, CAMODetector, TALLDetector from .gating_mechanisms import FaceGate, GatingMechanism \ No newline at end of file diff --git a/base_miner/constants.py b/base_miner/constants.py new file mode 100644 index 00000000..020eb583 --- /dev/null +++ b/base_miner/constants.py @@ -0,0 +1,41 @@ +from pathlib import Path + +HUGGINGFACE_CACHE_DIR: Path = Path.home() / '.cache' / 'huggingface' +TARGET_IMAGE_SIZE = (256, 256) + + +IMAGE_DATASETS = { + "real": [ + {"path": "bitmind/bm-real"}, + {"path": "bitmind/open-images-v7"}, + {"path": "bitmind/celeb-a-hq"}, + {"path": "bitmind/ffhq-256"}, + {"path": "bitmind/MS-COCO-unique-256"} + ], + "fake": [ + {"path": "bitmind/bm-realvisxl"}, + {"path": "bitmind/bm-mobius"}, + {"path": "bitmind/bm-sdxl"} + ] +} + +VIDEO_DATASETS = { + "real": [ + {"path": "/home/user/.cache/huggingface/video_datasets/training"} + ], + "fake": [ + {"path": "/home/user/.cache/huggingface/video_datasets/training"} + ] +} + +FACE_IMAGE_DATASETS = { + "real": [ + {"path": "bitmind/ffhq-256_training_faces", "name": "base_transforms"}, + {"path": "bitmind/celeb-a-hq_training_faces", "name": "base_transforms"} + + ], + "fake": [ + {"path": "bitmind/ffhq-256___stable-diffusion-xl-base-1.0_training_faces", "name": "base_transforms"}, + {"path": "bitmind/celeb-a-hq___stable-diffusion-xl-base-1.0___256_training_faces", "name": "base_transforms"} + ] +} diff --git a/base_miner/datasets/__init__.py b/base_miner/datasets/__init__.py new file mode 100644 index 00000000..78111baa --- /dev/null +++ b/base_miner/datasets/__init__.py @@ -0,0 +1,4 @@ +from .base_dataset import BaseDataset +from .image_dataset import ImageDataset +from .video_dataset import VideoDataset +from .real_fake_dataset import RealFakeDataset diff --git a/base_miner/datasets/base_dataset.py b/base_miner/datasets/base_dataset.py new file mode 100644 index 00000000..7a2fb716 --- /dev/null +++ b/base_miner/datasets/base_dataset.py @@ -0,0 +1,79 @@ +from abc import ABC, abstractmethod +from datasets import Dataset +from typing import Optional +from torchvision.transforms import Compose + +from bitmind.download_data import load_huggingface_dataset + + +class BaseDataset(ABC): + def __init__( + self, + huggingface_dataset_path: Optional[str] = None, + huggingface_dataset_split: str = 'train', + huggingface_dataset_name: Optional[str] = None, + huggingface_dataset: Optional[Dataset] = None, + download_mode: Optional[str] = None, + transforms: Optional[Compose] = None + ): + """Base class for dataset implementations. + + Args: + huggingface_dataset_path (str, optional): Path to the Hugging Face dataset. + Can be a publicly hosted dataset (/) or + local directory (imagefolder:) + huggingface_dataset_split (str): Dataset split to load. Defaults to 'train'. + huggingface_dataset_name (str, optional): Name of the specific Hugging Face dataset subset. + huggingface_dataset (Dataset, optional): Pre-loaded Hugging Face dataset instance. + download_mode (str, optional): Download mode for the dataset. + Can be None or "force_redownload" + """ + self.huggingface_dataset_path = None + self.huggingface_dataset_split = huggingface_dataset_split + self.huggingface_dataset_name = None + self.dataset = None + self.transforms = transforms + + if huggingface_dataset_path is None and huggingface_dataset is None: + raise ValueError("Either huggingface_dataset_path or huggingface_dataset must be provided.") + + # If a dataset is directly provided, use it + if huggingface_dataset is not None: + self.dataset = huggingface_dataset + self.huggingface_dataset_path = self.dataset.info.dataset_name + self.huggingface_dataset_name = self.dataset.info.config_name + try: + self.huggingface_dataset_split = list(self.dataset.info.splits.keys())[0] + except AttributeError as e: + self.huggingface_data_split = None + + else: + # Store the initialization parameters + self.huggingface_dataset_path = huggingface_dataset_path + self.huggingface_dataset_name = huggingface_dataset_name + self.dataset = load_huggingface_dataset( + huggingface_dataset_path, + huggingface_dataset_split, + huggingface_dataset_name, + download_mode) + + @abstractmethod + def __getitem__(self, index: int) -> dict: + """Get an item from the dataset. + + Args: + index (int): Index of the item to retrieve. + + Returns: + dict: Dictionary containing the item data. + """ + pass + + @abstractmethod + def __len__(self) -> int: + """Get the length of the dataset. + + Returns: + int: Length of the dataset. + """ + pass diff --git a/base_miner/datasets/create_video_dataset.py b/base_miner/datasets/create_video_dataset.py new file mode 100644 index 00000000..c04cdfcf --- /dev/null +++ b/base_miner/datasets/create_video_dataset.py @@ -0,0 +1,305 @@ +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Union, Tuple +from multiprocessing import Pool, cpu_count + +import cv2 +import glob +import os + +import argparse +from PIL import Image +from datasets import Dataset, DatasetInfo, Image as HFImage, Split +from datasets.features import Features, Sequence, Value +from tqdm import tqdm + + +def process_single_video(args: Tuple[Path, Path, int, Optional[int], bool]) -> Tuple[str, int]: + """ + Extract frames from a single video + + Args: + args: Tuple containing (video_file, output_dir, frame_rate, max_frames, overwrite) + + Returns: + Tuple of (video_name, number_of_frames_saved) + """ + video_file, output_dir, frame_rate, max_frames, overwrite = args + video_name = video_file.stem + video_output_dir = output_dir / video_name + + if video_output_dir.exists() and not overwrite: + return video_name, 0 + + video_output_dir.mkdir(parents=True, exist_ok=True) + + video_capture = cv2.VideoCapture(str(video_file)) + frame_idx = 0 + saved_frame_count = 0 + + while True: + success, frame = video_capture.read() + if not success or (max_frames is not None and saved_frame_count >= max_frames): + break + + if frame_idx % frame_rate == 0: + frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + pil_image = Image.fromarray(frame_rgb) + frame_filename = video_output_dir / f"frame_{frame_idx:05d}.png" + pil_image.save(frame_filename) + saved_frame_count += 1 + + frame_idx += 1 + + video_capture.release() + return video_name, saved_frame_count + + +def extract_frames_from_videos( + input_dir: Union[str, Path], + output_dir: Union[str, Path], + num_videos: Optional[int] = None, + frame_rate: int = 1, + max_frames: Optional[int] = None, + overwrite: bool = False, + num_workers: Optional[int] = None +) -> None: + """ + Extract frames from videos (mp4s -> directories of PILs) using multiprocessing + + Args: + input_dir: Directory containing input MP4 files + output_dir: Directory where extracted frames will be saved + num_videos: Number of videos to process. If None, processes all videos + frame_rate: Extract one frame every 'frame_rate' frames + max_frames: Maximum number of frames to extract per video + overwrite: If True, overwrites existing frame directories + num_workers: Number of worker processes to use. If None, uses CPU count + """ + input_dir = Path(input_dir) + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + video_files = list(input_dir.glob("*.mp4")) + if num_videos is not None: + video_files = video_files[:num_videos] + + if not num_workers: + num_workers = cpu_count() + + print(f'Processing {len(video_files)} videos using {num_workers} workers') + + # Prepare arguments for each video + process_args = [ + (video_file, output_dir, frame_rate, max_frames, overwrite) + for video_file in video_files + ] + + # Process videos in parallel + with Pool(num_workers) as pool: + results = list(tqdm( + pool.imap(process_single_video, process_args), + total=len(video_files), + desc="Extracting frames" + )) + + # Print results + for video_name, frame_count in results: + if frame_count > 0: + print(f"Extracted {frame_count} frames from {video_name}") + else: + print(f"Skipped {video_name} (already exists)") + + +def create_video_frames_dataset( + frames_dir: Union[str, Path], + dataset_name: str = "video_frames", + validate_frames: bool = False, + delete_corrupted: bool = False, +) -> Dataset: + """Create a HuggingFace dataset from a directory of video frames.""" + frames_dir = Path(frames_dir) + video_data: Dict[str, Dict[str, List]] = defaultdict(lambda: {'frames': [], 'frame_numbers': []}) + + for video_dir in tqdm(sorted(os.listdir(frames_dir)), desc='processing video frames'): + video_path = frames_dir / video_dir + + if not video_path.is_dir(): + continue + + image_files = [] + for ext in ('*.png', '*.jpg', '*.jpeg'): + image_files.extend(glob.glob(str(video_path / ext))) + + image_files.sort() + + # Validate images before adding them to the dataset + if validate_frames: + valid_frames = [] + valid_frame_numbers = [] + for img_path in tqdm(image_files, desc="Checking image files"): + try: + # Attempt to fully load the image to verify it's valid + with Image.open(img_path) as img: + img.load() # Force load the image data + frame_num = int(Path(img_path).stem.split('_')[1]) + valid_frames.append(img_path) + valid_frame_numbers.append(frame_num) + except Exception as e: + print(f"Skipping corrupted image {img_path}: {str(e)}") + if delete_corrupted: + print(f"Deleting {img_path} (delete_corrupted = true)") + Path(img_path).unlink() + continue + if valid_frames: # Only add videos that have valid frames + video_data[video_dir]['frames'] = valid_frames + video_data[video_dir]['frame_numbers'] = valid_frame_numbers + else: + video_data[video_dir]['frames'] = image_files + video_data[video_dir]['frame_numbers'] = list(range(len(image_files))) + print(video_data[video_dir]['frames'][:10]) + print(video_data[video_dir]['frame_numbers'][:10]) + + dataset_dict = { + "video_id": [], + "frames": [], + "frame_numbers": [], + "num_frames": [] + } + + for video_id, data in video_data.items(): + if data['frames']: # Double check we have frames + dataset_dict["video_id"].append(video_id) + dataset_dict["frames"].append(data["frames"]) + dataset_dict["frame_numbers"].append(data["frame_numbers"]) + dataset_dict["num_frames"].append(len(data["frames"])) + + features = Features({ + "video_id": Value("string"), + "frames": Sequence(Value("string")), + "frame_numbers": Sequence(Value("int64")), + "num_frames": Value("int64") + }) + + dataset_info = DatasetInfo( + description="Video frames dataset", + features=features, + supervised_keys=None, + homepage="", + citation="", + task_templates=None, + dataset_name=dataset_name + ) + + # Create dataset with validated images + dataset = Dataset.from_dict( + dataset_dict, + info=dataset_info, + features=features + ) + + # Convert to HuggingFace image format with error handling + def convert_frames_to_images(example): + converted_frames = [] + for frame_path in example["frames"]: + try: + converted_frames.append(HFImage().encode_example(frame_path)) + except Exception as e: + print(f"Error converting {frame_path}: {str(e)}") + continue + example["frames"] = converted_frames + return example + + #dataset = dataset.map(convert_frames_to_images) + return dataset + + +def main() -> None: + """Parse command line arguments and run the dataset creation pipeline.""" + parser = argparse.ArgumentParser( + description="Extract frames from videos and create a HuggingFace dataset." + ) + parser.add_argument( + "--input_dir", + type=str, + required=True, + help="Path to the directory containing input MP4 files." + ) + parser.add_argument( + "--frames_dir", + type=str, + required=True, + help="Path to the directory where extracted frames will be saved." + ) + parser.add_argument( + "--dataset_dir", + type=str, + required=True, + help="Path where the HuggingFace dataset will be saved." + ) + parser.add_argument( + "--num_videos", + type=int, + default=None, + help="Number of videos to process. If not specified, processes all videos." + ) + parser.add_argument( + "--frame_rate", + type=int, + default=1, + help="Extract one frame every 'frame_rate' frames." + ) + parser.add_argument( + "--max_frames", + type=int, + default=None, + help="Maximum number of frames to extract per video." + ) + parser.add_argument( + "--overwrite", + action="store_true", + help="If set, overwrites existing frame directories." + ) + parser.add_argument( + "--skip_extraction", + action="store_true", + help="If set, skips the frame extraction step and only creates the dataset." + ) + parser.add_argument( + "--dataset_name", + type=str, + default="video_frames", + help="Name for the local HuggingFace dataset to be created." + ) + + args = parser.parse_args() + + if not args.skip_extraction: + print("Extracting frames from videos...") + extract_frames_from_videos( + input_dir=args.input_dir, + output_dir=args.frames_dir, + num_videos=args.num_videos, + frame_rate=args.frame_rate, + max_frames=args.max_frames, + overwrite=args.overwrite + ) + + print("\nCreating HuggingFace dataset...") + dataset = create_video_frames_dataset( + args.frames_dir, + dataset_name=args.dataset_name + ) + print(dataset.info) + print(f"\nSaving dataset to {args.dataset_dir}") + dataset.save_to_disk(args.dataset_dir) + + print(f"\nDataset creation complete!") + print(f"Total number of videos: {len(dataset)}") + print(f"Features: {dataset.features}") + print("Frame counts:", dataset["num_frames"]) + print(f"Dataset name: {dataset.info.dataset_name}") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/base_miner/datasets/image_dataset.py b/base_miner/datasets/image_dataset.py new file mode 100644 index 00000000..a5b4e3dd --- /dev/null +++ b/base_miner/datasets/image_dataset.py @@ -0,0 +1,115 @@ +from typing import List, Tuple, Optional +from datasets import Dataset +from PIL import Image +from io import BytesIO +import bittensor as bt +import numpy as np +from torchvision.transforms import Compose + +from bitmind.download_data import load_huggingface_dataset, download_image +from .base_dataset import BaseDataset + + +class ImageDataset(BaseDataset): + def __init__( + self, + huggingface_dataset_path: Optional[str] = None, + huggingface_dataset_split: str = 'train', + huggingface_dataset_name: Optional[str] = None, + huggingface_dataset: Optional[Dataset] = None, + download_mode: Optional[str] = None, + transforms: Optional[Compose] = None, + ): + """Initialize the ImageDataset. + + Args: + huggingface_dataset_path (str, optional): Path to the Hugging Face dataset. + Can be a publicly hosted dataset (/) or + local directory (imagefolder:) + huggingface_dataset_split (str): Dataset split to load. Defaults to 'train'. + huggingface_dataset_name (str, optional): Name of the specific Hugging Face dataset subset. + huggingface_dataset (Dataset, optional): Pre-loaded Hugging Face dataset instance. + download_mode (str, optional): Download mode for the dataset. + Can be None or "force_redownload" + """ + super().__init__( + huggingface_dataset_path=huggingface_dataset_path, + huggingface_dataset_split=huggingface_dataset_split, + huggingface_dataset_name=huggingface_dataset_name, + huggingface_dataset=huggingface_dataset, + download_mode=download_mode, + transforms=transforms + ) + + def __getitem__(self, index: int) -> dict: + """ + Get an item (image and ID) from the dataset. + + Args: + index (int): Index of the item to retrieve. + + Returns: + dict: Dictionary containing 'image' (PIL image) and 'id' (str). + """ + """ + Load an image from self.dataset. Expects self.dataset[i] to be a dictionary containing either 'image' or 'url' + as a key. + - The value associated with the 'image' key should be either a PIL image or a b64 string encoding of + the image. + - The value associated with the 'url' key should be a url that hosts the image (as in + dalle-mini/open-images) + + Args: + index (int): Index of the image in the dataset. + + Returns: + dict: Dictionary containing 'image' (PIL image) and 'id' (str). + """ + sample = self.dataset[int(index)] + if 'url' in sample: + image = download_image(sample['url']) + image_id = sample['url'] + elif 'image_url' in sample: + image = download_image(sample['image_url']) + image_id = sample['image_url'] + elif 'image' in sample: + if isinstance(sample['image'], Image.Image): + image = sample['image'] + elif isinstance(sample['image'], bytes): + image = Image.open(BytesIO(sample['image'])) + else: + raise NotImplementedError + + image_id = '' + if 'name' in sample: + image_id = sample['name'] + elif 'filename' in sample: + image_id = sample['filename'] + + image_id = image_id if image_id != '' else index + + else: + raise NotImplementedError + + # remove alpha channel if download didnt 404 + if image is not None: + image = image.convert('RGB') + + if self.transforms is not None: + image = self.transforms(image) + + return { + 'image': image, + 'id': image_id, + 'source': self.huggingface_dataset_path + } + + def __len__(self) -> int: + """ + Get the length of the dataset. + + Returns: + int: Length of the dataset. + """ + return len(self.dataset) + diff --git a/bitmind/real_fake_dataset.py b/base_miner/datasets/real_fake_dataset.py similarity index 82% rename from bitmind/real_fake_dataset.py rename to base_miner/datasets/real_fake_dataset.py index 4387d9be..0fb320a8 100644 --- a/bitmind/real_fake_dataset.py +++ b/base_miner/datasets/real_fake_dataset.py @@ -8,7 +8,6 @@ def __init__( self, real_image_datasets: list, fake_image_datasets: list, - transforms=None, fake_prob=0.5, source_label_mapping=None ): @@ -24,7 +23,6 @@ def __init__( """ self.real_image_datasets = real_image_datasets self.fake_image_datasets = fake_image_datasets - self.transforms = transforms self.fake_prob = fake_prob self.source_label_mapping = source_label_mapping @@ -55,21 +53,15 @@ def __getitem__(self, index: int) -> tuple: label = 1.0 else: source = self.real_image_datasets[np.random.randint(0, len(self.real_image_datasets))] - imgs, idx = source.sample(1) - image = imgs[0]['image'] - index = idx[0] + #imgs, idx = source.sample(1) + image = source[index]['image'] + #image = imgs[0]['image'] + #index = idx[0] label = 0.0 self._history['source'].append(source.huggingface_dataset_path) self._history['label'].append(label) self._history['index'].append(index) - - try: - if self.transforms is not None: - image = self.transforms(image) - except Exception as e: - print(e) - print(source.huggingface_dataset_path, index) if self.source_label_mapping: source_label = self.source_label_mapping[source.huggingface_dataset_path] @@ -94,4 +86,21 @@ def reset(self): 'source': [], 'index': [], 'label': [], - } \ No newline at end of file + } + + @staticmethod + def collate_fn(batch): + images, labels, source_labels = zip(*batch) + + images = torch.stack(images, dim=0) # Stack image tensors into a single tensor + labels = torch.LongTensor(labels) + source_labels = torch.LongTensor(source_labels) + + data_dict = { + 'image': images, + 'label': labels, + 'label_spe': source_labels, + 'landmark': None, + 'mask': None + } + return data_dict \ No newline at end of file diff --git a/bitmind/utils/data.py b/base_miner/datasets/util.py similarity index 88% rename from bitmind/utils/data.py rename to base_miner/datasets/util.py index 058faba8..cc461d29 100644 --- a/bitmind/utils/data.py +++ b/base_miner/datasets/util.py @@ -6,8 +6,7 @@ import datasets from bitmind.download_data import load_huggingface_dataset -from bitmind.real_fake_dataset import RealFakeDataset -from bitmind.image_dataset import ImageDataset +from base_miner.datasets import ImageDataset, VideoDataset, RealFakeDataset datasets.logging.set_verbosity_error() datasets.disable_progress_bar() @@ -17,8 +16,11 @@ def split_dataset(dataset): # Split data into train, validation, test and return the three splits dataset = dataset.shuffle(seed=42) + if 'train' in dataset: + dataset = dataset['train'] + split_dataset = {} - train_test_split = dataset['train'].train_test_split(test_size=0.2, seed=42) + train_test_split = dataset.train_test_split(test_size=0.2, seed=42) split_dataset['train'] = train_test_split['train'] temp_dataset = train_test_split['test'] @@ -30,7 +32,11 @@ def split_dataset(dataset): return split_dataset['train'], split_dataset['validation'], split_dataset['test'] -def load_and_split_datasets(dataset_meta: list) -> Dict[str, List[ImageDataset]]: +def load_and_split_datasets( + dataset_meta: list, + modality: str, + split_transforms: Dict[str, transforms.Compose] = {}, +) -> Dict[str, List[ImageDataset]]: """ Helper function to load and split dataset into train, validation, and test sets. @@ -56,7 +62,12 @@ def load_and_split_datasets(dataset_meta: list) -> Dict[str, List[ImageDataset]] train_ds, val_ds, test_ds = split_dataset(dataset) for split, data in zip(splits, [train_ds, val_ds, test_ds]): - image_dataset = ImageDataset(huggingface_dataset=data) + if modality == 'image': + image_dataset = ImageDataset(huggingface_dataset=data, transforms=split_transforms.get(split, None)) + elif modality == 'video': + image_dataset = VideoDataset(huggingface_dataset=data, transforms=split_transforms.get(split, None)) + else: + raise NotImplementedError(f'Unsupported modality: {modality}') datasets[split].append(image_dataset) split_lengths = ', '.join([f"{split} len={len(datasets[split][0])}" for split in splits]) @@ -105,9 +116,6 @@ def create_source_label_mapping( def create_real_fake_datasets( real_datasets: Dict[str, List[ImageDataset]], fake_datasets: Dict[str, List[ImageDataset]], - train_transforms: transforms.Compose = None, - val_transforms: transforms.Compose = None, - test_transforms: transforms.Compose = None, source_labels: bool = False, group_sources_by_name: bool = False) -> Tuple[RealFakeDataset, ...]: """ @@ -131,19 +139,16 @@ def create_real_fake_datasets( train_dataset = RealFakeDataset( real_image_datasets=real_datasets['train'], fake_image_datasets=fake_datasets['train'], - transforms=train_transforms, source_label_mapping=source_label_mapping) val_dataset = RealFakeDataset( real_image_datasets=real_datasets['validation'], fake_image_datasets=fake_datasets['validation'], - transforms=val_transforms, source_label_mapping=source_label_mapping) test_dataset = RealFakeDataset( real_image_datasets=real_datasets['test'], fake_image_datasets=fake_datasets['test'], - transforms=test_transforms, source_label_mapping=source_label_mapping) if source_labels: diff --git a/base_miner/datasets/video_dataset.py b/base_miner/datasets/video_dataset.py new file mode 100644 index 00000000..814e2bbc --- /dev/null +++ b/base_miner/datasets/video_dataset.py @@ -0,0 +1,116 @@ +""" +Author: Zhiyuan Yan +Email: zhiyuanyan@link.cuhk.edu.cn +Date: 2023-03-30 +Description: Abstract Base Class for all types of deepfake datasets. +""" + +import os +import cv2 +from PIL import Image +import sys +import yaml +import numpy as np +from copy import deepcopy +import random +import torch +from torch import nn +from torch.utils import data +from torchvision.utils import save_image +from torchvision.transforms import Compose +from einops import rearrange +from typing import List, Tuple, Optional +from datasets import Dataset + +from .base_dataset import BaseDataset + + +class VideoDataset(BaseDataset): + def __init__( + self, + huggingface_dataset_path: Optional[str] = None, + huggingface_dataset_split: str = 'train', + huggingface_dataset_name: Optional[str] = None, + huggingface_dataset: Optional[Dataset] = None, + download_mode: Optional[str] = None, + max_frames_per_video: Optional[int] = 4, + transforms: Optional[Compose] = None + ): + """Initialize the ImageDataset. + + Args: + huggingface_dataset_path (str, optional): Path to the Hugging Face dataset. + Can be a publicly hosted dataset (/) or + local directory (imagefolder:) + huggingface_dataset_split (str): Dataset split to load. Defaults to 'train'. + huggingface_dataset_name (str, optional): Name of the specific Hugging Face dataset subset. + huggingface_dataset (Dataset, optional): Pre-loaded Hugging Face dataset instance. + download_mode (str, optional): Download mode for the dataset. + Can be None or "force_redownload" + """ + super().__init__( + huggingface_dataset_path=huggingface_dataset_path, + huggingface_dataset_split=huggingface_dataset_split, + huggingface_dataset_name=huggingface_dataset_name, + huggingface_dataset=huggingface_dataset, + download_mode=download_mode, + transforms=transforms, + ) + self.max_frames = max_frames_per_video + + def __getitem__(self, index): + """Return the data point at the given index. + + Args: + index (int): The index of the data point. + no_norm (bool): Whether to skip normalization. + + Returns: + tuple: Contains image tensor, label tensor, landmark tensor, + and mask tensor. + """ + image_paths = self.dataset[index]['frames'] + + if not isinstance(image_paths, list): + image_paths = [image_paths] + + images = [] + for image_path in image_paths[:self.max_frames]: + try: + img = Image.open(image_path) + images.append(img) + except Exception as e: + print(f"Error loading image at index {index}: {e}") + return self.__getitem__(0) + + if self.transforms is not None: + images = self.transforms(images) + + # Stack images along the time dimension (frame_dim) + image_tensors = torch.stack(images, dim=0) # Shape: [frame_dim, C, H, W] + + frames, channels, height, width = image_tensors.shape + x = torch.randint(0, width, (1,)).item() + y = torch.randint(0, height, (1,)).item() + mask_grid_size = 16 + x1 = max(x - mask_grid_size // 2, 0) + x2 = min(x + mask_grid_size // 2, width) + y1 = max(y - mask_grid_size // 2, 0) + y2 = min(y + mask_grid_size // 2, height) + image_tensors[:, :, y1:y2, x1:x2] = -1 + + return { + 'image': image_tensors, # Shape: [frame_dim, C, H, W] + 'id': self.dataset[index]['video_id'], + 'source': self.huggingface_dataset_path + } + + + def __len__(self) -> int: + """ + Get the length of the dataset. + + Returns: + int: Length of the dataset. + """ + return len(self.dataset['video_id']) \ No newline at end of file diff --git a/base_miner/deepfake_detectors/__init__.py b/base_miner/deepfake_detectors/__init__.py index 08e0636f..f82bc189 100644 --- a/base_miner/deepfake_detectors/__init__.py +++ b/base_miner/deepfake_detectors/__init__.py @@ -1,4 +1,5 @@ from .deepfake_detector import DeepfakeDetector from .npr_detector import NPRDetector from .ucf_detector import UCFDetector -from .camo_detector import CAMODetector \ No newline at end of file +from .camo_detector import CAMODetector +from .tall_detector import TALLDetector \ No newline at end of file diff --git a/base_miner/deepfake_detectors/configs/tall.yaml b/base_miner/deepfake_detectors/configs/tall.yaml new file mode 100644 index 00000000..e009e480 --- /dev/null +++ b/base_miner/deepfake_detectors/configs/tall.yaml @@ -0,0 +1,3 @@ +hf_repo: 'bitmind/tall' # Hugging Face repository for downloading model files +train_config: 'tall.yaml' # pre-trained configuration file in HuggingFace +weights: 'tall_trainFF_testCDF.pth' # UCF model checkpoint in HuggingFace \ No newline at end of file diff --git a/base_miner/deepfake_detectors/deepfake_detector.py b/base_miner/deepfake_detectors/deepfake_detector.py index 4c4af4ca..626a0737 100644 --- a/base_miner/deepfake_detectors/deepfake_detector.py +++ b/base_miner/deepfake_detectors/deepfake_detector.py @@ -1,85 +1,152 @@ -import typing from abc import ABC, abstractmethod from pathlib import Path -import yaml +from typing import Optional, Dict, Any + import torch +import yaml +import bittensor as bt from PIL import Image +from huggingface_hub import hf_hub_download + +from base_miner.DFB.config.constants import CONFIGS_DIR, WEIGHTS_DIR class DeepfakeDetector(ABC): - """ - Abstract base class for detecting deepfake images via binary classification. + """Abstract base class for detecting deepfake images via binary classification. This class is intended to be subclassed by detector implementations - using different underying model architectures, routing via gates, or + using different underlying model architectures, routing via gates, or configurations. - + Attributes: model_name (str): Name of the detector instance. - config (str): Name of the YAML file in deepfake_detectors/config/ to load - instance attributes from. + config (Optional[str]): Name of the YAML file in deepfake_detectors/config/ + to load instance attributes from. device (str): The type of device ('cpu' or 'cuda'). + hf_repo (str): Hugging Face repository name for model weights. + train_config (str): Name of training configuration file. """ - - def __init__(self, model_name: str, config = None, device: str = 'cpu'): + + def __init__( + self, + model_name: str, + config: Optional[str] = None, + device: str = 'cpu' + ) -> None: + """Initialize the DeepfakeDetector. + + Args: + model_name: Name of the detector instance. + config: Optional name of configuration file to load. + device: Device to run the model on ('cpu' or 'cuda'). + """ self.model_name = model_name - self.device = torch.device(device if device == 'cuda' and torch.cuda.is_available() else 'cpu') + self.device = torch.device( + device if device == 'cuda' and torch.cuda.is_available() else 'cpu' + ) + if config: - self.load_and_apply_config(config) + self.set_class_attrs(config) + self.load_model_config() + self.load_model() @abstractmethod - def load_model(self): - """ - Load the model. Specific loading implementations will be defined in subclasses. + def load_model(self) -> None: + """Load the model weights and architecture. + + This method should be implemented by subclasses to define their specific + model loading logic. """ pass - def preprocess(self, image: Image) -> torch.Tensor: - """ - Preprocess the image for model inference. - + def preprocess(self, image: Image.Image) -> torch.Tensor: + """Preprocess the image for model inference. + Args: - image (PIL.Image): The image to preprocess. - extra_data (dict, optional): Any additional data required for preprocessing. + image: The input image to preprocess. Returns: - torch.Tensor: The preprocessed image tensor. + The preprocessed image as a tensor ready for model input. """ # General preprocessing, to be overridden if necessary in subclasses pass @abstractmethod - def __call__(self, image: Image) -> float: - """ - Perform inference with the model. + def __call__(self, image: Image.Image) -> float: + """Perform inference with the model. Args: - image (PIL.Image): The preprocessed image. + image: The preprocessed input image. Returns: - float: The model's prediction (or other relevant result). + The model's prediction score (typically between 0 and 1). """ + pass + + def set_class_attrs(self, detector_config: str) -> None: + """Load detector configuration from YAML file and set attributes. - def load_and_apply_config(self, detector_config): - """ - Load detector configuration from YAML file and set corresponding attributes dynamically. - Args: - config_path (str): Path to the YAML configuration file. + detector_config: Path to the YAML configuration file or filename + in the configs directory. + + Raises: + Exception: If there is an error loading or parsing the config file. """ if Path(detector_config).exists(): detector_config_file = Path(detector_config) else: - detector_config_file = Path(__file__).resolve().parent / Path('configs/' + detector_config) + detector_config_file = ( + Path(__file__).resolve().parent / Path('configs/' + detector_config) + ) + try: - with open(detector_config_file, 'r') as file: + with open(detector_config_file, 'r', encoding='utf-8') as file: config_dict = yaml.safe_load(file) # Set class attributes dynamically from the config dictionary for key, value in config_dict.items(): - setattr(self, key, value) # Dynamically create self.key = value - + setattr(self, key, value) + except Exception as e: print(f"Error loading detector configurations from {detector_config_file}: {e}") - raise \ No newline at end of file + raise + + def ensure_weights_are_available( + self, + weights_dir: str, + weights_filename: str + ) -> None: + """Ensure model weights are downloaded and available locally. + + Downloads weights from Hugging Face Hub if not found locally. + + Args: + weights_dir: Directory to store/find the weights. + weights_filename: Name of the weights file. + """ + destination_path = Path(weights_dir) / Path(weights_filename) + if not Path(weights_dir).exists(): + Path(weights_dir).mkdir(parents=True, exist_ok=True) + + if not destination_path.exists(): + print(f"Downloading {weights_filename} from {self.hf_repo} " + f"to {weights_dir}") + hf_hub_download(self.hf_repo, weights_filename, local_dir=weights_dir) + + def load_model_config(self): + try: + destination_path = Path(CONFIGS_DIR) / Path(self.train_config) + if not destination_path.exists(): + local_config_path = hf_hub_download(self.hf_repo, self.train_config, local_dir=CONFIGS_DIR) + print(f"Downloaded {self.hf_repo}/{self.train_config} to {local_config_path}") + with Path(local_config_path).open('r') as f: + self.config = yaml.safe_load(f) + else: + print(f"Loaded local config from {destination_path}") + with destination_path.open('r') as f: + self.config = yaml.safe_load(f) + except Exception as e: + # some models such as NPR don't have an additional config file + bt.logging.warning("No additional train config loaded.") \ No newline at end of file diff --git a/base_miner/deepfake_detectors/npr_detector.py b/base_miner/deepfake_detectors/npr_detector.py index 30b1782d..2f83b6e5 100644 --- a/base_miner/deepfake_detectors/npr_detector.py +++ b/base_miner/deepfake_detectors/npr_detector.py @@ -4,7 +4,7 @@ from pathlib import Path from huggingface_hub import hf_hub_download from base_miner.NPR.networks.resnet import resnet50 -from bitmind.image_transforms import base_transforms +from bitmind.utils.image_transforms import get_base_transforms from base_miner.deepfake_detectors import DeepfakeDetector from base_miner import DETECTOR_REGISTRY from base_miner.NPR.config.constants import WEIGHTS_DIR @@ -25,24 +25,16 @@ class NPRDetector(DeepfakeDetector): def __init__(self, model_name: str = 'NPR', config: str = 'npr.yaml', device: str = 'cpu'): super().__init__(model_name, config, device) + self.transforms = get_base_transforms() def load_model(self): """ Load the ResNet50 model with the specified weights for deepfake detection. """ - self.ensure_weights_are_available(self.weights) + self.ensure_weights_are_available(WEIGHTS_DIR, self.weights) self.model = resnet50(num_classes=1) self.model.load_state_dict(torch.load(Path(WEIGHTS_DIR) / self.weights, map_location=self.device)) self.model.eval() - - def ensure_weights_are_available(self, weight_filename): - destination_path = Path(WEIGHTS_DIR) / Path(weight_filename) - if not destination_path.parent.exists(): - destination_path.parent.mkdir(parents=True, exist_ok=True) - if not destination_path.exists(): - model_path = hf_hub_download(self.hf_repo, weight_filename) - model = torch.load(model_path, map_location=torch.device(self.device)) - torch.save(model, destination_path) def preprocess(self, image: Image) -> torch.Tensor: """ @@ -54,7 +46,7 @@ def preprocess(self, image: Image) -> torch.Tensor: Returns: torch.Tensor: The preprocessed image tensor. """ - image_tensor = base_transforms(image).unsqueeze(0).float() + image_tensor = self.transforms(image).unsqueeze(0).float() return image_tensor def __call__(self, image: Image) -> float: diff --git a/base_miner/deepfake_detectors/tall_detector.py b/base_miner/deepfake_detectors/tall_detector.py new file mode 100644 index 00000000..cc8eb750 --- /dev/null +++ b/base_miner/deepfake_detectors/tall_detector.py @@ -0,0 +1,51 @@ +import torch +from pathlib import Path + +import bittensor as bt +from base_miner import DETECTOR_REGISTRY +from base_miner.DFB.config.constants import CONFIGS_DIR, WEIGHTS_DIR +from base_miner.DFB.detectors import DETECTOR, TALLDetector +from base_miner.deepfake_detectors import DeepfakeDetector +from bitmind.utils.video_utils import pad_frames + + +@DETECTOR_REGISTRY.register_module(module_name="TALL") +class TALLVideoDetector(DeepfakeDetector): + def __init__( + self, + model_name: str = "TALL", + config: str = "tall.yaml", + device: str = "cpu", + ): + super().__init__(model_name, config, device) + + total_params = sum(p.numel() for p in self.tall.model.parameters()) + trainable_params = sum( + p.numel() for p in self.tall.model.parameters() if p.requires_grad + ) + bt.logging.info('device:', self.device) + bt.logging.info(total_params, "parameters") + bt.logging.info(trainable_params, "trainable parameters") + + def load_model(self): + # download weights from hf if not available locally + self.ensure_weights_are_available(WEIGHTS_DIR, self.weights) + bt.logging.info(f"Loaded config from training run: {self.config}") + self.tall = TALLDetector(self.config, self.device) + + # load weights + checkpoint_path = Path(WEIGHTS_DIR) / self.weights + checkpoint = torch.load(checkpoint_path, map_location=self.device) + self.tall.load_state_dict(checkpoint, strict=True) + self.tall.model.eval() + + def preprocess(self, frames_tensor): + """ Prepare input data dict for TALLDetector """ + frames_tensor = pad_frames(frames_tensor, 4) + return {'image': frames_tensor} + + def __call__(self, frames_tensor): + input_data = self.preprocess(frames_tensor) + with torch.no_grad(): + output_data = self.tall.forward(input_data, inference=True) + return output_data['prob'][0] diff --git a/base_miner/deepfake_detectors/ucf_detector.py b/base_miner/deepfake_detectors/ucf_detector.py index ce30ed4d..de3b87c6 100644 --- a/base_miner/deepfake_detectors/ucf_detector.py +++ b/base_miner/deepfake_detectors/ucf_detector.py @@ -4,28 +4,25 @@ import random import warnings warnings.filterwarnings("ignore", category=FutureWarning) +from huggingface_hub import hf_hub_download from pathlib import Path - +from PIL import Image +import torchvision.transforms as transforms +import torch.backends.cudnn as cudnn +import bittensor as bt import numpy as np import torch -import torch.backends.cudnn as cudnn -import torchvision.transforms as transforms import yaml -from PIL import Image -from huggingface_hub import hf_hub_download import gc -from base_miner.UCF.config.constants import CONFIGS_DIR, WEIGHTS_DIR -from base_miner.gating_mechanisms import FaceGate - -from base_miner.UCF.detectors import DETECTOR +from base_miner.DFB.config.constants import CONFIGS_DIR, WEIGHTS_DIR from base_miner.deepfake_detectors import DeepfakeDetector -from base_miner import DETECTOR_REGISTRY, GATE_REGISTRY +from base_miner.DFB.detectors import UCFDetector +from base_miner import DETECTOR_REGISTRY -import bittensor as bt @DETECTOR_REGISTRY.register_module(module_name='UCF') -class UCFDetector(DeepfakeDetector): +class UCFImageDetector(DeepfakeDetector): """ DeepfakeDetector subclass that initializes a pretrained UCF model for binary classification of fake and real images. @@ -39,33 +36,6 @@ class UCFDetector(DeepfakeDetector): def __init__(self, model_name: str = 'UCF', config: str = 'ucf.yaml', device: str = 'cpu'): super().__init__(model_name, config, device) - - def ensure_weights_are_available(self, weight_filename): - destination_path = Path(WEIGHTS_DIR) / Path(weight_filename) - if not destination_path.parent.exists(): - destination_path.parent.mkdir(parents=True, exist_ok=True) - if not destination_path.exists(): - model_path = hf_hub_download(self.hf_repo, weight_filename) - model = torch.load(model_path, map_location=self.device) - torch.save(model, destination_path) - - def load_train_config(self): - destination_path = Path(CONFIGS_DIR) / Path(self.train_config) - - if not destination_path.exists(): - local_config_path = hf_hub_download(self.hf_repo, self.train_config) - print(f"Downloaded {self.hf_repo}/{self.train_config} to {local_config_path}") - config_dict = {} - with open(local_config_path, 'r') as f: - config_dict = yaml.safe_load(f) - with open(destination_path, 'w') as f: - yaml.dump(config_dict, f, default_flow_style=False) - with destination_path.open('r') as f: - return yaml.safe_load(f) - else: - print(f"Loaded local config from {destination_path}") - with destination_path.open('r') as f: - return yaml.safe_load(f) def init_cudnn(self): if self.train_config.get('cudnn'): @@ -79,14 +49,13 @@ def init_seed(self): torch.cuda.manual_seed_all(seed_value) def load_model(self): - self.train_config = self.load_train_config() self.init_cudnn() self.init_seed() - self.ensure_weights_are_available(self.weights) - self.ensure_weights_are_available(self.train_config['pretrained'].split('/')[-1]) - model_class = DETECTOR[self.train_config['model_name']] + self.ensure_weights_are_available(WEIGHTS_DIR, self.weights) + #self.ensure_weights_are_available(WEIGHTS_DIR, self.train_config['pretrained'].split('/')[-1]) + #model_class = DETECTOR[self.train_config['model_name']] bt.logging.info(f"Loaded config from training run: {self.train_config}") - self.model = model_class(self.train_config).to(self.device) + self.model = UCFDetector(self.train_config).to(self.device) self.model.eval() weights_path = Path(WEIGHTS_DIR) / self.weights checkpoint = torch.load(weights_path, map_location=self.device) diff --git a/base_miner/gating_mechanisms/face_gate.py b/base_miner/gating_mechanisms/face_gate.py index b8854e8e..ce557b96 100644 --- a/base_miner/gating_mechanisms/face_gate.py +++ b/base_miner/gating_mechanisms/face_gate.py @@ -4,7 +4,7 @@ import dlib from base_miner.gating_mechanisms import Gate -from base_miner.UCF.config.constants import DLIB_FACE_PREDICTOR_PATH +from base_miner.DFB.config.constants import DLIB_FACE_PREDICTOR_PATH from base_miner import GATE_REGISTRY from base_miner.gating_mechanisms.utils import get_face_landmarks, align_and_crop_face diff --git a/bitmind/__init__.py b/bitmind/__init__.py index 68285ed0..8b0f87ca 100644 --- a/bitmind/__init__.py +++ b/bitmind/__init__.py @@ -18,7 +18,7 @@ # DEALINGS IN THE SOFTWARE. -__version__ = "1.2.9" +__version__ = "2.0.0" version_split = __version__.split(".") __spec_version__ = ( (1000 * int(version_split[0])) diff --git a/bitmind/base/miner.py b/bitmind/base/miner.py index e812eac4..e63d70a1 100644 --- a/bitmind/base/miner.py +++ b/bitmind/base/miner.py @@ -20,6 +20,7 @@ import threading import argparse import traceback +import typing import bittensor as bt @@ -53,17 +54,9 @@ def __init__(self, config=None): bt.logging.warning( "You are allowing non-registered entities to send requests to your miner. This is a security risk." ) - # The axon handles request processing, allowing validators to send this miner requests. - self.axon = bt.axon(wallet=self.wallet, config=self.config() if callable(self.config) else self.config) - # Attach determiners which functions are called when servicing a request. - bt.logging.info(f"Attaching forward function to miner axon.") - self.axon.attach( - forward_fn=self.forward, - blacklist_fn=self.blacklist, - priority_fn=self.priority, - ) - bt.logging.info(f"Axon created: {self.axon}") + # attach miner-specific functions in subclass __init__ + self.axon = bt.axon(wallet=self.wallet, config=self.config() if callable(self.config) else self.config) # Instantiate runners self.should_exit: bool = False @@ -192,3 +185,101 @@ def resync_metagraph(self): # Sync the metagraph. self.metagraph.sync(subtensor=self.subtensor) + + async def blacklist( + self, synapse: bt.Synapse + ) -> typing.Tuple[bool, str]: + """ + Determines whether an incoming request should be blacklisted and thus ignored. Your implementation should + define the logic for blacklisting requests based on your needs and desired security parameters. + + Blacklist runs before the synapse data has been deserialized (i.e. before synapse.data is available). + The synapse is instead contructed via the headers of the request. It is important to blacklist + requests before they are deserialized to avoid wasting resources on requests that will be ignored. + + Args: + synapse (bt.Synapse): A synapse object constructed from the headers of the incoming request. + + Returns: + Tuple[bool, str]: A tuple containing a boolean indicating whether the synapse's hotkey is blacklisted, + and a string providing the reason for the decision. + + This function is a security measure to prevent resource wastage on undesired requests. It should be enhanced + to include checks against the metagraph for entity registration, validator status, and sufficient stake + before deserialization of synapse data to minimize processing overhead. + + Example blacklist logic: + - Reject if the hotkey is not a registered entity within the metagraph. + - Consider blacklisting entities that are not validators or have insufficient stake. + + In practice it would be wise to blacklist requests from entities that are not validators, or do not have + enough stake. This can be checked via metagraph.S and metagraph.validator_permit. You can always attain + the uid of the sender via a metagraph.hotkeys.index( synapse.dendrite.hotkey ) call. + + Otherwise, allow the request to be processed further. + """ + if synapse.dendrite is None or synapse.dendrite.hotkey is None: + bt.logging.warning("Received a request without a dendrite or hotkey.") + return True, "Missing dendrite or hotkey" + + # TODO(developer): Define how miners should blacklist requests. + uid = self.metagraph.hotkeys.index(synapse.dendrite.hotkey) + if ( + not self.config.blacklist.allow_non_registered + and synapse.dendrite.hotkey not in self.metagraph.hotkeys + ): + # Ignore requests from un-registered entities. + bt.logging.trace( + f"Blacklisting un-registered hotkey {synapse.dendrite.hotkey}" + ) + return True, "Unrecognized hotkey" + + if self.config.blacklist.force_validator_permit: + # If the config is set to force validator permit, then we should only allow requests from validators. + if not self.metagraph.validator_permit[uid]: + bt.logging.warning( + f"Blacklisting a request from non-validator hotkey {synapse.dendrite.hotkey}" + ) + return True, "Non-validator hotkey" + + bt.logging.trace( + f"Not Blacklisting recognized hotkey {synapse.dendrite.hotkey}" + ) + return False, "Hotkey recognized!" + + async def priority(self, synapse: bt.Synapse) -> float: + """ + The priority function determines the order in which requests are handled. More valuable or higher-priority + requests are processed before others. You should design your own priority mechanism with care. + + This implementation assigns priority to incoming requests based on the calling entity's stake in the metagraph. + + Args: + synapse (bt.Synapse): The synapse object that contains metadata about the incoming request. + + Returns: + float: A priority score derived from the stake of the calling entity. + + Miners may recieve messages from multiple entities at once. This function determines which request should be + processed first. Higher values indicate that the request should be processed first. Lower values indicate + that the request should be processed later. + + Example priority logic: + - A higher stake results in a higher priority value. + """ + if synapse.dendrite is None or synapse.dendrite.hotkey is None: + bt.logging.warning("Received a request without a dendrite or hotkey.") + return 0.0 + + # TODO(developer): Define how miners should prioritize requests. + caller_uid = self.metagraph.hotkeys.index( + synapse.dendrite.hotkey + ) # Get the caller index. + + prirority = float( + self.metagraph.S[caller_uid] + ) # Return the stake as the priority. + bt.logging.trace( + f"Prioritizing {synapse.dendrite.hotkey} with value: ", prirority + ) + return prirority diff --git a/bitmind/base/neuron.py b/bitmind/base/neuron.py index 15c0374f..247afa16 100644 --- a/bitmind/base/neuron.py +++ b/bitmind/base/neuron.py @@ -108,10 +108,6 @@ def __init__(self, config=None): ) self.step = 0 - @abstractmethod - async def forward(self, synapse: bt.Synapse) -> bt.Synapse: - ... - @abstractmethod def run(self): ... diff --git a/bitmind/base/validator.py b/bitmind/base/validator.py index 47b055b7..9db6a576 100644 --- a/bitmind/base/validator.py +++ b/bitmind/base/validator.py @@ -171,7 +171,6 @@ def run(self): self.sync() self.step += 1 - time.sleep(60) # If someone intentionally stops the validator, it'll safely terminate operations. except KeyboardInterrupt: @@ -382,7 +381,12 @@ def save_miner_history(self): def load_miner_history(self): if os.path.exists(self.history_cache_path): bt.logging.info(f"Loading miner performance history from {self.history_cache_path}") - self.performance_tracker = joblib.load(self.history_cache_path) + try: + self.performance_tracker = joblib.load(self.history_cache_path) + except Exception as e: + bt.logging.error(f'Error loading miner performance tracker: {e}') + self.performance_tracker = MinerPerformanceTracker() + pred_history = self.performance_tracker.prediction_history num_miners_history = len([ uid for uid in pred_history diff --git a/bitmind/constants.py b/bitmind/constants.py deleted file mode 100644 index d2dccdd4..00000000 --- a/bitmind/constants.py +++ /dev/null @@ -1,134 +0,0 @@ -import os -import torch - - -WANDB_PROJECT = 'bitmind-subnet' -WANDB_ENTITY = 'bitmindai' - -DATASET_META = { - "real": [ - {"path": "bitmind/bm-real"}, - {"path": "bitmind/open-images-v7"}, - {"path": "bitmind/celeb-a-hq"}, - {"path": "bitmind/ffhq-256"}, - {"path": "bitmind/MS-COCO-unique-256"} - ], - "fake": [ - {"path": "bitmind/bm-realvisxl"}, - {"path": "bitmind/bm-mobius"}, - {"path": "bitmind/bm-sdxl"} - ] -} - -FACE_TRAINING_DATASET_META = { - "real": [ - {"path": "bitmind/ffhq-256_training_faces", "name": "base_transforms"}, - {"path": "bitmind/celeb-a-hq_training_faces", "name": "base_transforms"} - - ], - "fake": [ - {"path": "bitmind/ffhq-256___stable-diffusion-xl-base-1.0_training_faces", "name": "base_transforms"}, - {"path": "bitmind/celeb-a-hq___stable-diffusion-xl-base-1.0___256_training_faces", "name": "base_transforms"} - ] -} - -VALIDATOR_DATASET_META = { - "real": [ - {"path": "bitmind/bm-real"}, - {"path": "bitmind/open-images-v7"}, - {"path": "bitmind/celeb-a-hq"}, - {"path": "bitmind/ffhq-256"}, - {"path": "bitmind/MS-COCO-unique-256"}, - {"path": "bitmind/AFHQ"}, - {"path": "bitmind/lfw"}, - {"path": "bitmind/caltech-256"}, - {"path": "bitmind/caltech-101"}, - {"path": "bitmind/dtd"} - ] -} - -VALIDATOR_MODEL_META = { - "diffusers": [ - { - "path": "stabilityai/stable-diffusion-xl-base-1.0", - "use_safetensors": True, - "torch_dtype": torch.float16, - "variant": "fp16", - "pipeline": "StableDiffusionXLPipeline" - }, - { - "path": "SG161222/RealVisXL_V4.0", - "use_safetensors": True, - "torch_dtype": torch.float16, - "variant": "fp16", - "pipeline": "StableDiffusionXLPipeline" - }, - { - "path": "Corcelio/mobius", - "use_safetensors": True, - "torch_dtype": torch.float16, - "pipeline": "StableDiffusionXLPipeline" - }, - { - "path": 'black-forest-labs/FLUX.1-dev', - "use_safetensors": True, - "torch_dtype": torch.bfloat16, - "generate_args": { - "guidance_scale": 2, - "num_inference_steps": {"min": 50, "max": 125}, - "generator": torch.Generator("cuda" if torch.cuda.is_available() else "cpu"), - "height": [512, 768], - "width": [512, 768] - }, - "enable_cpu_offload": False, - "pipeline": "FluxPipeline" - }, - { - "path": "prompthero/openjourney-v4", - "use_safetensors": True, - "torch_dtype": torch.float16, - "pipeline": "StableDiffusionPipeline" - }, - { - "path": "cagliostrolab/animagine-xl-3.1", - "use_safetensors": True, - "torch_dtype": torch.float16, - "pipeline": "StableDiffusionXLPipeline" - } - ] -} - -HUGGINGFACE_CACHE_DIR = os.path.expanduser('~/.cache/huggingface') - -TARGET_IMAGE_SIZE = (256, 256) - -PROMPT_TYPES = ('annotation', 'none') - -# args for .from_pretrained -DIFFUSER_ARGS = { - m['path']: { - k: v for k, v in m.items() - if k not in ('path', 'pipeline', 'generate_args', 'enable_cpu_offload') - } for m in VALIDATOR_MODEL_META['diffusers'] -} - -GENERATE_ARGS = { - m['path']: m['generate_args'] - for m in VALIDATOR_MODEL_META['diffusers'] - if 'generate_args' in m -} - -DIFFUSER_CPU_OFFLOAD_ENABLED = { - m['path']: m.get('enable_cpu_offload', False) - for m in VALIDATOR_MODEL_META['diffusers'] -} - -DIFFUSER_PIPELINE = { - m['path']: m['pipeline'] for m in VALIDATOR_MODEL_META['diffusers'] if 'pipeline' in m -} - -DIFFUSER_NAMES = list(DIFFUSER_ARGS.keys()) - -IMAGE_ANNOTATION_MODEL = "Salesforce/blip2-opt-6.7b-coco" - -TEXT_MODERATION_MODEL = "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit" diff --git a/bitmind/download_data.py b/bitmind/download_data.py index fcb68559..44b8eef7 100644 --- a/bitmind/download_data.py +++ b/bitmind/download_data.py @@ -11,35 +11,52 @@ import glob import requests -from bitmind.constants import DATASET_META, HUGGINGFACE_CACHE_DIR +from base_miner.constants import IMAGE_DATASETS, HUGGINGFACE_CACHE_DIR datasets.logging.set_verbosity_warning() datasets.disable_progress_bar() +from datasets import load_dataset, load_from_disk +from pathlib import Path +from typing import Optional, Union +import os + def load_huggingface_dataset( path: str, split: str = 'train', name: Optional[str] = None, download_mode: str = 'reuse_cache_if_exists', + local_data_path: Optional[str] = None ) -> datasets.Dataset: """ Load a dataset from Hugging Face or a local directory. - Args: - path (str): Path to the dataset or 'imagefolder:' for image folder. Can either be to a publicly - hosted huggingface datset with the format / or a local directory with the format - imagefolder: - split (str, optional): Name of the dataset split to load (default: None). - Make sure to check what splits are available for the datasets you're working with. - name (str, optional): Name of the dataset (if loading from Hugging Face, default: None). - Some huggingface datasets provide various subets of different sizes, which can be accessed via thi - parameter. - download_mode (str, optional): Download mode for the dataset (if loading from Hugging Face, default: None). - can be None or "force_redownload" + path (str): Path to dataset. Can be: + - A Hugging Face dataset path (/) + - An image folder path (imagefolder:) + - A local path to a saved dataset (for load_from_disk) + split (str, optional): Dataset split to load (default: 'train') + name (str, optional): Dataset configuration name (default: None) + download_mode (str, optional): Download mode for Hugging Face datasets + local_data_path (str, optional): Path for storing downloaded media files Returns: - Union[dict, load_dataset.Dataset]: The loaded dataset or a specific split of the dataset as requested. + Dataset: The loaded dataset or requested split """ + # Check if it's a local path suitable for load_from_disk + if not path.startswith('imagefolder:') and os.path.exists(path): + try: + # Look for dataset artifacts that indicate this is a saved dataset + dataset_files = {'dataset_info.json', 'state.json', 'data'} + path_contents = set(os.listdir(path)) + if dataset_files.intersection(path_contents): + dataset = load_from_disk(path) + if split is None: + return dataset + return dataset[split] + except Exception as e: + print(f"Attempted load_from_disk but failed: {e}") + if 'imagefolder' in path: _, directory = path.split(':') if name: @@ -51,14 +68,13 @@ def load_huggingface_dataset( dataset_path=path, dataset_name=name, download_mode=download_mode, - cache_dir=HUGGINGFACE_CACHE_DIR) - + cache_dir=HUGGINGFACE_CACHE_DIR, + local_data_path=local_data_path) + if split is None: return dataset - return dataset[split] - def download_image(url: str) -> Image.Image: """ Download an image from a URL. @@ -80,36 +96,12 @@ def download_image(url: str) -> Image.Image: return None -def clear_cache(cache_dir): - """Clears lock files and incomplete downloads from the cache directory.""" - # Find lock and incomplete files - lock_files = glob.glob(cache_dir + "/*lock") - incomplete_files = glob.glob(cache_dir + "/downloads/**/*.incomplete", recursive=True) - try: - if lock_files: - subprocess.run(["rm", *lock_files], check=True) - if incomplete_files: - for file in incomplete_files: - os.remove(file) - print("Hugging Face cache lock files cleared successfully.") - except Exception as e: - print(f"Failed to clear Hugging Face cache lock files: {e}") - - -def fix_permissions(path): - """Attempts to fix permission issues on a given path.""" - try: - subprocess.run(["chmod", "-R", "775", path], check=True) - print(f"Fixed permissions for {path}") - except subprocess.CalledProcessError as e: - print(f"Failed to fix permissions for {path}: {e}") - - def download_dataset( dataset_path: str, dataset_name: str, download_mode: str, cache_dir: str, + local_data_path: str = None, max_wait: int = 300 ): """ Downloads the datasets present in datasets.json with exponential backoff @@ -121,7 +113,7 @@ def download_dataset( print(f"Downloading {dataset_path} (subset={dataset_name}) dataset...") while True: try: - if dataset_name: + if dataset_name is not None: dataset = load_dataset(dataset_path, name=dataset_name, #config/subset name cache_dir=cache_dir, @@ -132,6 +124,11 @@ def download_dataset( cache_dir=cache_dir, download_mode=download_mode, trust_remote_code=True) + + #if local_data_path is not None: + # print(f"Downloading media for {dataset_path} to {local_data_path}") + # download_media(dataset_path, local_data_path) + break except Exception as e: print(e) @@ -141,7 +138,7 @@ def download_dataset( file_path = str(e).split(": '")[1].rstrip("'") print(f"Permission error at {file_path}, attempting to fix...") fix_permissions(file_path) # Attempt to fix permissions directly - clear_cache(cache_dir) # Clear cache to remove any incomplete or locked files + clean_cache(cache_dir) # Clear cache to remove any incomplete or locked files else: print(f"Unexpected error, stopping retries for {dataset_path}") raise e @@ -158,9 +155,35 @@ def download_dataset( return dataset +def clean_cache(cache_dir): + """Clears lock files and incomplete downloads from the cache directory.""" + # Find lock and incomplete files + lock_files = glob.glob(cache_dir + "/*lock") + incomplete_files = glob.glob(cache_dir + "/downloads/**/*.incomplete", recursive=True) + try: + if lock_files: + subprocess.run(["rm", *lock_files], check=True) + if incomplete_files: + for file in incomplete_files: + os.remove(file) + print("Hugging Face cache lock files cleared successfully.") + except Exception as e: + print(f"Failed to clear Hugging Face cache lock files: {e}") + + +def fix_permissions(path): + """Attempts to fix permission issues on a given path.""" + try: + subprocess.run(["chmod", "-R", "775", path], check=True) + print(f"Fixed permissions for {path}") + except subprocess.CalledProcessError as e: + print(f"Failed to fix permissions for {path}: {e}") + + if __name__ == '__main__': parser = argparse.ArgumentParser(description='Download Hugging Face datasets for validator challenge generation and miner training.') parser.add_argument('--force_redownload', action='store_true', help='force redownload of datasets') + parser.add_argument('--modality', default='image', choices=['video', 'image'], help='download image or video datasets') parser.add_argument('--cache_dir', type=str, default=HUGGINGFACE_CACHE_DIR, help='huggingface cache directory') args = parser.parse_args() @@ -169,8 +192,18 @@ def download_dataset( download_mode = "force_redownload" os.makedirs(args.cache_dir, exist_ok=True) - clear_cache(args.cache_dir) # Clear the cache of lock and incomplete files. - - for dataset_type in DATASET_META: - for dataset in DATASET_META[dataset_type]: - download_dataset(dataset['path'], dataset.get('name', None), download_mode, args.cache_dir) + clean_cache(args.cache_dir) # Clear the cache of lock and incomplete files. + + if args.modality == 'image': + dataset_meta = IMAGE_DATASETS + #elif args.modality == 'video': + # dataset_meta = VIDEO_DATASET_META + + for dataset_type in dataset_meta: + for dataset in dataset_meta[dataset_type]: + download_dataset( + dataset_path=dataset['path'], + dataset_name=dataset.get('name', None), + download_mode=download_mode, + local_data_path=dataset.get('local_data_path', None), + cache_dir=args.cache_dir) diff --git a/bitmind/image_dataset.py b/bitmind/image_dataset.py deleted file mode 100644 index b37c894c..00000000 --- a/bitmind/image_dataset.py +++ /dev/null @@ -1,159 +0,0 @@ -from typing import List, Tuple -from datasets import Dataset -from PIL import Image -from io import BytesIO -import bittensor as bt -import numpy as np - -from bitmind.download_data import load_huggingface_dataset, download_image - - -class ImageDataset: - - def __init__( - self, - huggingface_dataset_path: str = None, - huggingface_dataset_split: str = 'train', - huggingface_dataset_name: str = None, - huggingface_dataset: Dataset = None, - download_mode: str = None - ): - """ - Args: - huggingface_dataset_path (str): Path to the Hugging Face dataset. Can either be to a publicly hosted - huggingface dataset (/) or a local directory (imagefolder:) - huggingface_dataset_split (str): Split of the dataset to load (default: 'train'). - Make sure to check what splits are available for the datasets you're working with. - huggingface_dataset_name (str): Name of the Hugging Face dataset (default: None). - Some huggingface datasets provide various subets of different sizes, which can be accessed via thi - parameter. - create_splits (bool): Whether to create dataset splits (default: False). - If the huggingface dataset hasn't been pre-split (i.e., it only contains "Train"), we split it here - randomly. - download_mode (str): Download mode for the dataset (default: None). - can be None or "force_redownload" - """ - assert huggingface_dataset_path is not None or huggingface_dataset is not None, \ - "Either huggingface_dataset_path or huggingface_dataset must be provided." - - if huggingface_dataset: - self.dataset = huggingface_dataset - self.huggingface_dataset_path = self.dataset.info.dataset_name - self.huggingface_dataset_split = list(self.dataset.info.splits.keys())[0] - self.huggingface_dataset_name = self.dataset.info.config_name - - else: - self.huggingface_dataset_path = huggingface_dataset_path - self.huggingface_dataset_name = huggingface_dataset_name - self.dataset = load_huggingface_dataset( - huggingface_dataset_path, - huggingface_dataset_split, - huggingface_dataset_name, - download_mode) - self.sampled_images_idx = [] - - def __getitem__(self, index: int) -> dict: - """ - Get an item (image and ID) from the dataset. - - Args: - index (int): Index of the item to retrieve. - - Returns: - dict: Dictionary containing 'image' (PIL image) and 'id' (str). - """ - return self._get_image(index) - - def __len__(self) -> int: - """ - Get the length of the dataset. - - Returns: - int: Length of the dataset. - """ - return len(self.dataset) - - def _get_image(self, index: int) -> dict: - """ - Load an image from self.dataset. Expects self.dataset[i] to be a dictionary containing either 'image' or 'url' - as a key. - - The value associated with the 'image' key should be either a PIL image or a b64 string encoding of - the image. - - The value associated with the 'url' key should be a url that hosts the image (as in - dalle-mini/open-images) - - Args: - index (int): Index of the image in the dataset. - - Returns: - dict: Dictionary containing 'image' (PIL image) and 'id' (str). - """ - sample = self.dataset[int(index)] - if 'url' in sample: - image = download_image(sample['url']) - image_id = sample['url'] - elif 'image_url' in sample: - image = download_image(sample['image_url']) - image_id = sample['image_url'] - elif 'image' in sample: - if isinstance(sample['image'], Image.Image): - image = sample['image'] - elif isinstance(sample['image'], bytes): - image = Image.open(BytesIO(sample['image'])) - else: - raise NotImplementedError - - image_id = '' - if 'name' in sample: - image_id = sample['name'] - elif 'filename' in sample: - image_id = sample['filename'] - - image_id = image_id if image_id != '' else index - - else: - raise NotImplementedError - - # remove alpha channel if download didnt 404 - if image is not None: - image = image.convert('RGB') - - return { - 'image': image, - 'id': image_id, - 'source': self.huggingface_dataset_path - } - - def sample(self, k: int = 1) -> Tuple[List[dict], List[int]]: - """ - Randomly sample k images from self.dataset. Includes retries for failed downloads, in the case that - self.dataset contains urls. - - Args: - k (int): Number of images to sample (default: 1). - - Returns: - Tuple[List[dict], List[int]]: A tuple containing a list of sampled images and their indices. - """ - sampled_images = [] - sampled_idx = [] - while k > 0: - attempts = len(self.dataset) // 2 - for i in range(attempts): - image_idx = np.random.randint(0, len(self.dataset)) - if image_idx not in self.sampled_images_idx: - break - if i >= attempts: - self.sampled_images_idx = [] - try: - image = self._get_image(image_idx) - if image['image'] is not None: - sampled_images.append(image) - sampled_idx.append(image_idx) - self.sampled_images_idx.append(image_idx) - k -= 1 - except Exception as e: - bt.logging.error(e) - continue - - return sampled_images, sampled_idx diff --git a/bitmind/miner/predict.py b/bitmind/miner/predict.py deleted file mode 100644 index 8beba11e..00000000 --- a/bitmind/miner/predict.py +++ /dev/null @@ -1,21 +0,0 @@ -from PIL import Image -import torch - -from bitmind.image_transforms import base_transforms - - -def predict(model: torch.nn.Module, image: Image.Image) -> float: - """ - Perform prediction using a given PyTorch model on an image. You may need to modify this - if you train a custom model. - - Args: - model (torch.nn.Module): The PyTorch model to use for prediction. - image (Image.Image): The input image as a PIL Image. - - Returns: - float: The predicted output value. - """ - image = base_transforms(image).unsqueeze(0).float() - out = model(image).sigmoid().flatten().tolist() - return out[0] \ No newline at end of file diff --git a/bitmind/protocol.py b/bitmind/protocol.py index 996e09dc..95377a6f 100644 --- a/bitmind/protocol.py +++ b/bitmind/protocol.py @@ -17,33 +17,20 @@ # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER # DEALINGS IN THE SOFTWARE. -from pydantic import root_validator, validator +from typing import List +from pydantic import BaseModel, Field from torchvision import transforms from io import BytesIO from PIL import Image import bittensor as bt -import pydantic import base64 +import pydantic import torch +import zlib - -def prepare_image_synapse(image: Image): - """ - Prepares an image for use with ImageSynapse object. - - Args: - image (Image): The input image to be prepared. - - Returns: - ImageSynapse: An instance of ImageSynapse containing the encoded image and a default prediction value. - """ - if isinstance(image, torch.Tensor): - image = transforms.ToPILImage()(image.cpu().detach()) - - image_bytes = BytesIO() - image.save(image_bytes, format="JPEG") - b64_encoded_image = base64.b64encode(image_bytes.getvalue()) - return ImageSynapse(image=b64_encoded_image) +from bitmind.validator.config import TARGET_IMAGE_SIZE +from bitmind.utils.image_transforms import get_base_transforms +base_transforms = get_base_transforms(TARGET_IMAGE_SIZE) # ---- miner ---- @@ -61,6 +48,36 @@ def prepare_image_synapse(image: Image): # predictions = dendrite.query( ImageSynapse( images = b64_images ) ) # assert len(predictions) == len(b64_images) +def prepare_synapse(input_data, modality): + if isinstance(input_data, torch.Tensor): + input_data = transforms.ToPILImage()(input_data.cpu().detach()) + if isinstance(input_data, list) and isinstance(input_data[0], torch.Tensor): + for i, img in enumerate(input_data): + input_data[i] = transforms.ToPILImage()(img.cpu().detach()) + + if modality == 'image': + return prepare_image_synapse(input_data) + elif modality == 'video': + return prepare_video_synapse(input_data) + else: + raise NotImplementedError(f"Unsupported modality: {modality}") + + +def prepare_image_synapse(image: Image): + """ + Prepares an image for use with ImageSynapse object. + + Args: + image (Image): The input image to be prepared. + + Returns: + ImageSynapse: An instance of ImageSynapse containing the encoded image and a default prediction value. + """ + image_bytes = BytesIO() + image.save(image_bytes, format="JPEG") + b64_encoded_image = base64.b64encode(image_bytes.getvalue()) + return ImageSynapse(image=b64_encoded_image) + class ImageSynapse(bt.Synapse): """ @@ -73,6 +90,8 @@ class ImageSynapse(bt.Synapse): >.5 is considered generated/modified, <= 0.5 is considered real. """ + testnet_label: int = -1 # for easier miner eval on testnet + # Required request input, filled by sending dendrite caller. image: str = pydantic.Field( title="Image", @@ -99,3 +118,115 @@ def deserialize(self) -> float: prediction probabilities """ return self.prediction + + +def prepare_video_synapse(frames: List[Image.Image]): + """ + """ + frame_bytes = [] + for frame in frames: + buffer = BytesIO() + frame.save(buffer, format="JPEG") + frame_bytes.append(buffer.getvalue()) + + combined_bytes = b''.join(frame_bytes) + compressed_data = zlib.compress(combined_bytes) + encoded_data = base64.b85encode(compressed_data).decode('utf-8') + return VideoSynapse(video=encoded_data) + +class VideoSynapse(bt.Synapse): + """ + Naive initial VideoSynapse + Better option would be to modify the Dendrite interface to allow multipart/form-data here: + https://github.com/opentensor/bittensor/blob/master/bittensor/core/dendrite.py#L533 + Another higher lift option would be to look into Epistula or Fiber + """ + + testnet_label: int = -1 # for easier miner eval on testnet + + # Required request input, filled by sending dendrite caller. + video: str = pydantic.Field( + title="Video", + description="A wildly inefficient means of sending video data", + default="", + frozen=False + ) + + # Optional request output, filled by receiving axon. + prediction: float = pydantic.Field( + title="Prediction", + description="Probability that the image is AI generated/modified", + default=-1., + frozen=False + ) + + def deserialize(self) -> float: + """ + Deserialize the output. This method retrieves the response from + the miner, deserializes it and returns it as the output of the dendrite.query() call. + + Returns: + - float: The deserialized miner prediction + prediction probabilities + """ + return self.prediction + + +def decode_video_synapse(synapse: VideoSynapse) -> List[torch.Tensor]: + """ + V1 of a function for decoding a VideoSynapse object back into a list of torch tensors. + + Args: + synapse: VideoSynapse object containing the encoded video data + + Returns: + List of torch tensors, each representing a frame from the video + """ + compressed_data = base64.b85decode(synapse.video.encode('utf-8')) + combined_bytes = zlib.decompress(compressed_data) + + # Split the combined bytes into individual JPEG files + # Look for JPEG markers: FF D8 (start) and FF D9 (end) + frames = [] + current_pos = 0 + data_length = len(combined_bytes) + + while current_pos < data_length: + # Find start of JPEG (FF D8) + while current_pos < data_length - 1: + if combined_bytes[current_pos] == 0xFF and combined_bytes[current_pos + 1] == 0xD8: + break + current_pos += 1 + + if current_pos >= data_length - 1: + break + + start_pos = current_pos + + # Find end of JPEG (FF D9) + while current_pos < data_length - 1: + if combined_bytes[current_pos] == 0xFF and combined_bytes[current_pos + 1] == 0xD9: + current_pos += 2 + break + current_pos += 1 + + if current_pos > start_pos: + # Extract the JPEG data + jpeg_data = combined_bytes[start_pos:current_pos] + try: + # Convert to PIL Image + img = Image.open(BytesIO(jpeg_data)) + # Convert to numpy array + frames.append(img) + except Exception as e: + print(f"Error processing frame: {e}") + continue + + frames = frames[:32] # temp + bt.logging.info('transforming video inputs') + frames = base_transforms(frames) + + frames = torch.stack(frames, dim=0) + frames = frames.unsqueeze(0) + print(f'decoded video into tensor with shape {frames.shape}') + return frames diff --git a/bitmind/synthetic_image_generation/README.md b/bitmind/synthetic_data_generation/README.md similarity index 100% rename from bitmind/synthetic_image_generation/README.md rename to bitmind/synthetic_data_generation/README.md diff --git a/bitmind/synthetic_data_generation/__init__.py b/bitmind/synthetic_data_generation/__init__.py new file mode 100644 index 00000000..5c7fbce0 --- /dev/null +++ b/bitmind/synthetic_data_generation/__init__.py @@ -0,0 +1 @@ +from .synthetic_data_generator import SyntheticDataGenerator diff --git a/bitmind/synthetic_data_generation/image_annotation_generator.py b/bitmind/synthetic_data_generation/image_annotation_generator.py new file mode 100644 index 00000000..d2cbd11c --- /dev/null +++ b/bitmind/synthetic_data_generation/image_annotation_generator.py @@ -0,0 +1,244 @@ +import gc +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import torch +from PIL import Image +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + Blip2ForConditionalGeneration, + Blip2Processor, + pipeline, + logging as transformers_logging, +) +from transformers.utils.logging import disable_progress_bar + +import bittensor as bt +from bitmind.validator.config import HUGGINGFACE_CACHE_DIR + +disable_progress_bar() + + +class ImageAnnotationGenerator: + """ + A class for generating and moderating image annotations using transformer models. + + This class provides functionality to generate descriptive captions for images + using BLIP2 models and optionally moderate the generated text using a separate + language model. + """ + + def __init__( + self, + model_name: str, + text_moderation_model_name: str, + device: str = 'cuda', + apply_moderation: bool = True + ) -> None: + """ + Initialize the ImageAnnotationGenerator with specific models and device settings. + + Args: + model_name: The name of the BLIP model for generating image captions. + text_moderation_model_name: The name of the model used for moderating + text descriptions. + device: The device to use. + apply_moderation: Flag to determine whether text moderation should be + applied to captions. + """ + self.model_name = model_name + self.processor = Blip2Processor.from_pretrained( + self.model_name, + cache_dir=HUGGINGFACE_CACHE_DIR + ) + + self.apply_moderation = apply_moderation + self.text_moderation_model_name = text_moderation_model_name + self.text_moderation_pipeline = None + self.model = None + self.device = device + + def is_model_loaded(self) -> bool: + return self.model is not None + + def load_models(self) -> None: + """ + Load the necessary models for image annotation and text moderation onto + the specified device. + """ + if self.is_model_loaded(): + bt.logging.warning( + f"Image annotation model {self.model_name} is already loaded" + ) + return + + bt.logging.info(f"Loading image annotation model {self.model_name}") + self.model = Blip2ForConditionalGeneration.from_pretrained( + self.model_name, + torch_dtype=torch.float16, + cache_dir=HUGGINGFACE_CACHE_DIR + ) + self.model.to(self.device) + bt.logging.info(f"Loaded image annotation model {self.model_name}") + bt.logging.info( + f"Loading annotation moderation model {self.text_moderation_model_name}..." + ) + if self.apply_moderation: + model = AutoModelForCausalLM.from_pretrained( + self.text_moderation_model_name, + torch_dtype=torch.bfloat16, + cache_dir=HUGGINGFACE_CACHE_DIR + ) + + tokenizer = AutoTokenizer.from_pretrained( + self.text_moderation_model_name, + cache_dir=HUGGINGFACE_CACHE_DIR + ) + model = model.to(self.device) + self.text_moderation_pipeline = pipeline( + "text-generation", + model=model, + tokenizer=tokenizer + ) + bt.logging.info( + f"Loaded annotation moderation model {self.text_moderation_model_name}." + ) + + def clear_gpu(self) -> None: + """ + Clear GPU memory by moving models back to CPU and deleting them, + followed by collecting garbage. + """ + bt.logging.info("Clearing GPU memory after generating image annotation") + self.model.to('cpu') + del self.model + self.model = None + if self.text_moderation_pipeline: + self.text_moderation_pipeline.model.to('cpu') + del self.text_moderation_pipeline + self.text_moderation_pipeline = None + gc.collect() + torch.cuda.empty_cache() + + def moderate(self, description: str, max_new_tokens: int = 80) -> str: + """ + Use the text moderation pipeline to make the description more concise + and neutral. + + Args: + description: The text description to be moderated. + max_new_tokens: Maximum number of new tokens to generate in the + moderated text. + + Returns: + The moderated description text, or the original description if + moderation fails. + """ + messages = [ + { + "role": "system", + "content": ( + "[INST]You always concisely rephrase given descriptions, " + "eliminate redundancy, and remove all specific references to " + "individuals by name. You do not respond with anything other " + "than the revised description.[/INST]" + ) + }, + { + "role": "user", + "content": description + } + ] + try: + moderated_text = self.text_moderation_pipeline( + messages, + max_new_tokens=max_new_tokens, + pad_token_id=self.text_moderation_pipeline.tokenizer.eos_token_id, + return_full_text=False + ) + + if isinstance(moderated_text, list): + return moderated_text[0]['generated_text'] + + bt.logging.error("Moderated text did not return a list.") + return description + + except Exception as e: + bt.logging.error(f"An error occurred during moderation: {e}", exc_info=True) + return description + + def generate( + self, + image: Image.Image, + max_new_tokens: int = 20, + verbose: bool = False + ) -> str: + """ + Generate a string description for a given image using prompt-based + captioning and building conversational context. + + Args: + image: The image for which the description is to be generated. + max_new_tokens: The maximum number of tokens to generate for each + prompt. + verbose: If True, additional logging information is printed. + + Returns: + A generated description of the image. + """ + if not verbose: + transformers_logging.set_verbosity_error() + + description = "" + prompts = [ + "An image of", + "The setting is", + "The background is", + "The image type/style is" + ] + + for i, prompt in enumerate(prompts): + description += prompt + ' ' + inputs = self.processor( + image, + text=description, + return_tensors="pt" + ).to(self.device, torch.float16) + + generated_ids = self.model.generate( + **inputs, + max_new_tokens=max_new_tokens + ) + answer = self.processor.batch_decode( + generated_ids, + skip_special_tokens=True + )[0].strip() + + if verbose: + bt.logging.info(f"{i}. Prompt: {prompt}") + bt.logging.info(f"{i}. Answer: {answer}") + + if answer: + answer = answer.rstrip(" ,;!?") + if not answer.endswith('.'): + answer += '.' + description += answer + ' ' + else: + description = description[:-len(prompt) - 1] + + if not verbose: + transformers_logging.set_verbosity_info() + + if description.startswith(prompts[0]): + description = description[len(prompts[0]):] + + description = description.strip() + if not description.endswith('.'): + description += '.' + + if self.apply_moderation: + moderated_description = self.moderate(description) + return moderated_description + + return description diff --git a/bitmind/synthetic_image_generation/utils/image_utils.py b/bitmind/synthetic_data_generation/image_utils.py similarity index 97% rename from bitmind/synthetic_image_generation/utils/image_utils.py rename to bitmind/synthetic_data_generation/image_utils.py index a01627b9..5c419537 100644 --- a/bitmind/synthetic_image_generation/utils/image_utils.py +++ b/bitmind/synthetic_data_generation/image_utils.py @@ -1,7 +1,8 @@ import PIL import os import json -from bitmind.constants import TARGET_IMAGE_SIZE +from bitmind.validator.config import TARGET_IMAGE_SIZE + def resize_image(image: PIL.Image.Image, max_width: int, max_height: int) -> PIL.Image.Image: """Resize the image to fit within specified dimensions while maintaining aspect ratio.""" @@ -20,6 +21,7 @@ def resize_image(image: PIL.Image.Image, max_width: int, max_height: int) -> PIL resized_image = image.resize((new_width, new_height), PIL.Image.LANCZOS) return resized_image + def resize_images_in_directory(directory, target_width=TARGET_IMAGE_SIZE[0], target_height=TARGET_IMAGE_SIZE[1]): """ Resize all images in the specified directory to the target width and height. diff --git a/bitmind/synthetic_data_generation/prompt_utils.py b/bitmind/synthetic_data_generation/prompt_utils.py new file mode 100644 index 00000000..7c5ce81e --- /dev/null +++ b/bitmind/synthetic_data_generation/prompt_utils.py @@ -0,0 +1,39 @@ + + +def get_tokenizer_with_min_len(model): + """ + Returns the tokenizer with the smallest maximum token length from the 't2vis_model` object. + + If a second tokenizer exists, it compares both and returns the one with the smaller + maximum token length. Otherwise, it returns the available tokenizer. + + Returns: + tuple: A tuple containing the tokenizer and its maximum token length. + """ + # Check if a second tokenizer is available in the t2vis_model + if hasattr(model, 'tokenizer_2'): + if model.tokenizer.model_max_length > model.tokenizer_2.model_max_length: + return model.tokenizer_2, model.tokenizer_2.model_max_length + return model.tokenizer, model.tokenizer.model_max_length + + +def truncate_prompt_if_too_long(prompt: str, model): + """ + Truncates the input string if it exceeds the maximum token length when tokenized. + + Args: + prompt (str): The text prompt that may need to be truncated. + + Returns: + str: The original prompt if within the token limit; otherwise, a truncated version of the prompt. + """ + tokenizer, max_token_len = get_tokenizer_with_min_len(model) + tokens = tokenizer(prompt, verbose=False) # Suppress token max exceeded warnings + if len(tokens['input_ids']) < max_token_len: + return prompt + + # Truncate tokens if they exceed the maximum token length, decode the tokens back to a string + truncated_prompt = tokenizer.decode(token_ids=tokens['input_ids'][:max_token_len-1], + skip_special_tokens=True) + tokens = tokenizer(truncated_prompt) + return truncated_prompt \ No newline at end of file diff --git a/bitmind/synthetic_data_generation/synthetic_data_generator.py b/bitmind/synthetic_data_generation/synthetic_data_generator.py new file mode 100644 index 00000000..780b8400 --- /dev/null +++ b/bitmind/synthetic_data_generation/synthetic_data_generator.py @@ -0,0 +1,367 @@ +import gc +import json +import os +import time +import warnings +from pathlib import Path +from typing import Dict, Optional, Any, Union + +import bittensor as bt +import numpy as np +import torch +from diffusers.utils import export_to_video +from PIL import Image + +from bitmind.validator.config import ( + HUGGINGFACE_CACHE_DIR, + TEXT_MODERATION_MODEL, + IMAGE_ANNOTATION_MODEL, + T2VIS_MODELS, + T2VIS_MODEL_NAMES, + T2V_MODEL_NAMES, + TARGET_IMAGE_SIZE, + select_random_t2vis_model, + get_modality +) +from bitmind.synthetic_data_generation.prompt_utils import truncate_prompt_if_too_long +from bitmind.synthetic_data_generation.image_annotation_generator import ImageAnnotationGenerator +from bitmind.validator.cache import ImageCache + + +future_warning_modules_to_ignore = [ + 'diffusers', + 'transformers.tokenization_utils_base' +] + +for module in future_warning_modules_to_ignore: + warnings.filterwarnings("ignore", category=FutureWarning, module=module) + +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True +torch.set_float32_matmul_precision('high') + +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' +os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' + + +class SyntheticDataGenerator: + """ + A class for generating synthetic images and videos based on text prompts. + + This class supports different prompt generation strategies and can utilize + various text-to-video (t2v) and text-to-image (t2i) models. + + Attributes: + use_random_t2vis_model: Whether to randomly select a t2v or t2i for each + generation task. + prompt_type: The type of prompt generation strategy ('random', 'annotation'). + prompt_generator_name: Name of the prompt generation model. + t2vis_model_name: Name of the t2v or t2i model. + image_annotation_generator: The generator object for annotating images if required. + output_dir: Directory to write generated data. + """ + + def __init__( + self, + t2vis_model_name: Optional[str] = None, + use_random_t2vis_model: bool = True, + prompt_type: str = 'annotation', + output_dir: Optional[Union[str, Path]] = None, + image_cache: Optional[ImageCache] = None, + device: str = 'cuda' + ) -> None: + """ + Initialize the SyntheticDataGenerator. + + Args: + t2vis_model_name: Name of the text-to-video or text-to-image model. + use_random_t2vis_model: Whether to randomly select models for generation. + prompt_type: The type of prompt generation strategy. + output_dir: Directory to write generated data. + device: Device identifier. + run_as_daemon: Whether to run generation in the background. + image_cache: Optional image cache instance. + + Raises: + ValueError: If an invalid model name is provided. + NotImplementedError: If an unsupported prompt type is specified. + """ + if not use_random_t2vis_model and t2vis_model_name not in T2VIS_MODEL_NAMES: + raise ValueError( + f"Invalid model name '{t2vis_model_name}'. " + f"Options are {T2VIS_MODEL_NAMES}" + ) + + self.use_random_t2vis_model = use_random_t2vis_model + self.t2vis_model_name = t2vis_model_name + self.t2vis_model = None + self.device = device + + if self.use_random_t2vis_model and t2vis_model_name is not None: + bt.logging.warning( + "t2vis_model_name will be ignored (use_random_t2vis_model=True)" + ) + self.t2vis_model_name = None + + self.prompt_type = prompt_type + if self.prompt_type == 'annotation': + self.image_annotation_generator = ImageAnnotationGenerator( + model_name=IMAGE_ANNOTATION_MODEL, + text_moderation_model_name=TEXT_MODERATION_MODEL + ) + else: + raise NotImplementedError(f"Unsupported prompt type: {self.prompt_type}") + + self.output_dir = Path(output_dir) if output_dir else None + if self.output_dir: + (self.output_dir / "video").mkdir(parents=True, exist_ok=True) + (self.output_dir / "image").mkdir(parents=True, exist_ok=True) + + self.image_cache = image_cache + + def batch_generate(self, batch_size: int = 5) -> None: + """ + Asynchronously generate synthetic data in batches. + + Args: + batch_size: Number of prompts to generate in each batch. + """ + prompts = [] + bt.logging.info(f"Generating {batch_size} prompts") + for i in range(batch_size): + image_sample = self.image_cache.sample() + bt.logging.info(f"Sampled image {i+1}/{batch_size} for captioning: {image_sample['path']}") + prompts.append(self.generate_prompt(image=image_sample['image'], clear_gpu=i==batch_size-1)) + bt.logging.info(f"Caption {i+1}/{batch_size} generated: {prompts[-1]}") + + for model_name in T2VIS_MODEL_NAMES: + modality = get_modality(model_name) + for i, prompt in enumerate(prompts): + bt.logging.info(f"Started generation {i+1}/{batch_size} | Model: {model_name} | Prompt: {prompt}") + + # Generate image/video from current model and prompt + output = self.run_t2vis(prompt, modality, t2vis_model_name=model_name) + + base_path = self.output_dir / modality / str(output['time']) + metadata = {k: v for k, v in output.items() if k != 'gen_output'} + base_path.with_suffix('.json').write_text(json.dumps(metadata)) + + if modality == 'image': + output['gen_output'].images[0].save(base_path.with_suffix('.png')) + elif modality == 'video': + export_to_video( + output['gen_output'].frames[0], + str(base_path.with_suffix('.mp4')), + fps=30 + ) + + def generate( + self, + image: Optional[Image.Image] = None, + modality: str = 'image', + t2vis_model_name: Optional[str] = None + ) -> Dict[str, Any]: + """ + Generate synthetic data based on input parameters. + + Args: + image: Input image for annotation-based generation. + modality: Type of media to generate ('image' or 'video'). + + Returns: + Dictionary containing generated data information. + + Raises: + ValueError: If real_image is None when using annotation prompt type. + NotImplementedError: If prompt type is not supported. + """ + prompt = self.generate_prompt(image, clear_gpu=True) + bt.logging.info("Generating synthetic data...") + gen_data = self.run_t2vis(prompt, modality, t2vis_model_name) + self.clear_gpu() + return gen_data + + def generate_prompt( + self, + image: Optional[Image.Image] = None, + clear_gpu: bool = True + ) -> str: + """Generate a prompt based on the specified strategy.""" + bt.logging.info("Generating prompt") + if self.prompt_type == 'annotation': + if image is None: + raise ValueError( + "image can't be None if self.prompt_type is 'annotation'" + ) + self.image_annotation_generator.load_models() + prompt = self.image_annotation_generator.generate(image) + if clear_gpu: + self.image_annotation_generator.clear_gpu() + else: + raise NotImplementedError(f"Unsupported prompt type: {self.prompt_type}") + return prompt + + def run_t2vis( + self, + prompt: str, + modality: str, + t2vis_model_name: Optional[str] = None, + generate_at_target_size: bool = False, + + ) -> Dict[str, Any]: + """ + Generate synthetic data based on a text prompt. + + Args: + prompt: The text prompt used to inspire the generation. + generate_at_target_size: If True, generate at TARGET_IMAGE_SIZE dimensions. + t2vis_model_name: Optional model name to use for generation. + + Returns: + Dictionary containing generated data and metadata. + + Raises: + RuntimeError: If generation fails. + """ + self.load_t2vis_model(t2vis_model_name) + + bt.logging.info("Preparing generation arguments") + gen_args = T2VIS_MODELS[self.t2vis_model_name].get( + 'generate_args', {}).copy() + + # Process generation arguments + for k, v in gen_args.items(): + if isinstance(v, dict): + gen_args[k] = np.random.randint( + gen_args[k]['min'], + gen_args[k]['max'] + ) + for dim in ('height', 'width'): + if isinstance(gen_args.get(dim), list): + gen_args[dim] = np.random.choice(gen_args[dim]) + + try: + if generate_at_target_size: + gen_args['height'] = TARGET_IMAGE_SIZE[0] + gen_args['width'] = TARGET_IMAGE_SIZE[1] + + truncated_prompt = truncate_prompt_if_too_long( + prompt, + self.t2vis_model + ) + + torch_dtype = T2VIS_MODELS[self.t2vis_model_name].get( + 'from_pretrained_args', {}).get('torch_dtype', torch.bfloat16) + + bt.logging.info("Generating media from prompt") + start_time = time.time() + with torch.autocast(self.device, torch_dtype, cache_enabled=False): + gen_output = self.t2vis_model( + prompt=truncated_prompt, + **gen_args + ) + gen_time = time.time() - start_time + + except Exception as e: + if generate_at_target_size: + bt.logging.error( + f"Attempt with custom dimensions failed, falling back to " + f"default dimensions. Error: {e}" + ) + try: + gen_output = self.t2vis_model(prompt=truncated_prompt) + gen_time = time.time() - start_time + except Exception as fallback_error: + bt.logging.error( + f"Failed to generate image with default dimensions after " + f"initial failure: {fallback_error}" + ) + raise RuntimeError( + f"Both attempts to generate image failed: {fallback_error}" + ) + else: + bt.logging.error(f"Image generation error: {e}") + raise RuntimeError(f"Failed to generate image: {e}") + + return { + 'prompt': truncated_prompt, + 'prompt_long': prompt, + 'gen_output': gen_output, # image or video + 'time': time.time(), + 'model_name': self.t2vis_model_name, + 'gen_time': gen_time + } + + def load_t2vis_model(self, model_name: Optional[str] = None, modality: Optional[str] = None) -> None: + """Load a Hugging Face text-to-image or text-to-video model to a specific GPU.""" + if model_name is not None: + self.t2vis_model_name = model_name + elif self.use_random_t2vis_model or model_name == 'random': + model_name = select_random_t2vis_model(modality) + self.t2vis_model_name = model_name + + bt.logging.info(f"Loading {self.t2vis_model_name}") + + pipeline_cls = T2VIS_MODELS[model_name]['pipeline_cls'] + pipeline_args = T2VIS_MODELS[model_name]['from_pretrained_args'] + + self.t2vis_model = pipeline_cls.from_pretrained( + pipeline_args.get('base', model_name), + cache_dir=HUGGINGFACE_CACHE_DIR, + **pipeline_args, + add_watermarker=False + ) + + self.t2vis_model.set_progress_bar_config(disable=True) + + # Load scheduler if specified + if 'scheduler' in T2VIS_MODELS[model_name]: + sched_cls = T2VIS_MODELS[model_name]['scheduler']['cls'] + sched_args = T2VIS_MODELS[model_name]['scheduler']['from_config_args'] + self.t2vis_model.scheduler = sched_cls.from_config( + self.t2vis_model.scheduler.config, + **sched_args + ) + + # Configure model optimizations + model_config = T2VIS_MODELS[model_name] + if model_config.get('enable_model_cpu_offload', False): + bt.logging.info(f"Enabling cpu offload for {model_name}") + self.t2vis_model.enable_model_cpu_offload() + if model_config.get('enable_sequential_cpu_offload', False): + bt.logging.info(f"Enabling sequential cpu offload for {model_name}") + self.t2vis_model.enable_sequential_cpu_offload() + if model_config.get('vae_enable_slicing', False): + bt.logging.info(f"Enabling vae slicing for {model_name}") + try: + self.t2vis_model.vae.enable_slicing() + except Exception: + try: + self.t2vis_model.enable_vae_slicing() + except Exception: + bt.logging.warning(f"Could not enable vae slicing for {self.t2vis_model}") + if model_config.get('vae_enable_tiling', False): + bt.logging.info(f"Enabling vae tiling for {model_name}") + try: + self.t2vis_model.vae.enable_tiling() + except Exception: + try: + self.t2vis_model.enable_vae_tiling() + except Exception: + bt.logging.warning(f"Could not enable vae tiling for {self.t2vis_model}") + + self.t2vis_model.to(self.device) + bt.logging.info(f"Loaded {model_name} using {pipeline_cls.__name__}.") + + def clear_gpu(self) -> None: + """Clear GPU memory by deleting models and running garbage collection.""" + if self.t2vis_model is not None: + bt.logging.info( + "Deleting previous text-to-image or text-to-video model, " + "freeing memory" + ) + del self.t2vis_model + self.t2vis_model = None + gc.collect() + torch.cuda.empty_cache() + diff --git a/bitmind/synthetic_image_generation/image_annotation_generator.py b/bitmind/synthetic_image_generation/image_annotation_generator.py deleted file mode 100644 index ced54e43..00000000 --- a/bitmind/synthetic_image_generation/image_annotation_generator.py +++ /dev/null @@ -1,344 +0,0 @@ -# Transformer models -from transformers import Blip2Processor, Blip2ForConditionalGeneration, AutoModelForCausalLM, AutoTokenizer, pipeline -import torch - -# Logging and progress handling -from transformers import logging as transformers_logging -from transformers.utils.logging import disable_progress_bar - -from typing import Any, Dict, List, Tuple -import bittensor as bt -import PIL -import time -import torch -import gc - -from bitmind.image_dataset import ImageDataset -from bitmind.synthetic_image_generation.utils import image_utils -from bitmind.constants import HUGGINGFACE_CACHE_DIR - -disable_progress_bar() - - -class ImageAnnotationGenerator: - """ - A class responsible for generating text annotations for images using a transformer-based image captioning model. - It integrates text moderation to ensure the descriptions are concise and neutral. - - Attributes: - device (torch.device): The device (CPU or GPU) on which the models are loaded. - model_name (str): The name of the BLIP model for generating image captions. - processor (Blip2Processor): The processor associated with the BLIP model. - model (Blip2ForConditionalGeneration): The BLIP model used for generating image captions. - apply_moderation (bool): Flag to determine whether text moderation should be applied to captions. - text_moderation_model_name (str): The name of the model used for moderating text descriptions. - text_moderation_pipeline (pipeline): A Hugging Face pipeline for text moderation. - - Methods: - __init__(self, model_name: str, text_moderation_model_name: str, device: str = cuda, apply_moderation: bool = True): - Initializes the ImageAnnotationGenerator with the specified model, device, and moderation settings. - - load_models(self): - Loads the image annotation and text moderation models into memory. - - clear_gpu(self): - Clears GPU memory to ensure that no residual data remains that could affect further operations. - - moderate_description(self, description: str, max_new_tokens: int = 80) -> str: - Moderates the given description to make it more concise and neutral, using the text moderation model. - - generate_description(self, image: PIL.Image.Image, verbose: bool = False, max_new_tokens: int = 20) -> str: - Generates a description for the provided image using the image captioning model. - - generate_annotation(self, image_id, dataset_name: str, image: PIL.Image.Image, original_dimensions: tuple, resize: bool, verbose: int) -> dict: - Generates a text annotation for a given image, including handling image resizing and verbose logging. - - process_image(self, image_info: dict, dataset_name: str, image_index: int, resize: bool, verbose: int) -> Tuple[Any, float]: - Processes a single image from a dataset to generate its annotation and measures the time taken. - - generate_annotations(self, real_image_datasets: List[ImageDataset], verbose: int = 0, max_images: int = None, resize_images: bool = False) -> Dict[str, Dict[str, Any]]: - Generates text annotations for a batch of images from the specified datasets and calculates the average processing latency. - """ - def __init__( - self, model_name: str, text_moderation_model_name: str, device: str = "cuda", - apply_moderation: bool = True - ): - """ - Initializes the ImageAnnotationGenerator with specific models and device settings. - - Args: - model_name (str): The name of the BLIP model for generating image captions. - text_moderation_model_name (str): The name of the model used for moderating text descriptions. - device (str): Device to use for model inference. Defaults to "cuda". - apply_moderation (bool): Flag to determine whether text moderation should be applied to captions. - """ - self.device = device - self.model_name = model_name - self.processor = Blip2Processor.from_pretrained( - self.model_name, cache_dir=HUGGINGFACE_CACHE_DIR - ) - self.model = None - - self.apply_moderation = apply_moderation - self.text_moderation_model_name = text_moderation_model_name - self.text_moderation_pipeline = None - - def load_models(self): - """ - Loads the necessary models for image annotation and text moderation onto the specified device. - """ - self.model = Blip2ForConditionalGeneration.from_pretrained( - self.model_name, - torch_dtype=torch.float16, - cache_dir=HUGGINGFACE_CACHE_DIR - ) - self.model.to(self.device) - if self.apply_moderation: - model = AutoModelForCausalLM.from_pretrained( - self.text_moderation_model_name, - torch_dtype=torch.bfloat16, - cache_dir=HUGGINGFACE_CACHE_DIR - ) - - tokenizer = AutoTokenizer.from_pretrained( - self.text_moderation_model_name, - cache_dir=HUGGINGFACE_CACHE_DIR - ) - model = model.to(self.device) - self.text_moderation_pipeline = pipeline( - "text-generation", - model=model, - tokenizer=tokenizer - ) - - def clear_gpu(self): - """ - Clears GPU memory by moving models back to CPU and deleting them, followed by collecting garbage. - """ - self.model.to('cpu') - del self.model - self.model = None - if self.text_moderation_pipeline: - self.text_moderation_pipeline.model.to('cpu') - del self.text_moderation_pipeline - self.text_moderation_pipeline = None - gc.collect() - torch.cuda.empty_cache() - - def moderate_description(self, description: str, max_new_tokens: int = 80) -> str: - """ - Uses the text moderation pipeline to make the description more concise and neutral. - """ - messages = [ - { - "role": "system", - "content": ("[INST]You always concisely rephrase given descriptions, eliminate redundancy, " - "and remove all specific references to individuals by name. You do not respond with" - "anything other than the revised description.[/INST]") - }, - { - "role": "user", - "content": description - } - ] - try: - moderated_text = self.text_moderation_pipeline(messages, max_new_tokens=max_new_tokens, - pad_token_id=self.text_moderation_pipeline.tokenizer.eos_token_id, - return_full_text=False) - - if isinstance(moderated_text, list): - return moderated_text[0]['generated_text'] - bt.logging.error("Failed to return moderated text.") - else: - bt.logging.error("Moderated text did not return a list.") - - return description # Fallback to the original description if no suitable entry is found - except Exception as e: - bt.logging.error(f"An error occurred during moderation: {e}", exc_info=True) - return description # Return the original description as a fallback - - def generate_description(self, - image: PIL.Image.Image, - verbose: bool = False, - max_new_tokens: int = 20) -> str: - """ - Generates a string description for a given image by interfacing with a transformer - model using prompt-based captioning and building conversational context. - - Args: - image (PIL.Image.Image): The image for which the description is to be generated. - verbose (bool, optional): If True, additional logging information is printed. Defaults to False. - max_new_tokens (int, optional): The maximum number of tokens to generate for each prompt. Defaults to 20. - - Returns: - str: A generated description of the image. - """ - if not verbose: - transformers_logging.set_verbosity_error() - - description = "" - prompts = ["An image of", "The setting is", "The background is", "The image type/style is"] - for i, prompt in enumerate(prompts): - description += prompt + ' ' - inputs = self.processor(image, text=description, return_tensors="pt").to(self.device, torch.float16) - generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens) #GPT2Tokenizer - answer = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() - if verbose: - bt.logging.info(f"{i}. Prompt: {prompt}") - bt.logging.info(f"{i}. Answer: {answer}") - - if answer: - # Remove any ending spaces or punctuation that is not a period - answer = answer.rstrip(" ,;!?") - # Add a period at the end if it's not already there - if not answer.endswith('.'): - answer += '.' - - description += answer + ' ' - else: - description = description[:-len(prompt) - 1] - - if not verbose: - transformers_logging.set_verbosity_info() - - if description.startswith(prompts[0]): - description = description[len(prompts[0]):] - - # Remove any trailing spaces and ensure the description ends with a period - description = description.strip() - if not description.endswith('.'): - description += '.' - if self.apply_moderation: - moderated_description = self.moderate_description(description) - return moderated_description - return description - - def generate_annotation( - self, - image_id, - dataset_name: str, - image: PIL.Image.Image, - original_dimensions: tuple, - resize: bool, - verbose: int) -> dict: - """ - Generate a text annotation for a given image. - - Parameters: - image_id (int or str): The identifier for the image within the dataset. - dataset_name (str): The name of the dataset the image belongs to. - image (PIL.Image.Image): The image object that requires annotation. - original_dimensions (tuple): Original dimensions of the image as (width, height). - resize (bool): Allow image downsizing to maximum dimensions of (1280, 1280). - verbose (int): Verbosity level. - - Returns: - dict: Dictionary containing the annotation data. - """ - image_to_process = image.copy() - if resize: # Downsize if dimension(s) are greater than 1280 - image_to_process = image_utils.resize_image(image_to_process, 1280, 1280) - if verbose > 1 and image_to_process.size != image.size: - bt.logging.info(f"Resized {image_id}: {image.size} to {image_to_process.size}") - try: - description = self.generate_description(image_to_process, verbose > 2) - annotation = { - 'description': description, - 'original_dataset': dataset_name, - 'original_dimensions': f"{original_dimensions[0]}x{original_dimensions[1]}", - 'id': image_id - } - return annotation - except Exception as e: - if verbose > 1: - bt.logging.error(f"Error processing image {image_id} in {dataset_name}: {e}") - return None - - def process_image( - self, - image_info: dict, - dataset_name: str, - image_index: int, - resize: bool, - verbose: int) -> Tuple[Any, float]: - """ - Processes an individual image for annotation, including resizing and verbosity controls, - and calculates the time taken to process the image. - - Args: - image_info (dict): Dictionary containing image data and metadata. - dataset_name (str): The name of the dataset containing the image. - image_index (int): The index of the image within the dataset. - resize (bool): Whether to resize the image before processing. - verbose (int): Verbosity level for logging outputs. - - Returns: - Tuple[Any, float]: A tuple containing the generated annotation (or None if failed) and the time taken to process. - """ - - if image_info['image'] is None: - if verbose > 1: - bt.logging.debug(f"Skipping image {image_index} in dataset {dataset_name} due to missing image data.") - return None, 0 - - original_dimensions = image_info['image'].size - start_time = time.time() - annotation = self.generate_annotation(image_index, - dataset_name, - image_info['image'], - original_dimensions, - resize, - verbose) - time_elapsed = time.time() - start_time - - if annotation is None: - if verbose > 1: - bt.logging.debug(f"Failed to generate annotation for image {image_index} in dataset {dataset_name}") - return None, time_elapsed - - return annotation, time_elapsed - - def generate_annotations( - self, - real_image_datasets: - List[ImageDataset], - verbose: int = 0, - max_images: int = None, - resize_images: bool = False) -> Dict[str, Dict[str, Any]]: - """ - Generates text annotations for images in the given datasets, saves them in a specified directory, - and computes the average per image latency. Returns a dictionary of new annotations and the average latency. - - Parameters: - real_image_datasets (List[Any]): Datasets containing images. - verbose (int): Verbosity level for process messages (Most verbose = 3). - max_images (int): Maximum number of images to annotate. - resize_images (bool) : Allow image downsizing before captioning. - Sets max dimensions to (1280, 1280), maintaining aspect ratio. - - Returns: - Tuple[Dict[str, Dict[str, Any]], float]: A tuple containing the annotations dictionary and average latency. - """ - annotations = {} - total_time = 0 - total_processed_images = 0 - for dataset in real_image_datasets: - dataset_name = dataset.huggingface_dataset_path - processed_images = 0 - dataset_time = 0 - for j, image_info in enumerate(dataset): - annotation, time_elapsed = self.process_image(image_info, - dataset_name, - j, - resize_images, - verbose) - if annotation is not None: - annotations.setdefault(dataset_name, {})[image_info['id']] = annotation - total_time += time_elapsed - dataset_time += time_elapsed - processed_images += 1 - if max_images is not None and len(annotations[dataset_name]) >= max_images: - break - total_processed_images += processed_images - overall_average_latency = total_time / total_processed_images if total_processed_images else 0 - return annotations, overall_average_latency diff --git a/bitmind/synthetic_image_generation/synthetic_image_generator.py b/bitmind/synthetic_image_generation/synthetic_image_generator.py deleted file mode 100644 index 80a2a919..00000000 --- a/bitmind/synthetic_image_generation/synthetic_image_generator.py +++ /dev/null @@ -1,295 +0,0 @@ -from transformers import pipeline -from transformers import set_seed -from diffusers import StableDiffusionXLPipeline, FluxPipeline, StableDiffusionPipeline -import bittensor as bt -import numpy as np -import torch -import random -import time -import re -import gc -import os -import warnings - -from bitmind.constants import ( - TEXT_MODERATION_MODEL, - DIFFUSER_NAMES, - DIFFUSER_ARGS, - DIFFUSER_PIPELINE, - DIFFUSER_CPU_OFFLOAD_ENABLED, - GENERATE_ARGS, - PROMPT_TYPES, - IMAGE_ANNOTATION_MODEL, - TARGET_IMAGE_SIZE -) - -future_warning_modules_to_ignore = [ - 'diffusers', - 'transformers.tokenization_utils_base' -] - -for module in future_warning_modules_to_ignore: - warnings.filterwarnings("ignore", category=FutureWarning, module=module) - -os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' -os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' - -from transformers import pipeline, set_seed -import bittensor as bt - -from bitmind.synthetic_image_generation.image_annotation_generator import ImageAnnotationGenerator -from bitmind.constants import HUGGINGFACE_CACHE_DIR - - -class SyntheticImageGenerator: - """ - A class for generating synthetic images based on text prompts. Supports different prompt generation strategies - and can utilize various image diffuser models to create images. - - Attributes: - use_random_diffuser (bool): Whether to randomly select a diffuser for each generation task. - prompt_type (str): The type of prompt generation strategy (currently only supports 'annotation') - diffuser_name (str): Name of the image diffuser model. - image_annotation_generator (ImageAnnotationGenerator): The generator object for annotating images if required. - image_cache_dir (str): Directory to cache generated images. - device (str): Device to use for model inference. Defaults to "cuda". - """ - def __init__( - self, - prompt_type='annotation', - diffuser_name=DIFFUSER_NAMES[0], - use_random_diffuser=False, - image_cache_dir=None, - device="cuda" - ): - if prompt_type not in PROMPT_TYPES: - raise ValueError(f"Invalid prompt type '{prompt_type}'. Options are {PROMPT_TYPES}") - if not use_random_diffuser and diffuser_name not in DIFFUSER_NAMES: - raise ValueError(f"Invalid diffuser name '{diffuser_name}'. Options are {DIFFUSER_NAMES}") - - self.use_random_diffuser = use_random_diffuser - self.prompt_type = prompt_type - self.device = device - - self.diffuser = None - if self.use_random_diffuser and diffuser_name is not None: - bt.logging.warning("Warning: diffuser_name will be ignored (use_random_diffuser=True)") - self.diffuser_name = None - else: - self.diffuser_name = diffuser_name - - self.image_annotation_generator = None - if self.prompt_type == 'annotation': - self.image_annotation_generator = ImageAnnotationGenerator(model_name=IMAGE_ANNOTATION_MODEL, - text_moderation_model_name=TEXT_MODERATION_MODEL, - device = self.device) - else: - raise NotImplementedError(f"Unsupported prompt_type: {self.prompt_type}") - - self.image_cache_dir = image_cache_dir - if image_cache_dir is not None: - os.makedirs(self.image_cache_dir, exist_ok=True) - - def generate(self, k: int = 1, real_images=None) -> list: - """ - Generates k synthetic images. If self.prompt_type is 'annotation', a BLIP2 captioning pipeline is used - to produce prompts by captioning real images. If self.prompt_type is 'random', an LLM is used to generate - prompts. - - Args: - k (int): Number of images to generate. - - Returns: - list: List of dictionaries containing 'prompt', 'image', and 'id'. - """ - if self.prompt_type == 'annotation': - if real_images is None: - raise ValueError(f"real_images can't be None if self.prompt_type is 'annotation'") - prompts = [ - self.generate_image_caption(real_images[i]) - for i in range(k) - ] - else: - raise NotImplementedError - - if self.use_random_diffuser: - self.load_diffuser('random') - else: - self.load_diffuser(self.diffuser_name) - - gen_data = [] - for prompt in prompts: - image_data = self.generate_image(prompt) - if self.image_cache_dir is not None: - path = os.path.join(self.image_cache_dir, image_data['id']) - image_data['image'].save(path) - gen_data.append(image_data) - self.clear_gpu() # remove diffuser from gpu - - return gen_data - - def clear_gpu(self): - """ - Clears GPU memory by deleting the loaded diffuser and performing garbage collection. - """ - if self.diffuser is not None: - del self.diffuser - gc.collect() - torch.cuda.empty_cache() - self.diffuser = None - - def load_diffuser(self, diffuser_name) -> None: - """ - Loads a Hugging Face diffuser model to a specific GPU. - - Parameters: - diffuser_name (str): Name of the diffuser to load. - """ - if diffuser_name == 'random': - diffuser_name = np.random.choice(DIFFUSER_NAMES, 1)[0] - - self.diffuser_name = diffuser_name - pipeline_class = globals()[DIFFUSER_PIPELINE[diffuser_name]] - self.diffuser = pipeline_class.from_pretrained(diffuser_name, - cache_dir=HUGGINGFACE_CACHE_DIR, - **DIFFUSER_ARGS[diffuser_name], - add_watermarker=False) - self.diffuser.set_progress_bar_config(disable=True) - self.diffuser.to(self.device) - if DIFFUSER_CPU_OFFLOAD_ENABLED[diffuser_name]: - self.diffuser.enable_model_cpu_offload() - - def generate_image_caption(self, image_sample) -> str: - """ - Generates a descriptive caption for a given image sample. - - This function takes an image sample as input, processes the image using a pre-trained - model, and returns a generated caption describing the content of the image. - - Args: - image_sample (dict): A dictionary containing information about the image to be processed. - It includes: - - 'source' (str): The dataset or source name of the image. - - 'id' (int/str): The unique identifier of the image. - - Returns: - str: A descriptive caption generated for the input image. - """ - self.image_annotation_generator.load_models() - annotation = self.image_annotation_generator.process_image( - image_info=image_sample, - dataset_name=image_sample['source'], - image_index=image_sample['id'], - resize=False, - verbose=0 - )[0] - self.image_annotation_generator.clear_gpu() - return annotation['description'] - - def get_tokenizer_with_min_len(self): - """ - Returns the tokenizer with the smallest maximum token length from the 'diffuser` object. - - If a second tokenizer exists, it compares both and returns the one with the smaller - maximum token length. Otherwise, it returns the available tokenizer. - - Returns: - tuple: A tuple containing the tokenizer and its maximum token length. - """ - # Check if a second tokenizer is available in the diffuser - if hasattr(self.diffuser, 'tokenizer_2'): - if self.diffuser.tokenizer.model_max_length > self.diffuser.tokenizer_2.model_max_length: - return self.diffuser.tokenizer_2, self.diffuser.tokenizer_2.model_max_length - return self.diffuser.tokenizer, self.diffuser.tokenizer.model_max_length - - def truncate_prompt_if_too_long(self, prompt: str): - """ - Truncates the input string if it exceeds the maximum token length when tokenized. - - Args: - prompt (str): The text prompt that may need to be truncated. - - Returns: - str: The original prompt if within the token limit; otherwise, a truncated version of the prompt. - """ - tokenizer, max_token_len = self.get_tokenizer_with_min_len() - tokens = tokenizer(prompt, verbose=False) # Suppress token max exceeded warnings - if len(tokens['input_ids']) < max_token_len: - return prompt - # Truncate tokens if they exceed the maximum token length, decode the tokens back to a string - truncated_prompt = tokenizer.decode(token_ids=tokens['input_ids'][:max_token_len-1], - skip_special_tokens=True) - tokens = tokenizer(truncated_prompt) - bt.logging.info("Truncated prompt to abide by token limit.") - return truncated_prompt - - def generate_image(self, prompt, name = None, generate_at_target_size = False) -> list: - """ - Generates a synthetic image based on a text prompt. This function can optionally adjust the generation args of the - diffusion model, such as dimensions and the number of inference steps. - - Args: - prompt (str): The text prompt used to inspire the image generation. - name (str, optional): An optional identifier for the generated image. If not provided, a timestamp-based - identifier is used. - generate_at_target_size (bool, optional): If True, the image is generated at the dimensions specified by the - TARGET_IMAGE_SIZE constant. Otherwise, dimensions are selected based on the diffuser's default or random settings. - - Returns: - dict: A dictionary containing: - - 'prompt': The possibly truncated version of the input prompt. - - 'image': The generated image object. - - 'id': The identifier of the generated image. - - 'gen_time': The time taken to generate the image, measured from the start of the process. - """ - # Generate a unique image name based on current time if not provided - image_name = name if name else f"{time.time():.0f}.jpg" - # Check if the prompt is too long - truncated_prompt = self.truncate_prompt_if_too_long(prompt) - gen_args = {} - - # Load generation arguments based on diffuser settings - if self.diffuser_name in GENERATE_ARGS: - gen_args = GENERATE_ARGS[self.diffuser_name].copy() - - if isinstance(gen_args.get('num_inference_steps'), dict): - gen_args['num_inference_steps'] = np.random.randint( - gen_args['num_inference_steps']['min'], - gen_args['num_inference_steps']['max']) - - for dim in ('height', 'width'): - if isinstance(gen_args.get(dim), list): - gen_args[dim] = np.random.choice(gen_args[dim]) - - try: - if generate_at_target_size: - #Attempt to generate an image with specified dimensions - gen_args['height'] = TARGET_IMAGE_SIZE[0] - gen_args['width'] = TARGET_IMAGE_SIZE[1] - # Record the time taken to generate the image - start_time = time.time() - # Generate image using the diffuser with appropriate arguments - gen_image = self.diffuser(prompt=truncated_prompt, num_images_per_prompt=1, **gen_args).images[0] - # Calculate generation time - gen_time = time.time() - start_time - except Exception as e: - if generate_at_target_size: - bt.logging.error(f"Attempt with custom dimensions failed, falling back to default dimensions. Error: {e}") - try: - # Fallback to generating an image without specifying dimensions - gen_image = self.diffuser(prompt=truncated_prompt).images[0] - gen_time = time.time() - start_time - except Exception as fallback_error: - bt.logging.error(f"Failed to generate image with default dimensions after initial failure: {fallback_error}") - raise RuntimeError(f"Both attempts to generate image failed: {fallback_error}") - else: - bt.logging.error(f"Image generation error: {e}") - raise RuntimeError(f"Failed to generate image: {e}") - - image_data = { - 'prompt': truncated_prompt, - 'image': gen_image, - 'id': image_name, - 'gen_time': gen_time - } - return image_data diff --git a/bitmind/synthetic_image_generation/utils/annotation_utils.py b/bitmind/synthetic_image_generation/utils/annotation_utils.py deleted file mode 100644 index 279009ab..00000000 --- a/bitmind/synthetic_image_generation/utils/annotation_utils.py +++ /dev/null @@ -1,58 +0,0 @@ -import bittensor as bt -import json -import os - -def ensure_save_path(path: str) -> str: - """Ensure that a directory exists; if it does not, create it.""" - if not os.path.exists(path): - os.makedirs(path) - return path - -def create_annotation_dataset_directory(base_path: str, dataset_name: str) -> str: - """Create a directory for a dataset with a safe name, replacing any invalid characters.""" - safe_name = dataset_name.replace("/", "_") - full_path = os.path.join(base_path, safe_name) - if not os.path.exists(full_path): - os.makedirs(full_path) - return full_path - - -def save_annotation(dataset_dir: str, image_id, annotation: dict, verbose: int): - """Save a text annotation to a JSON file if it doesn't already exist.""" - file_path = os.path.join(dataset_dir, f"{image_id}.json") - if os.path.exists(file_path): - if verbose > 0: - bt.logging.info(f"Annotation for {image_id} already exists - Skipping") - return -1 # Skip this image as it already has an annotation - - with open(file_path, 'w') as f: - json.dump(annotation, f, indent=4) - if verbose > 0: - bt.logging.info(f"Created {file_path}") - - return 0 - - -def compute_annotation_latency(self, processed_images: int, dataset_time: float, dataset_name: str) -> float: - if processed_images > 0: - average_latency = dataset_time / processed_images - bt.logging.info(f'Average annotation latency for {dataset_name}: {average_latency:.4f} seconds') - return average_latency - return 0.0 - - -def list_datasets(base_dir: str) -> list[str]: - """List all subdirectories in the base directory.""" - return [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))] - - -def load_annotations(base_dir: str, dataset: str) -> list[dict]: - """Load annotations from JSON files within a specified directory.""" - annotations = [] - path = os.path.join(base_dir, dataset) - for filename in os.listdir(path): - if filename.endswith(".json"): - with open(os.path.join(path, filename), 'r') as file: - data = json.load(file) - annotations.append(data) - return annotations diff --git a/bitmind/synthetic_image_generation/utils/hugging_face_utils.py b/bitmind/synthetic_image_generation/utils/hugging_face_utils.py deleted file mode 100644 index 8507fead..00000000 --- a/bitmind/synthetic_image_generation/utils/hugging_face_utils.py +++ /dev/null @@ -1,81 +0,0 @@ -import os -import json -from datasets import load_dataset -from huggingface_hub import HfApi - -def dataset_exists_on_hf(hf_dataset_name, token): - """Check if the dataset exists on Hugging Face.""" - api = HfApi() - try: - dataset_info = api.dataset_info(hf_dataset_name, token=token) - return True - except Exception as e: - return False - -def numerical_sort(value): - return int(os.path.splitext(os.path.basename(value))[0]) - -def load_and_sort_dataset(data_dir, file_type): - # Get list of filenames in the directory with the given extension - try: - if file_type == 'image': - # List image filenames with common image extensions - valid_extensions = ('.png', '.jpg', '.jpeg', '.bmp', '.gif') - filenames = [os.path.join(data_dir, f) for f in os.listdir(data_dir) - if f.lower().endswith(valid_extensions)] - elif file_type == 'json': - # List json filenames - filenames = [os.path.join(data_dir, f) for f in os.listdir(data_dir) - if f.lower().endswith('.json')] - else: - raise ValueError(f"Unsupported file type: {file_type}") - - if not filenames: - raise FileNotFoundError(f"No files with the extension '{file_type}' \ - found in directory '{data_dir}'") - - # Sort filenames numerically (0, 1, 2, 3, 4). Necessary because - # HF datasets are ordered by string (0, 1, 10, 11, 12). - sorted_filenames = sorted(filenames, key=numerical_sort) - - # Load the dataset with sorted filenames - if file_type == 'image': - return load_dataset("imagefolder", data_files=sorted_filenames) - elif file_type == 'json': - return load_dataset("json", data_files=sorted_filenames) - - except Exception as e: - print(f"Error loading dataset: {e}") - return None - -def upload_to_huggingface(dataset, repo_name, token): - """Uploads the dataset dictionary to Hugging Face.""" - api = HfApi() - api.create_repo(repo_name, repo_type="dataset", private=False, token=token) - dataset.push_to_hub(repo_name) - -def slice_dataset(dataset, start_index, end_index=None): - """ - Slice the dataset according to provided start and end indices. - - Parameters: - dataset (Dataset): The dataset to be sliced. - start_index (int): The index of the first element to include in the slice. - end_index (int, optional): The index of the last element to include in the slice. If None, slices to the end of the dataset. - - Returns: - Dataset: The sliced dataset. - """ - if end_index is not None and end_index < len(dataset): - return dataset.select(range(start_index, end_index)) - else: - return dataset.select(range(start_index, len(dataset))) - -def save_as_json(df, output_dir): - os.makedirs(output_dir, exist_ok=True) # Ensure the directory exists - # Iterate through rows in dataframe - for index, row in df.iterrows(): - file_path = os.path.join(output_dir, f"{row['id']}.json") - # Convert the row to a dictionary and save it as JSON - with open(file_path, 'w', encoding='utf-8') as f: - json.dump(row.to_dict(), f, ensure_ascii=False, indent=4) diff --git a/bitmind/synthetic_image_generation/utils/stress_test.py b/bitmind/synthetic_image_generation/utils/stress_test.py deleted file mode 100644 index 70eb2e0d..00000000 --- a/bitmind/synthetic_image_generation/utils/stress_test.py +++ /dev/null @@ -1,66 +0,0 @@ -import logging -import os -import time -import time - -os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' -os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' -logging.basicConfig(level=logging.ERROR, format='%(asctime)s - %(levelname)s - %(message)s') - -from synthetic_image_generator import SyntheticImageGenerator -from bitmind.image_dataset import ImageDataset -from bitmind.utils.data import sample_dataset_index_name - -from bitmind.constants import DATASET_META - - -def slice_dataset(dataset, start_index, end_index=None): - """ - Slice the dataset according to provided start and end indices. - - Parameters: - dataset (Dataset): The dataset to be sliced. - start_index (int): The index of the first element to include in the slice. - end_index (int, optional): The index of the last element to include in the slice. If None, slices to the end of the dataset. - - Returns: - Dataset: The sliced dataset. - """ - if end_index is not None and end_index < len(dataset): - return dataset.select(range(start_index, end_index)) - else: - return dataset.select(range(start_index, len(dataset))) - - -def main(): - synthetic_image_generator = SyntheticImageGenerator(prompt_type='annotation', - use_random_diffuser=False, - diffuser_name='stabilityai/stable-diffusion-xl-base-1.0') - - # Load the datasets specified in DATASET_META - real_image_datasets = [ - ImageDataset(ds['path'], 'train', ds.get('name', None), ds['create_splits']) - for ds in DATASET_META['real'] - ] - DIFFUSER_NAMES = ['black-forest-labs/FLUX.1-dev'] - for model_name in DIFFUSER_NAMES: - synthetic_image_generator.diffuser_name = model_name # Set the diffuser model - print(f"Testing {model_name}") - for _ in range(11): - # Sample an image from real datasets - real_dataset_index, source_dataset = sample_dataset_index_name(real_image_datasets) - real_dataset = real_image_datasets[real_dataset_index] - images_to_caption, image_indexes = real_dataset.sample(k=1) - - start = time.time() - # Generate synthetic images from sampled real images - sample = synthetic_image_generator.generate(k=1, real_images=images_to_caption)[0] - end = time.time() - - # Logging the results - time_elapsed = end - start - print(f"Model: {model_name}, Time elapsed: {time_elapsed}") - print(sample) # You may want to store these samples differently depending on your needs. - -if __name__ == "__main__": - main() diff --git a/bitmind/utils/config.py b/bitmind/utils/config.py index baa06ec7..c3f9baae 100644 --- a/bitmind/utils/config.py +++ b/bitmind/utils/config.py @@ -87,13 +87,6 @@ def add_args(cls, parser): parser.add_argument("--netuid", type=int, help="Subnet netuid", default=1) - parser.add_argument( - "--neuron.device", - type=str, - help="Device to run on.", - default=get_device(), - ) - parser.add_argument( "--neuron.epoch_length", type=int, @@ -148,19 +141,47 @@ def add_miner_args(cls, parser): """Add miner specific arguments to the parser.""" parser.add_argument( - "--neuron.detector_config", + "--neuron.image_detector_config", type=str, help=".yaml file name in base_miner/deepfake_detectors/configs/ to load for trained model.", default="camo.yaml", ) parser.add_argument( - "--neuron.detector", + "--neuron.image_detector", type=str, help="The DETECTOR_REGISTRY module name of the DeepfakeDetector subclass to use for inference.", default="CAMO", ) - + + parser.add_argument( + "--neuron.image_detector_device", + type=str, + help="Device to run image detection model on.", + default=get_device(), + ) + + parser.add_argument( + "--neuron.video_detector_config", + type=str, + help=".yaml file name in base_miner/deepfake_detectors/configs/ to load for trained model.", + default="tall.yaml", + ) + + parser.add_argument( + "--neuron.video_detector", + type=str, + help="The DETECTOR_REGISTRY module name of the DeepfakeDetector subclass to use for inference.", + default="TALL", + ) + + parser.add_argument( + "--neuron.video_detector_device", + type=str, + help="Device to run image detection model on.", + default=get_device(), + ) + parser.add_argument( "--neuron.name", type=str, @@ -200,6 +221,13 @@ def add_miner_args(cls, parser): def add_validator_args(cls, parser): """Add validator specific arguments to the parser.""" + parser.add_argument( + "--neuron.device", + type=str, + help="Device to run on.", + default=get_device(), + ) + parser.add_argument( "--neuron.prompt_type", type=str, @@ -207,6 +235,34 @@ def add_validator_args(cls, parser): default='annotation', ) + parser.add_argument( + "--neuron.clip_length_min", + type=int, + help="Min length in seconds for video challenge", + default=2, + ) + + parser.add_argument( + "--neuron.clip_length_max", + type=int, + help="Max length in seconds for video challenge", + default=8, + ) + + parser.add_argument( + "--neuron.video_cache_refresh_interval", + type=int, + help="Interval at which to refresh video cache (hours)", + default=4, + ) + + parser.add_argument( + "--neuron.zip_cache_refresh_interval", + type=int, + help="Interval at which to refresh zipped video cache (hours)", + default=12, + ) + parser.add_argument( "--neuron.name", type=str, diff --git a/bitmind/image_transforms.py b/bitmind/utils/image_transforms.py similarity index 69% rename from bitmind/image_transforms.py rename to bitmind/utils/image_transforms.py index 8642e910..e32d3c4c 100644 --- a/bitmind/image_transforms.py +++ b/bitmind/utils/image_transforms.py @@ -2,17 +2,18 @@ import random from PIL import Image import torchvision.transforms as transforms +import torchvision.transforms.functional as F import numpy as np import torch import cv2 -from bitmind.constants import TARGET_IMAGE_SIZE +from bitmind.validator.config import TARGET_IMAGE_SIZE + def center_crop(): def fn(img): m = min(img.size) return transforms.CenterCrop(m)(img) - return fn @@ -21,10 +22,13 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.params = None - def get_params(self, img, scale, ratio): - params = super().get_params(img, scale, ratio) - self.params = params - return params + def forward(self, img, crop_params=None): + if crop_params is None: + i, j, h, w = super().get_params(img, self.scale, self.ratio) + else: + i, j, h, w = crop_params + self.params = {'crop_params': (i, j, h, w)} + return F.resized_crop(img, i, j, h, w, self.size, self.interpolation, antialias=self.antialias) class RandomHorizontalFlipWithParams(transforms.RandomHorizontalFlip): @@ -32,12 +36,12 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.params = None - def forward(self, img): - if torch.rand(1) < self.p: - self.params = True + def forward(self, img, do_flip=False): + if do_flip or (torch.rand(1) < self.p): + self.params = {'do_flip': True} return transforms.functional.hflip(img) else: - self.params = False + self.params = {'do_flip': False} return img @@ -46,12 +50,12 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.params = None - def forward(self, img): - if torch.rand(1) < self.p: - self.params = True + def forward(self, img, do_flip=True): + if do_flip or (torch.rand(1) < self.p): + self.params = {'do_flip': True} return transforms.functional.vflip(img) else: - self.params = False + self.params = {'do_flip': False} return img @@ -60,9 +64,10 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.params = None - def forward(self, img): - angle = self.get_params(self.degrees) - self.params = angle + def forward(self, img, angle=None): + if angle is None: + angle = self.get_params(self.degrees) + self.params = {'angle': angle} return transforms.functional.rotate(img, angle) @@ -303,83 +308,107 @@ def __call__(self, tensor): class ComposeWithParams: - """Compose multiple transforms with parameter tracking.""" - def __init__(self, transforms): self.transforms = transforms self.params = {} - def __call__(self, img): + def __call__(self, input_data): transform_params = { RandomResizedCropWithParams: 'RandomResizedCrop', RandomHorizontalFlipWithParams: 'RandomHorizontalFlip', RandomVerticalFlipWithParams: 'RandomVerticalFlip', RandomRotationWithParams: 'RandomRotation' } - - for transform in self.transforms: - img = transform(img) - if type(transform) in transform_params: - self.params[transform_params[type(transform)]] = transform.params - return img + output_data = [] + list_input = True + if not isinstance(input_data, list): + input_data = [input_data] + list_input = False + + for img in input_data: + for t in self.transforms: + if type(t) in transform_params and transform_params[type(t)] in self.params: + params = self.params[transform_params[type(t)]] + img = t(img, **params) + else: + img = t(img) + if type(t) in transform_params: + self.params[transform_params[type(t)]] = t.params + output_data.append(img) + + if list_input: + return output_data + return output_data[0] # Transform configurations -base_transforms = transforms.Compose([ - ConvertToRGB(), - center_crop(), - transforms.Resize(TARGET_IMAGE_SIZE), - transforms.ToTensor() -]) - -random_aug_transforms = ComposeWithParams([ - ConvertToRGB(), - transforms.ToTensor(), - RandomRotationWithParams(20, interpolation=transforms.InterpolationMode.BILINEAR), - RandomResizedCropWithParams(TARGET_IMAGE_SIZE, scale=(0.2, 1.0), ratio=(1.0, 1.0)), - RandomHorizontalFlipWithParams(), - RandomVerticalFlipWithParams() -]) - -ucf_transforms = transforms.Compose([ - ConvertToRGB(), - center_crop(), - transforms.Resize(TARGET_IMAGE_SIZE), - CLAHE(), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) -]) +def get_base_transforms(target_image_size=TARGET_IMAGE_SIZE): + return ComposeWithParams([ + ConvertToRGB(), + center_crop(), + transforms.Resize(target_image_size), + transforms.ToTensor() + ]) + + +def get_random_augmentations(target_image_size=TARGET_IMAGE_SIZE): + return ComposeWithParams([ + ConvertToRGB(), + transforms.ToTensor(), + RandomRotationWithParams(20, interpolation=transforms.InterpolationMode.BILINEAR), + RandomResizedCropWithParams(TARGET_IMAGE_SIZE, scale=(0.2, 1.0), ratio=(1.0, 1.0)), + RandomHorizontalFlipWithParams(), + RandomVerticalFlipWithParams() + ]) + +def get_ucf_base_transforms(target_image_size=TARGET_IMAGE_SIZE): + return transforms.Compose([ + ConvertToRGB(), + center_crop(), + transforms.Resize(target_image_size), + CLAHE(), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + +def get_tall_base_transforms(target_image_size=TARGET_IMAGE_SIZE): + return ComposeWithParams([ + transforms.Resize(target_image_size), + transforms.ToTensor() + ]) # Medium difficulty transforms with mild distortions -random_aug_transforms_medium = ComposeWithParams([ - ConvertToRGB(), - transforms.ToTensor(), - RandomRotationWithParams(20, interpolation=transforms.InterpolationMode.BILINEAR), - RandomResizedCropWithParams(TARGET_IMAGE_SIZE, scale=(0.2, 1.0), ratio=(1.0, 1.0)), - RandomHorizontalFlipWithParams(), - RandomVerticalFlipWithParams(), - ApplyDeeperForensicsDistortion('CS', level_min=0, level_max=1), - ApplyDeeperForensicsDistortion('CC', level_min=0, level_max=1), - ApplyDeeperForensicsDistortion('JPEG', level_min=0, level_max=1) -]) +def get_random_augmentations_medium(target_image_size=TARGET_IMAGE_SIZE): + return ComposeWithParams([ + ConvertToRGB(), + transforms.ToTensor(), + RandomRotationWithParams(20, interpolation=transforms.InterpolationMode.BILINEAR), + RandomResizedCropWithParams(target_image_size, scale=(0.2, 1.0), ratio=(1.0, 1.0)), + RandomHorizontalFlipWithParams(), + RandomVerticalFlipWithParams(), + ApplyDeeperForensicsDistortion('CS', level_min=0, level_max=1), + ApplyDeeperForensicsDistortion('CC', level_min=0, level_max=1), + ApplyDeeperForensicsDistortion('JPEG', level_min=0, level_max=1) + ]) # Hard difficulty transforms with more severe distortions -random_aug_transforms_hard = ComposeWithParams([ - ConvertToRGB(), - transforms.ToTensor(), - RandomRotationWithParams(20, interpolation=transforms.InterpolationMode.BILINEAR), - RandomResizedCropWithParams(TARGET_IMAGE_SIZE, scale=(0.2, 1.0), ratio=(1.0, 1.0)), - RandomHorizontalFlipWithParams(), - RandomVerticalFlipWithParams(), - ApplyDeeperForensicsDistortion('CS', level_min=0, level_max=2), - ApplyDeeperForensicsDistortion('CC', level_min=0, level_max=2), - ApplyDeeperForensicsDistortion('JPEG', level_min=0, level_max=2), - ApplyDeeperForensicsDistortion('GNC', level_min=0, level_max=2), - ApplyDeeperForensicsDistortion('GB', level_min=0, level_max=2) -]) - - -def apply_augmentation_by_level(image, level_probs={ +def get_random_augmentations_hard(target_image_size=TARGET_IMAGE_SIZE): + return ComposeWithParams([ + ConvertToRGB(), + transforms.ToTensor(), + RandomRotationWithParams(20, interpolation=transforms.InterpolationMode.BILINEAR), + RandomResizedCropWithParams(target_image_size, scale=(0.2, 1.0), ratio=(1.0, 1.0)), + RandomHorizontalFlipWithParams(), + RandomVerticalFlipWithParams(), + ApplyDeeperForensicsDistortion('CS', level_min=0, level_max=2), + ApplyDeeperForensicsDistortion('CC', level_min=0, level_max=2), + ApplyDeeperForensicsDistortion('JPEG', level_min=0, level_max=2), + ApplyDeeperForensicsDistortion('GNC', level_min=0, level_max=2), + ApplyDeeperForensicsDistortion('GB', level_min=0, level_max=2) + ]) + + +def apply_augmentation_by_level(image, target_image_size, level_probs={ 0: 0.25, # No augmentations (base transforms) 1: 0.45, # Basic augmentations 2: 0.15, # Medium distortions @@ -423,16 +452,14 @@ def apply_augmentation_by_level(image, level_probs={ # Apply appropriate transform if level == 0: - transformed = base_transforms(image) - params = {} + tforms = get_base_transforms(target_image_size) elif level == 1: - transformed = random_aug_transforms(image) - params = random_aug_transforms.params + tforms = get_random_augmentations(target_image_size) elif level == 2: - transformed = random_aug_transforms_medium(image) - params = random_aug_transforms_medium.params + tforms = get_random_augmentations_medium(target_image_size) else: # level == 3 - transformed = random_aug_transforms_hard(image) - params = random_aug_transforms_hard.params + tforms = get_random_augmentations_hard(target_image_size) + + transformed = tforms(image) - return transformed, level, params \ No newline at end of file + return transformed, level, tforms.params diff --git a/bitmind/utils/mock.py b/bitmind/utils/mock.py index 110f6cad..bfb6639a 100644 --- a/bitmind/utils/mock.py +++ b/bitmind/utils/mock.py @@ -6,7 +6,7 @@ from typing import List from PIL import Image -from bitmind.constants import DIFFUSER_NAMES +from bitmind.validator.config import T2VIS_MODEL_NAMES as MODEL_NAMES from bitmind.validator.miner_performance_tracker import MinerPerformanceTracker @@ -43,17 +43,17 @@ def sample(self, k=1): return [self.__getitem__(i) for i in range(k)], [i for i in range(k)] -class MockSyntheticImageGenerator: - def __init__(self, prompt_type, use_random_diffuser, diffuser_name): +class MockSyntheticDataGenerator: + def __init__(self, prompt_type, use_random_t2v_model, t2v_model_name): self.prompt_type = prompt_type - self.diffuser_name = diffuser_name - self.use_random_diffuser = use_random_diffuser + self.t2v_model_name = t2v_model_name + self.use_random_t2v_model = use_random_t2v_model - def generate(self, k=1, real_images=None): - if self.use_random_diffuser: - self.load_diffuser('random') + def generate(self, k=1, real_images=None, modality='image'): + if self.use_random_t2v_model: + self.load_t2v_model('random') else: - self.load_diffuser(self.diffuser_name) + self.load_t2v_model(self.t2v_model_name) return [{ 'prompt': f'mock {self.prompt_type} prompt', @@ -61,13 +61,13 @@ def generate(self, k=1, real_images=None): 'id': i } for i in range(k)] - def load_diffuser(self, diffuser_name) -> None: + def load_diffuser(self, t2v_model_name) -> None: """ loads a huggingface diffuser model. """ - if diffuser_name == 'random': - diffuser_name = np.random.choice(DIFFUSER_NAMES, 1)[0] - self.diffuser_name = diffuser_name + if t2v_model_name == 'random': + t2v_model_name = np.random.choice(MODEL_NAMES, 1)[0] + self.t2v_model_name = t2v_model_name class MockValidator: @@ -90,7 +90,7 @@ def __init__(self, config): False) for i in range(3) ] - self.synthetic_image_generator = MockSyntheticImageGenerator( + self.synthetic_data_generator = MockSyntheticDataGenerator( prompt_type='annotation', use_random_diffuser=True, diffuser_name=None) self.total_real_images = sum([len(ds) for ds in self.real_image_datasets]) self.scores = np.zeros(self.metagraph.n, dtype=np.float32) diff --git a/bitmind/utils/video_utils.py b/bitmind/utils/video_utils.py new file mode 100644 index 00000000..ffb2a7ff --- /dev/null +++ b/bitmind/utils/video_utils.py @@ -0,0 +1,26 @@ +import torch + + +def pad_frames(x, divisible_by): + """ + Pads the tensor `x` along the frame dimension (1) until the number of frames is divisible by `divisible_by`. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, num_frames, channels, height, width). + divisible_by (int): The divisor to make the number of frames divisible by. + + Returns: + torch.Tensor: Padded tensor of shape (batch_size, adjusted_num_frames, channels, height, width). + """ + num_frames = x.shape[1] + frame_padding = (divisible_by - (num_frames % divisible_by)) % divisible_by + + if frame_padding > 0: + padding_shape = (x.shape[0], frame_padding, x.shape[2], x.shape[3], x.shape[4]) + x_padding = torch.zeros(padding_shape, device=x.device) # Ensure padding is on the same device + x = torch.cat((x, x_padding), dim=1) + + assert x.shape[1] % divisible_by == 0, ( + f'Frame number mismatch: got {x.shape[1]} frames, not divisible by {divisible_by}.' + ) + return x \ No newline at end of file diff --git a/bitmind/validator/__init__.py b/bitmind/validator/__init__.py index 0b7ddf1a..e69de29b 100644 --- a/bitmind/validator/__init__.py +++ b/bitmind/validator/__init__.py @@ -1,2 +0,0 @@ -from .forward import forward -from .reward import get_rewards diff --git a/bitmind/validator/cache/__init__.py b/bitmind/validator/cache/__init__.py new file mode 100644 index 00000000..8858fff1 --- /dev/null +++ b/bitmind/validator/cache/__init__.py @@ -0,0 +1,3 @@ +from .base_cache import BaseCache +from .image_cache import ImageCache +from .video_cache import VideoCache diff --git a/bitmind/validator/cache/base_cache.py b/bitmind/validator/cache/base_cache.py new file mode 100644 index 00000000..6ee91662 --- /dev/null +++ b/bitmind/validator/cache/base_cache.py @@ -0,0 +1,206 @@ +from abc import ABC, abstractmethod +import asyncio +from datetime import datetime +from pathlib import Path +import time +from typing import Any, Dict, List, Optional, Union + +import bittensor as bt +import huggingface_hub as hf_hub +import numpy as np + +from .util import get_most_recent_update_time, seconds_to_str +from .download import download_files, list_hf_files + + +class BaseCache(ABC): + """ + Abstract base class for managing file caches with compressed sources. + + This class provides the basic infrastructure for maintaining both a compressed + source cache and an extracted cache, with automatic refresh intervals and + background update tasks. + """ + + def __init__( + self, + cache_dir: Union[str, Path], + file_extensions: List[str], + compressed_file_extension: str, + run_updater: bool = False, + datasets: dict = None, + extracted_update_interval: int = 4, + compressed_update_interval: int = 24, + num_samples_per_source: int = 10, + ) -> None: + """ + Initialize the base cache infrastructure. + + Args: + cache_dir: Path to store extracted files + extracted_update_interval: Hours between extracted cache updates + compressed_update_interval: Hours between compressed cache updates + num_samples_per_source: Number of items to extract per source + file_extensions: List of valid file extensions for this cache type + """ + self.cache_dir = Path(cache_dir) + self.cache_dir.mkdir(exist_ok=True, parents=True) + + self.compressed_dir = self.cache_dir / 'sources' + self.compressed_dir.mkdir(exist_ok=True, parents=True) + + self.datasets = datasets + + self.extracted_update_interval = extracted_update_interval * 60 * 60 + self.compressed_update_interval = compressed_update_interval * 60 * 60 + self.num_samples_per_source = num_samples_per_source + self.file_extensions = file_extensions + self.compressed_file_extension = compressed_file_extension + + if run_updater: + try: + self.loop = asyncio.get_running_loop() + except RuntimeError: + self.loop = asyncio.get_event_loop() + + # Initialize caches, blocking to ensure data are available for validator + bt.logging.info(f"Setting up cache at {self.cache_dir}") + bt.logging.info(f"Clearing incomplete sources in {self.compressed_dir}") + self._clear_incomplete_sources() + + if self._compressed_cache_empty(): + bt.logging.info(f"Compressed cache {self.compressed_dir} empty; populating") + # grab 1 zip per source to get started, download more later + self._refresh_compressed_cache(n_per_source=1) + + if self._extracted_cache_empty(): + bt.logging.info(f"Extracted cache {self.cache_dir} empty; populating") + self._refresh_extracted_cache() + + # Start background tasks + bt.logging.info(f"Starting background tasks") + self._compressed_updater_task = self.loop.create_task( + self._run_compressed_updater() + ) + self._extracted_updater_task = self.loop.create_task( + self._run_extracted_updater() + ) + + def _get_cached_files(self) -> List[Path]: + """Get list of all extracted files in cache directory.""" + return [ + f for f in self.cache_dir.iterdir() + if f.suffix.lower() in self.file_extensions + ] + + def _get_compressed_files(self) -> List[Path]: + """Get list of all compressed files in compressed directory.""" + return list(self.compressed_dir.glob(f'*{self.compressed_file_extension}')) + + def _extracted_cache_empty(self) -> bool: + """Check if extracted cache directory is empty.""" + return len(self._get_cached_files()) == 0 + + def _compressed_cache_empty(self) -> bool: + """Check if compressed cache directory is empty.""" + return len(self._get_compressed_files()) == 0 + + async def _run_extracted_updater(self) -> None: + """Asynchronously refresh extracted files according to update interval.""" + while True: + try: + last_update = get_most_recent_update_time(self.cache_dir) + time_elapsed = time.time() - last_update + + if time_elapsed >= self.extracted_update_interval: + bt.logging.info(f"Refreshing cache [{self.cache_dir}]") + self._refresh_extracted_cache() + bt.logging.info(f"Cache refresh complete [{self.cache_dir}]") + + sleep_time = max(0, self.extracted_update_interval - time_elapsed) + bt.logging.info(f"Next cache refresh in {seconds_to_str(sleep_time)} [{self.compressed_dir}]") + await asyncio.sleep(sleep_time) + except Exception as e: + bt.logging.error(f"Error in extracted cache update: {e}") + await asyncio.sleep(60) + + async def _run_compressed_updater(self) -> None: + """Asynchronously refresh compressed files according to update interval.""" + while True: + try: + self._clear_incomplete_sources() + last_update = get_most_recent_update_time(self.compressed_dir) + time_elapsed = time.time() - last_update + + if time_elapsed >= self.compressed_update_interval: + bt.logging.info(f"Refreshing cache [{self.compressed_dir}]") + self._refresh_compressed_cache(n_per_source=1) + bt.logging.info(f"Cache refresh complete [{self.cache_dir}]") + + sleep_time = max(0, self.compressed_update_interval - time_elapsed) + bt.logging.info(f"Next cache refresh in {seconds_to_str(sleep_time)} [{self.compressed_dir}]") + await asyncio.sleep(sleep_time) + except Exception as e: + bt.logging.error(f"Error in compressed cache update: {e}") + await asyncio.sleep(60) + + def _refresh_compressed_cache(self, n_per_source) -> None: + """ + Refresh the compressed file cache with new downloads. + """ + try: + bt.logging.info(f"{len(self._get_compressed_files())} compressed sources currently cached") + + new_files: List[Path] = [] + for source in self.datasets: + filenames = list_hf_files( + repo_id=source['path'], + extension=self.compressed_file_extension) + remote_paths = [ + f"https://huggingface.co/datasets/{source['path']}/resolve/main/{f}" + for f in filenames + ] + bt.logging.info(f"Downloading {n_per_source} from {source['path']} to {self.compressed_dir}") + new_files += download_files( + urls=np.random.choice(remote_paths, n_per_source), + output_dir=self.compressed_dir) + + if new_files: + bt.logging.info(f"{len(new_files)} new files added to {self.compressed_dir}") + else: + bt.logging.error(f"No new files were added to {self.compressed_dir}") + + except Exception as e: + bt.logging.error(f"Error during compressed refresh for {self.compressed_dir}: {e}") + raise + + def _refresh_extracted_cache(self) -> None: + """Refresh the extracted cache with new selections.""" + bt.logging.info(f"{len(self._get_compressed_files())} files currently cached") + new_files = self._extract_random_items() + if new_files: + bt.logging.info(f"{len(new_files)} new files added to {self.cache_dir}") + else: + bt.logging.error(f"No new files were added to {self.cache_dir}") + + @abstractmethod + def _extract_random_items(self) -> List[Path]: + """Remove any incomplete or corrupted source files from cache.""" + pass + + @abstractmethod + def _clear_incomplete_sources(self) -> None: + """Remove any incomplete or corrupted source files from cache.""" + pass + + @abstractmethod + def sample(self, num_samples: int) -> Optional[Dict[str, Any]]: + """Sample random items from the cache.""" + pass + + def __del__(self) -> None: + """Cleanup background tasks on deletion.""" + if hasattr(self, '_extracted_updater_task'): + self._extracted_updater_task.cancel() + if hasattr(self, '_compressed_updater_task'): + self._compressed_updater_task.cancel() diff --git a/bitmind/validator/cache/download.py b/bitmind/validator/cache/download.py new file mode 100644 index 00000000..b5d45978 --- /dev/null +++ b/bitmind/validator/cache/download.py @@ -0,0 +1,164 @@ +import requests +import os +from pathlib import Path +from requests.exceptions import RequestException +from typing import List, Union, Dict, Optional + +import bittensor as bt +import huggingface_hub as hf_hub + + +def download_files( + urls: List[str], + output_dir: Union[str, Path], + chunk_size: int = 8192 +) -> List[Path]: + """ + Downloads multiple files synchronously. + + Args: + urls: List of URLs to download + output_dir: Directory to save the files + chunk_size: Size of chunks to download at a time + + Returns: + List of successfully downloaded file paths + """ + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + downloaded_files = [] + + for url in urls: + try: + bt.logging.info(f'Downloading {url}') + response = requests.get(url, stream=True) + if response.status_code != 200: + bt.logging.error(f'Failed to download {url}: Status {response.status_code}') + continue + + filename = os.path.basename(url) + filepath = output_dir / filename + + bt.logging.info(f'Writing to {filepath}') + with open(filepath, 'wb') as f: + for chunk in response.iter_content(chunk_size=chunk_size): + if chunk: # filter out keep-alive chunks + f.write(chunk) + + downloaded_files.append(filepath) + bt.logging.info(f'Successfully downloaded {filename}') + + except Exception as e: + bt.logging.error(f'Error downloading {url}: {str(e)}') + continue + + return downloaded_files + + +def list_hf_files(repo_id, repo_type='dataset', extension=None): + files = [] + try: + files = list(hf_hub.list_repo_files(repo_id=repo_id, repo_type=repo_type)) + if extension: + files = [f for f in files if f.endswith(extension)] + except Exception as e: + bt.logging.error(f"Failed to list files of type {extension} in {repo_id}: {e}") + return files + + +def openvid1m_err_handler( + base_zip_url: str, + output_path: Path, + part_index: int, + chunk_size: int = 8192, + timeout: int = 300 +) -> Optional[Path]: + """ + Synchronous error handler for OpenVid1M downloads that handles split files. + + Args: + base_zip_url: Base URL for the zip parts + output_path: Directory to save files + part_index: Index of the part to download + chunk_size: Size of download chunks + timeout: Download timeout in seconds + + Returns: + Path to combined file if successful, None otherwise + """ + part_urls = [ + f"{base_zip_url}{part_index}_partaa", + f"{base_zip_url}{part_index}_partab" + ] + error_log_path = output_path / "download_log.txt" + downloaded_parts = [] + + # Download each part + for part_url in part_urls: + part_file_path = output_path / Path(part_url).name + + if part_file_path.exists(): + bt.logging.warning(f"File {part_file_path} exists.") + downloaded_parts.append(part_file_path) + continue + + try: + response = requests.get(part_url, stream=True, timeout=timeout) + if response.status_code != 200: + raise RequestException( + f"HTTP {response.status_code}: {response.reason}" + ) + + with open(part_file_path, 'wb') as f: + for chunk in response.iter_content(chunk_size=chunk_size): + if chunk: # filter out keep-alive chunks + f.write(chunk) + + bt.logging.info(f"File {part_url} saved to {part_file_path}") + downloaded_parts.append(part_file_path) + + except Exception as e: + error_message = f"File {part_url} download failed: {str(e)}\n" + bt.logging.error(error_message) + with open(error_log_path, "a") as error_log_file: + error_log_file.write(error_message) + return None + + if len(downloaded_parts) == len(part_urls): + try: + combined_file = output_path / f"OpenVid_part{part_index}.zip" + combined_data = bytearray() + for part_path in downloaded_parts: + with open(part_path, 'rb') as part_file: + combined_data.extend(part_file.read()) + + with open(combined_file, 'wb') as out_file: + out_file.write(combined_data) + + for part_path in downloaded_parts: + part_path.unlink() + + bt.logging.info(f"Successfully combined parts into {combined_file}") + return combined_file + + except Exception as e: + error_message = f"Failed to combine parts for index {part_index}: {str(e)}\n" + bt.logging.error(error_message) + with open(error_log_path, "a") as error_log_file: + error_log_file.write(error_message) + return None + + return None + + """ +data_folder = output_path / "data" / "train" +data_folder.mkdir(parents=True, exist_ok=True) +data_urls = [ + "https://huggingface.co/datasets/nkp37/OpenVid-1M/resolve/main/data/train/OpenVid-1M.csv", + "https://huggingface.co/datasets/nkp37/OpenVid-1M/resolve/main/data/train/OpenVidHD.csv" +] +for data_url in data_urls: + data_path = data_folder / Path(data_url).name + command = ["wget", "-O", str(data_path), data_url] + subprocess.run(command, check=True) +""" diff --git a/bitmind/validator/cache/extract.py b/bitmind/validator/cache/extract.py new file mode 100644 index 00000000..1b34958b --- /dev/null +++ b/bitmind/validator/cache/extract.py @@ -0,0 +1,197 @@ +import base64 +import hashlib +import json +import logging +import mimetypes +import os +import random +import warnings +from datetime import datetime +from io import BytesIO +from pathlib import Path +from typing import Dict, List, Optional, Set, Tuple +from zipfile import ZipFile + +from PIL import Image +import pyarrow.parquet as pq +import bittensor as bt + + +def extract_videos_from_zip( + zip_path: Path, + dest_dir: Path, + num_videos: int, + file_extensions: Set[str] = {'.mp4', '.avi', '.mov', '.mkv', '.wmv'}, + include_checksums: bool = True +) -> List[Tuple[str, str]]: + """ + Extract random videos and their metadata from a zip file and save them to disk. +q + Args: + zip_path: Path to the zip file + dest_dir: Directory to save videos and metadata + num_videos: Number of videos to extract + file_extensions: Set of valid video file extensions + include_checksums: Whether to calculate and include file checksums in metadata + + Returns: + List of tuples containing (video_path, metadata_path) + """ + dest_dir = Path(dest_dir) + dest_dir.mkdir(parents=True, exist_ok=True) + + extracted_files = [] + try: + with ZipFile(zip_path) as zip_file: + video_files = [ + f for f in zip_file.namelist() + if any(f.lower().endswith(ext) for ext in file_extensions) + ] + if not video_files: + bt.logging.warning(f"No video files found in {zip_path}") + return extracted_files + + bt.logging.info(f"{len(video_files)} video files found in {zip_path}") + selected_videos = random.sample( + video_files, + min(num_videos, len(video_files)) + ) + + bt.logging.info(f"Extracting {len(selected_videos)} randomly sampled video files from {zip_path}") + for idx, video in enumerate(selected_videos): + try: + zip_basename = zip_path.name.split('.zip')[0] + original_filename = Path(video).name + base_filename = f"{zip_basename}_{idx}_{original_filename}" + + # extract video and get metadata + video_path = dest_dir / base_filename + temp_path = Path(zip_file.extract(video, path=dest_dir)) + temp_path.rename(video_path) + + video_info = zip_file.getinfo(video) + metadata = { + 'source_zip': str(zip_path), + 'original_filename': original_filename, + 'original_path_in_zip': video, + 'extraction_date': datetime.now().isoformat(), + 'file_size': os.path.getsize(video_path), + 'mime_type': mimetypes.guess_type(video_path)[0], + 'zip_metadata': { + 'compress_size': video_info.compress_size, + 'file_size': video_info.file_size, + 'compress_type': video_info.compress_type, + 'date_time': datetime.strftime( + datetime(*video_info.date_time), + '%Y-%m-%d %H:%M:%S' + ), + } + } + + if include_checksums: + with open(video_path, 'rb') as f: + file_data = f.read() + metadata['checksums'] = { + 'md5': hashlib.md5(file_data).hexdigest(), + 'sha256': hashlib.sha256(file_data).hexdigest() + } + + metadata_filename = f"{video_path.stem}_metadata.json" + metadata_path = dest_dir / metadata_filename + + with open(metadata_path, 'w', encoding='utf-8') as f: + json.dump(metadata, f, indent=2, ensure_ascii=False) + + extracted_files.append((str(video_path), str(metadata_path))) + logging.info(f"Extracted {original_filename} from {zip_path}") + + except Exception as e: + bt.logging.warning(f"Error extracting {video}: {e}") + if 'temp_path' in locals() and temp_path.exists(): + temp_path.unlink() + continue + + except Exception as e: + bt.logging.warning(f"Error processing zip file {zip_path}: {e}") + + return extracted_files + + +def extract_images_from_parquet( + parquet_path: Path, + dest_dir: Path, + num_images: int, + columns: Optional[List[str]] = None, + seed: Optional[int] = None +) -> List[Tuple[str, str]]: + """ + Extract random images and their metadata from a parquet file and save them to disk. + + Args: + parquet_path: Path to the parquet file + dest_dir: Directory to save images and metadata + num_images: Number of images to extract + columns: Specific columns to include in metadata + seed: Random seed for sampling + + Returns: + List of tuples containing (image_path, metadata_path) + """ + dest_dir = Path(dest_dir) + dest_dir.mkdir(parents=True, exist_ok=True) + + # read parquet file, sample random image rows + table = pq.read_table(parquet_path) + df = table.to_pandas() + sample_df = df.sample(n=min(num_images, len(df)), random_state=seed) + image_col = next((col for col in sample_df.columns if 'image' in col.lower()), None) + metadata_cols = [c for c in sample_df.columns if c != image_col] + + saved_files = [] + for idx, row in sample_df.iterrows(): + try: + img_data = row[image_col] + if isinstance(img_data, dict): + key = next((k for k in img_data if 'bytes' in k.lower() or 'image' in k.lower()), None) + img_data = img_data[key] + + try: + img = Image.open(BytesIO(img_data)) + except Exception as e: + img_data = base64.b64decode(img_data) + img = Image.open(BytesIO(img_data)) + + base_filename = f"image_{idx}" + image_format = img.format.lower() if img.format else 'png' + img_filename = f"{base_filename}.{image_format}" + img_path = dest_dir / img_filename + img.save(img_path) + + metadata = { + 'source_parquet': str(parquet_path), + 'original_index': str(idx), + 'image_format': image_format, + 'image_size': img.size, + 'image_mode': img.mode + } + + for col in metadata_cols: + # Convert any non-serializable types to strings + try: + json.dumps({col: row[col]}) + metadata[col] = row[col] + except (TypeError, OverflowError): + metadata[col] = str(row[col]) + + metadata_filename = f"{base_filename}.json" + metadata_path = dest_dir / metadata_filename + with open(metadata_path, 'w', encoding='utf-8') as f: + json.dump(metadata, f, indent=2, ensure_ascii=False) + + saved_files.append(str(img_path)) + + except Exception as e: + warnings.warn(f"Failed to extract/save image {idx}: {e}") + continue + + return saved_files \ No newline at end of file diff --git a/bitmind/validator/cache/image_cache.py b/bitmind/validator/cache/image_cache.py new file mode 100644 index 00000000..4920500e --- /dev/null +++ b/bitmind/validator/cache/image_cache.py @@ -0,0 +1,123 @@ +import json +import random +from pathlib import Path +from typing import Dict, List, Optional, Union, Any + +import bittensor as bt +from PIL import Image + +from .base_cache import BaseCache +from .extract import extract_images_from_parquet +from .util import is_parquet_complete + + +class ImageCache(BaseCache): + """ + A class to manage image caching from parquet files. + + This class handles the caching, updating, and sampling of images stored + in parquet files. It maintains both a compressed cache of parquet files + and an extracted cache of images ready for processing. + """ + + def __init__( + self, + cache_dir: Union[str, Path], + run_updater: bool = False, + datasets: Optional[dict] = None, + parquet_update_interval: int = 24, + image_update_interval: int = 2, + num_images_per_source: int = 100, + ) -> None: + """ + Args: + cache_dir: Path to store extracted images + parquet_update_interval: Hours between parquet cache updates + image_update_interval: Hours between image cache updates + num_images_per_source: Number of images to extract per parquet + """ + super().__init__( + cache_dir=cache_dir, + datasets=datasets, + extracted_update_interval=image_update_interval, + compressed_update_interval=parquet_update_interval, + num_samples_per_source=num_images_per_source, + file_extensions=['.jpg', '.jpeg', '.png'], + compressed_file_extension='.parquet', + run_updater=run_updater + ) + + def _clear_incomplete_sources(self) -> None: + """Remove any incomplete or corrupted parquet files.""" + for path in self._get_compressed_files(): + if path.suffix == '.parquet' and not is_parquet_complete(path): + try: + path.unlink() + bt.logging.warning(f"Removed incomplete parquet file {path}") + except Exception as e: + bt.logging.error(f"Error removing incomplete parquet {path}: {e}") + + def _extract_random_items(self) -> List[Path]: + """ + Extract random videos from zip files in compressed directory. + + Returns: + List of paths to extracted video files. + """ + extracted_files = [] + parquet_files = self._get_compressed_files() + if not parquet_files: + bt.logging.warning(f"No parquet files found in {self.compressed_dir}") + return extracted_files + + for parquet_file in parquet_files: + try: + extracted_files += extract_images_from_parquet( + parquet_file, + self.cache_dir, + self.num_samples_per_source + ) + except Exception as e: + bt.logging.error(f"Error processing parquet file {parquet_file}: {e}") + return extracted_files + + def sample(self) -> Optional[Dict[str, Any]]: + """ + Sample a random image and its metadata from the cache. + + Returns: + Dictionary containing: + - image: PIL Image + - path: Path to source file + - dataset: Source dataset name + - metadata: Metadata dict + Returns None if no valid image is available. + """ + cached_files = self._get_cached_files() + if not cached_files: + bt.logging.warning("No images available in cache") + return None + + attempts = 0 + max_attempts = len(cached_files) * 2 + + while attempts < max_attempts: + attempts += 1 + image_path = random.choice(cached_files) + + try: + image = Image.open(image_path) + metadata = json.loads(image_path.with_suffix('.json').read_text()) + return { + 'image': image, + 'path': str(image_path), + 'dataset': metadata.get('dataset', None), + 'index': metadata.get('index', None) + } + + except Exception as e: + bt.logging.warning(f"Failed to load image {image_path}: {e}") + continue + + bt.logging.warning(f"Failed to find valid image after {attempts} attempts") + return None diff --git a/bitmind/validator/cache/util.py b/bitmind/validator/cache/util.py new file mode 100644 index 00000000..d429db48 --- /dev/null +++ b/bitmind/validator/cache/util.py @@ -0,0 +1,77 @@ +from pathlib import Path +from typing import Union, Callable +from zipfile import ZipFile, BadZipFile +from enum import Enum, auto +import asyncio +import pyarrow.parquet as pq +import bittensor as bt + + +def seconds_to_str(seconds): + seconds = int(float(seconds)) + hours = seconds // 3600 + minutes = (seconds % 3600) // 60 + seconds = seconds % 60 + return f"{hours:02}:{minutes:02}:{seconds:02}" + + +def get_most_recent_update_time(directory: Path) -> float: + """Get the most recent modification time of any file in directory.""" + try: + mtimes = [f.stat().st_mtime for f in directory.iterdir()] + return max(mtimes) if mtimes else 0 + except Exception as e: + bt.logging.error(f"Error getting modification times: {e}") + return 0 + + +class FileType(Enum): + PARQUET = auto() + ZIP = auto() + + +def get_integrity_check(file_type: FileType) -> Callable[[Path], bool]: + """Returns the appropriate validation function for the file type.""" + if file_type == FileType.PARQUET: + return is_parquet_complete + elif file_type == FileType.ZIP: + return is_zip_complete + raise ValueError(f"Unsupported file type: {file_type}") + + +def is_zip_complete(zip_path: Union[str, Path], testzip=False) -> bool: + """ + Args: + zip_path: Path to zip file + testzip: More thorough, less efficient + Returns: + bool: True if zip is valid, False otherwise + """ + try: + with ZipFile(zip_path) as zf: + if testzip: + zf.testzip() + else: + zf.namelist() + return True + except (BadZipFile, Exception) as e: + bt.logging.error(f"Zip file {zip_path} is invalid: {e}") + return False + + +def is_parquet_complete(path: Path) -> bool: + """ + Args: + path: Path to the parquet file + + Returns: + bool: True if file is valid, False otherwise + """ + try: + with open(path, 'rb') as f: + pq.read_metadata(f) + return True + except Exception as e: + bt.logging.error(f"Parquet file {path} is incomplete or corrupted: {e}") + return False + diff --git a/bitmind/validator/cache/video_cache.py b/bitmind/validator/cache/video_cache.py new file mode 100644 index 00000000..affc3bac --- /dev/null +++ b/bitmind/validator/cache/video_cache.py @@ -0,0 +1,168 @@ +import random +from io import BytesIO +from pathlib import Path +from typing import Dict, List, Optional, Union + +import bittensor as bt +import ffmpeg +from PIL import Image + +from .base_cache import BaseCache +from .extract import extract_videos_from_zip +from .util import is_zip_complete +from bitmind.validator.video_utils import get_video_duration + + +class VideoCache(BaseCache): + """ + A class to manage video caching and processing operations. + + This class handles the caching, updating, and sampling of video files from + compressed archives and optionally YouTube. It maintains both a compressed + cache of source files and an extracted cache of video files ready for processing. + """ + + def __init__( + self, + cache_dir: Union[str, Path], + run_updater: bool = False, + datasets: Optional[dict] = None, + video_update_interval: int = 2, + zip_update_interval: int = 24, + num_videos_per_source: int = 10 + ) -> None: + """ + Initialize the VideoCache. + + Args: + cache_dir: Path to store extracted video files + video_update_interval: Hours between video cache updates + zip_update_interval: Hours between zip cache updates + num_videos_per_source: Number of videos to extract per source + use_youtube: Whether to include YouTube videos + """ + super().__init__( + cache_dir=cache_dir, + datasets=datasets, + extracted_update_interval=video_update_interval, + compressed_update_interval=zip_update_interval, + num_samples_per_source=num_videos_per_source, + file_extensions=['.mp4', '.avi', '.mov', '.mkv'], + compressed_file_extension='.zip', + run_updater=run_updater + ) + + def _clear_incomplete_sources(self) -> None: + """Remove any incomplete or corrupted zip files from cache.""" + for path in self._get_compressed_files(): + if path.suffix == '.zip' and not is_zip_complete(path): + try: + path.unlink() + bt.logging.warning(f"Removed incomplete zip file {path}") + except Exception as e: + bt.logging.error(f"Error removing incomplete zip {path}: {e}") + + def _extract_random_items(self) -> List[Path]: + """ + Extract random videos from zip files in compressed directory. + + Returns: + List of paths to extracted video files. + """ + extracted_files = [] + zip_files = self._get_compressed_files() + if not zip_files: + bt.logging.warning(f"No zip files found in {self.compressed_dir}") + return extracted_files + + for zip_file in zip_files: + try: + extracted_files += extract_videos_from_zip( + zip_file, + self.cache_dir, + self.num_samples_per_source) + except Exception as e: + bt.logging.error(f"Error processing zip file {zip_file}: {e}") + + return extracted_files + + def sample( + self, + num_seconds: int = 6 + ) -> Optional[Dict[str, Union[List[Image.Image], str, float]]]: + """ + Sample random frames from a random video in the cache. + + Args: + num_seconds: Number of consecutive frames to sample + + Returns: + Dictionary containing: + - video: List of sampled video frames as PIL Images + - path: Path to source video file + - dataset: Name of source dataset + - total_duration: Total video duration in seconds + - sampled_length: Number of seconds sampled + Returns None if no videos are available or extraction fails. + """ + video_files = self._get_cached_files() + if not video_files: + bt.logging.warning("No videos available in cache") + return None + + video_path = random.choice(video_files) + if not Path(video_path).exists(): + bt.logging.error(f"Selected video {video_path} not found") + return None + + duration = get_video_duration(str(video_path)) + start_time = random.uniform(0, max(0, duration - num_seconds)) + frames: List[Image.Image] = [] + + start_time = random.uniform(0, max(0, duration - num_seconds)) + bt.logging.info(f'Extracting frames starting atq {start_time:.2f}s') + + for second in range(num_seconds): + timestamp = start_time + second + + try: + # extract frames + out_bytes, err = ( + ffmpeg + .input(str(video_path), ss=str(timestamp)) + .filter('select', 'eq(n,0)') + .output('pipe:', + vframes=1, + format='image2', + vcodec='mjpeg', + loglevel='error', # silence ffmpeg output + **{'qscale:v': 2} # Better quality JPEG + ) + .run(capture_stdout=True, capture_stderr=True) + ) + + if not out_bytes: + bt.logging.error(f'No data received for frame at {timestamp}s') + continue + + try: + frame = Image.open(BytesIO(out_bytes)) + frame.load() # Verify image can be loaded + frames.append(frame) + bt.logging.debug(f'Successfully extracted frame at {timestamp}s') + except Exception as e: + bt.logging.error(f'Failed to process frame at {timestamp}s: {e}') + continue + + except ffmpeg.Error as e: + bt.logging.error(f'FFmpeg error at {timestamp}s: {e.stderr.decode()}') + continue + + bt.logging.success(f"Sampled {num_seconds}s of video") + return { + 'video': frames, + 'path': str(video_path), + 'dataset': str(Path(video_path).name.split('_')[0]), + 'total_duration': duration, + 'sampled_length': num_seconds + } diff --git a/bitmind/validator/config.py b/bitmind/validator/config.py new file mode 100644 index 00000000..6112c3a2 --- /dev/null +++ b/bitmind/validator/config.py @@ -0,0 +1,226 @@ +from pathlib import Path +from typing import Dict, List, Union, Optional, Any + +import numpy as np +import torch +from diffusers import ( + StableDiffusionPipeline, + StableDiffusionXLPipeline, + FluxPipeline, + CogVideoXPipeline, + MochiPipeline, + AnimateDiffPipeline, + EulerDiscreteScheduler +) + +from .model_utils import load_annimatediff_motion_adapter + + +TARGET_IMAGE_SIZE: tuple[int, int] = (256, 256) + +MAINNET_UID = 34 +TESTNET_UID = 168 + +# Project constants +MAINNET_WANDB_PROJECT: str = 'bitmind-subnet' +TESTNET_WANDB_PROJECT: str = 'bitmind' +WANDB_ENTITY: str = 'bitmindai' + +# Cache directories +HUGGINGFACE_CACHE_DIR: Path = Path.home() / '.cache' / 'huggingface' +SN34_CACHE_DIR: Path = Path.home() / '.cache' / 'sn34' +REAL_CACHE_DIR: Path = SN34_CACHE_DIR / 'real' +SYNTH_CACHE_DIR: Path = SN34_CACHE_DIR / 'synthetic' +REAL_VIDEO_CACHE_DIR: Path = REAL_CACHE_DIR / 'video' +REAL_IMAGE_CACHE_DIR: Path = REAL_CACHE_DIR / 'image' +SYNTH_VIDEO_CACHE_DIR: Path = SYNTH_CACHE_DIR / 'video' +SYNTH_IMAGE_CACHE_DIR: Path = SYNTH_CACHE_DIR / 'image' +VALIDATOR_INFO_PATH: Path = SN34_CACHE_DIR / 'validator.yaml' +SN34_CACHE_DIR.mkdir(parents=True, exist_ok=True) + + +CHALLENGE_TYPE = { + 0: 'real', + 1: 'synthetic' +} + +# Image datasets configuration +IMAGE_DATASETS: Dict[str, List[Dict[str, str]]] = { + "real": [ + {"path": "bitmind/bm-real"}, + {"path": "bitmind/open-images-v7"}, + {"path": "bitmind/celeb-a-hq"}, + {"path": "bitmind/ffhq-256"}, + {"path": "bitmind/MS-COCO-unique-256"}, + {"path": "bitmind/AFHQ"}, + {"path": "bitmind/lfw"}, + {"path": "bitmind/caltech-256"}, + {"path": "bitmind/caltech-101"}, + {"path": "bitmind/dtd"} + ] +} + +VIDEO_DATASETS = { + "real": [ + { + "path": "nkp37/OpenVid-1M", + "filetype": "zip" + }, + { + "path": "shangxd/imagenet-vidvrd", + "filetype": "zip" + } + ] +} + +# Prompt generation model configurations +IMAGE_ANNOTATION_MODEL: str = "Salesforce/blip2-opt-6.7b-coco" +TEXT_MODERATION_MODEL: str = "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit" + +# Text-to-image model configurations +T2I_MODELS: Dict[str, Dict[str, Any]] = { + "stabilityai/stable-diffusion-xl-base-1.0": { + "pipeline_cls": StableDiffusionXLPipeline, + "from_pretrained_args": { + "use_safetensors": True, + "torch_dtype": torch.float16, + "variant": "fp16" + } + }, + "SG161222/RealVisXL_V4.0": { + "pipeline_cls": StableDiffusionXLPipeline, + "from_pretrained_args": { + "use_safetensors": True, + "torch_dtype": torch.float16, + "variant": "fp16" + } + }, + "Corcelio/mobius": { + "pipeline_cls": StableDiffusionXLPipeline, + "from_pretrained_args": { + "use_safetensors": True, + "torch_dtype": torch.float16 + } + }, + "black-forest-labs/FLUX.1-dev": { + "pipeline_cls": FluxPipeline, + "from_pretrained_args": { + "use_safetensors": True, + "torch_dtype": torch.bfloat16, + }, + "generate_args": { + "guidance_scale": 2, + "num_inference_steps": {"min": 50, "max": 125}, + "generator": torch.Generator("cuda" if torch.cuda.is_available() else "cpu"), + "height": [512, 768], + "width": [512, 768] + }, + "enable_model_cpu_offload": False + }, + "prompthero/openjourney-v4" : { + "pipeline_cls": StableDiffusionPipeline, + "from_pretrained_args": { + "use_safetensors": True, + "torch_dtype": torch.float16, + } + }, + "cagliostrolab/animagine-xl-3.1": { + "pipeline_cls": StableDiffusionXLPipeline, + "from_pretrained_args": { + "use_safetensors": True, + "torch_dtype": torch.float16, + } + } +} +T2I_MODEL_NAMES: List[str] = list(T2I_MODELS.keys()) + + +# Text-to-video model configurations +T2V_MODELS: Dict[str, Dict[str, Any]] = { + "genmo/mochi-1-preview": { + "pipeline_cls": MochiPipeline, + "from_pretrained_args": { + "variant": "bf16", + "torch_dtype": torch.bfloat16 + }, + "generate_args": { + "num_frames": 84 + }, + #"enable_model_cpu_offload": True, + "vae_enable_tiling": True + }, + 'THUDM/CogVideoX-5b': { + "pipeline_cls": CogVideoXPipeline, + "from_pretrained_args": { + "use_safetensors": True, + "torch_dtype": torch.bfloat16 + }, + "generate_args": { + "guidance_scale": 2, + "num_videos_per_prompt": 1, + "num_inference_steps": {"min": 50, "max": 125}, + "num_frames": 48, + }, + "enable_model_cpu_offload": True, + #"enable_sequential_cpu_offload": True, + "vae_enable_slicing": True, + "vae_enable_tiling": True + }, + 'ByteDance/AnimateDiff-Lightning': { + "pipeline_cls": AnimateDiffPipeline, + "from_pretrained_args": { + "base": "emilianJR/epiCRealism", + "torch_dtype": torch.bfloat16, + "motion_adapter": load_annimatediff_motion_adapter() + }, + "generate_args": { + "guidance_scale": 2, + "num_inference_steps": {"min": 50, "max": 125}, + }, + "scheduler": { + "cls": EulerDiscreteScheduler, + "from_config_args": { + "timestep_spacing": "trailing", + "beta_schedule": "linear" + } + } + } +} +T2V_MODEL_NAMES: List[str] = list(T2V_MODELS.keys()) + +# Combined model configurations +T2VIS_MODELS: Dict[str, Dict[str, Any]] = {**T2I_MODELS, **T2V_MODELS} +T2VIS_MODEL_NAMES: List[str] = list(T2VIS_MODELS.keys()) + + +def get_modality(model_name): + if model_name in T2V_MODEL_NAMES: + return 'video' + elif model_name in T2I_MODEL_NAMES: + return 'image' + + +def select_random_t2vis_model(modality: Optional[str] = None) -> str: + """ + Select a random text-to-image or text-to-video model based on the specified + modality. + + Args: + modality: The type of model to select ('image', 'video', or 'random'). + If None or 'random', randomly chooses between image and video. + + Returns: + The name of the selected model. + + Raises: + NotImplementedError: If the specified modality is not supported. + """ + if modality is None or modality == 'random': + modality = np.random.choice(['image', 'video']) + + if modality == 'image': + return np.random.choice(T2I_MODEL_NAMES) + elif modality == 'video': + return np.random.choice(T2V_MODEL_NAMES) + else: + raise NotImplementedError(f"Unsupported modality: {modality}") diff --git a/bitmind/validator/forward.py b/bitmind/validator/forward.py index 0a3fa984..a1118796 100644 --- a/bitmind/validator/forward.py +++ b/bitmind/validator/forward.py @@ -16,39 +16,19 @@ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER # DEALINGS IN THE SOFTWARE. - -from PIL import Image -from io import BytesIO -from datetime import datetime import bittensor as bt import pandas as pd import numpy as np -import os +import random import wandb +import time + +from bitmind.validator.config import CHALLENGE_TYPE, MAINNET_UID from bitmind.utils.uids import get_random_uids -from bitmind.utils.data import sample_dataset_index_name -from bitmind.protocol import prepare_image_synapse +from bitmind.protocol import prepare_synapse from bitmind.validator.reward import get_rewards -from bitmind.image_transforms import apply_augmentation_by_level - - -def sample_random_real_image(datasets, total_images, retries=10): - random_idx = np.random.randint(0, total_images) - source, idx = sample_real_image(datasets, random_idx) - if source[idx]['image'] is None: - if retries: - return sample_random_real_image(datasets, total_images, retries-1) - return None, None - return source, idx - - -def sample_real_image(datasets, index): - cumulative_sizes = np.cumsum([len(ds) for ds in datasets]) - source_index = np.searchsorted(cumulative_sizes - 1, index % (cumulative_sizes[-1])) - source = datasets[source_index] - valid_index = index - (cumulative_sizes[source_index - 1] if source_index > 0 else 0) - return source, valid_index +from bitmind.utils.image_transforms import apply_augmentation_by_level async def forward(self): @@ -58,115 +38,102 @@ async def forward(self): Steps are: 1. Sample miner UIDs - 2. Get an image. 50/50 chance of: - A. REAL (label = 0): Randomly sample a real image from self.real_image_datasets - B. FAKE (label = 1): Generate a synthetic image with self.random_image_generator + 2. Sample synthetic/real image/video (50/50 chance for each choice) 3. Apply random data augmentation to the image - 4. Base64 encode the image and prepare an ImageSynapse + 4. Encode data and prepare Synapse 5. Query miner axons - 6. Log results, including image and miner responses (soon to be W&B) - 7. Compute rewards and update scores + 6. Compute rewards and update scores Args: self (:obj:`bittensor.neuron.Neuron`): The neuron object which contains all the necessary state for the validator. """ - wandb_data = {} - - miner_uids = get_random_uids(self, k=self.config.neuron.sample_size) - bt.logging.info("Generating challenge") - if np.random.rand() > self._fake_prob: - label = 0 - source_dataset, local_index = sample_random_real_image(self.real_image_datasets, self.total_real_images) - wandb_data['source_dataset'] = source_dataset.huggingface_dataset_name - wandb_data['source_image_index'] = local_index - sample = source_dataset[local_index] - - else: - label = 1 - if self.config.neuron.prompt_type == 'annotation': - retries = 10 - while retries > 0: - retries -= 1 - source_dataset, local_index = sample_random_real_image(self.real_image_datasets, self.total_real_images) - source_sample = source_dataset[local_index] - source_image = source_sample['image'] - if source_image is None: - continue - - # generate captions for the real images, then synthetic images from these captions - sample = self.synthetic_image_generator.generate( - k=1, real_images=[source_sample])[0] # {'prompt': str, 'image': PIL Image ,'id': int} - - wandb_data['model'] = self.synthetic_image_generator.diffuser_name - wandb_data['source_dataset'] = source_dataset.huggingface_dataset_name - wandb_data['source_image_index'] = local_index - wandb_data['image'] = wandb.Image(sample['image']) - wandb_data['prompt'] = sample['prompt'] - if not np.any(np.isnan(sample['image'])): - break - else: - raise NotImplementedError(f'unsupported neuron.prompt_type: {self.config.neuron.prompt_type}') - - image = sample['image'] - image, level, data_aug_params = apply_augmentation_by_level(image) - - bt.logging.info(f"Querying {len(miner_uids)} miners...") + challenge_metadata = {} # for bookkeeping + challenge = {} # for querying miners + + modality = 'video' if np.random.rand() > 0.5 else 'image' + label = 0 if np.random.rand() > self._fake_prob else 1 + challenge_metadata['label'] = label + challenge_metadata['modality'] = modality + + bt.logging.info(f"Sampling data from {modality} cache") + cache = self.media_cache[CHALLENGE_TYPE[label]][modality] + + if modality == 'video': + clip_length = random.randint( + self.config.neuron.clip_length_min, + self.config.neuron.clip_length_max) + challenge = cache.sample(clip_length) + challenge_metadata['clip_length_s'] = clip_length + #np_video = np.stack([np.array(img) for img in gen_output], axis=0) + #challenge_data['video'] = wandb.Video(np_video) # TODO format video for w&b + + elif modality == 'image': + challenge = cache.sample() + #challenge_data['image'] = wandb.Image(challenge['image']) + + if challenge is None: + bt.logging.warning("Waiting for cache to populate. Challenge skipped.") + return + + # update logging dict with everything except image/video data + challenge_metadata.update({k: v for k, v in challenge.items() if k != modality}) + input_data = challenge[modality] # extract video or image + + # apply data augmentation pipeline + try: + input_data, level, data_aug_params = apply_augmentation_by_level(input_data) + except Exception as e: + level, data_aug_params = -1, {} + bt.logging.error(f"Unable to applay augmentations: {e}") + challenge_metadata['data_aug_params'] = data_aug_params + challenge_metadata['data_aug_level'] = level + + # sample miner uids for challenge + miner_uids = get_random_uids(self, k=self.metagraph.n) # self.config.neuron.sample_size) axons = [self.metagraph.axons[uid] for uid in miner_uids] + challenge_metadata['miner_uids'] = list(miner_uids) + challenge_metadata['miner_hotkeys'] = list([axon.hotkey for axon in axons]) + + # prepare synapse + synapse = prepare_synapse(input_data, modality=modality) + if self.metagraph.netuid != MAINNET_UID: + synapse.testnet_label = label + + bt.logging.info(f"Sending {modality} challenge to {len(miner_uids)} miners") + start = time.time() responses = await self.dendrite( axons=axons, - synapse=prepare_image_synapse(image=image), + synapse=synapse, deserialize=True, timeout=9 ) + bt.logging.info(f"Responses received in {time.time() - start}s") + bt.logging.success(f"{CHALLENGE_TYPE[label]} {modality} challenge complete!") + bt.logging.info({k: v for k, v in challenge_metadata.items() if k not in ('miner_uids', 'miner_hotkeys')}) + bt.logging.info(f"Scoring responses") rewards, metrics = get_rewards( label=label, responses=responses, uids=miner_uids, axons=axons, performance_tracker=self.performance_tracker) - - # Logging image source (model for synthetic, dataset for real) and verification details - source_name = wandb_data['model'] if 'model' in wandb_data else wandb_data['source_dataset'] - bt.logging.info(f'{"real" if label == 0 else "fake"} image | source: {source_name}: {sample["id"]}') - - # Logging responses and rewards - bt.logging.info(f"Received responses: {responses}") - bt.logging.info(f"Scored responses: {rewards}") - - # Update the scores based on the rewards. + self.update_scores(rewards, miner_uids) - # update logging data - wandb_data['data_aug_params'] = data_aug_params - wandb_data['label'] = label - wandb_data['miner_uids'] = list(miner_uids) - wandb_data['miner_hotkeys'] = list([axon.hotkey for axon in axons]) - wandb_data['predictions'] = responses - wandb_data['data_aug_level'] = level - wandb_data['correct'] = [ - np.round(y_hat) == y - for y_hat, y in zip(responses, [label] * len(responses)) - ] - wandb_data['rewards'] = list(rewards) - wandb_data['scores'] = list(self.scores) - - metric_names = list(metrics[0].keys()) - for metric_name in metric_names: - wandb_data[f'miner_{metric_name}'] = [m[metric_name] for m in metrics] + for metric_name in list(metrics[0].keys()): + challenge_metadata[f'miner_{metric_name}'] = [m[metric_name] for m in metrics] + challenge_metadata['predictions'] = responses + challenge_metadata['rewards'] = rewards + challenge_metadata['scores'] = list(self.scores) + + for uid, pred, reward in zip(miner_uids, responses, rewards): + bt.logging.success(f"UID: {uid} | Prediction: {pred} | Reward: {reward}") # W&B logging if enabled if not self.config.wandb.off: - wandb.log(wandb_data) + wandb.log(challenge_metadata) # ensure state is saved after each challenge self.save_miner_history() - - # Track miners who have responded - self.last_responding_miner_uids = [] - for i, pred in enumerate(responses): - # Logging specific prediction details - if pred != -1: - bt.logging.info(f'Miner uid: {miner_uids[i]} | prediction: {pred} | correct: {np.round(pred) == label} | reward: {rewards[i]}') - self.last_responding_miner_uids.append(miner_uids[i]) diff --git a/bitmind/validator/model_utils.py b/bitmind/validator/model_utils.py new file mode 100644 index 00000000..36b90ad0 --- /dev/null +++ b/bitmind/validator/model_utils.py @@ -0,0 +1,37 @@ +import torch +from diffusers import MotionAdapter +from huggingface_hub import hf_hub_download +from safetensors.torch import load_file + + +def load_annimatediff_motion_adapter( + step: int = 4 +) -> MotionAdapter: + """ + Load a motion adapter model for AnimateDiff. + + Args: + step: The step size for the motion adapter. Options: [1, 2, 4, 8]. + repo: The HuggingFace repository to download the motion adapter from. + ckpt: The checkpoint filename + Returns: + A loaded MotionAdapter model. + + Raises: + ValueError: If step is not one of [1, 2, 4, 8]. + """ + if step not in [1, 2, 4, 8]: + raise ValueError("Step must be one of [1, 2, 4, 8]") + + device = "cuda" if torch.cuda.is_available() else "cpu" + adapter = MotionAdapter().to(device, torch.float16) + + repo = "ByteDance/AnimateDiff-Lightning" + ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors" + adapter.load_state_dict( + load_file( + hf_hub_download(repo, ckpt), + device=device + ) + ) + return adapter diff --git a/bitmind/validator/run_data_generator.py b/bitmind/validator/run_data_generator.py new file mode 100644 index 00000000..c19572d8 --- /dev/null +++ b/bitmind/validator/run_data_generator.py @@ -0,0 +1,113 @@ +import time +import yaml + +import wandb +import bittensor as bt + +import bitmind +from bitmind.synthetic_data_generation import SyntheticDataGenerator +from bitmind.validator.cache import ImageCache +from bitmind.validator.config import ( + REAL_IMAGE_CACHE_DIR, + SYNTH_CACHE_DIR, + WANDB_ENTITY, + TESTNET_WANDB_PROJECT, + MAINNET_WANDB_PROJECT, + MAINNET_UID, + VALIDATOR_INFO_PATH +) + +def load_validator_info(): + try: + with open(VALIDATOR_INFO_PATH, 'r') as f: + validator_info = yaml.safe_load(f) + bt.logging.info(f"Loaded validator info from {VALIDATOR_INFO_PATH}") + except FileNotFoundError: + bt.logging.error(f"Could not find validator info at {VALIDATOR_INFO_PATH}") + validator_info = { + 'uid': 'NotFound', + 'hotkey': 'NotFound', + 'full_path': 'NotFound', + 'netuid': TESTNET_WANDB_PROJECT + } + except yaml.YAMLError: + bt.logging.error(f"Could not parse validator info at {VALIDATOR_INFO_PATH}") + validator_info = { + 'uid': 'ParseError', + 'hotkey': 'ParseError', + 'full_path': 'ParseError', + 'netuid': TESTNET_WANDB_PROJECT + } + return validator_info + + +def init_wandb_run(uid: str, hotkey: str, netuid: int, full_path: str) -> None: + """ + Initialize a Weights & Biases run for tracking the validator. + + Args: + vali_uid: The validator's uid + vali_hotkey: The validator's hotkey address + netuid: The network ID (mainnet or testnet) + vali_full_path: Validator's bittensor directory + + Returns: + None + """ + run_name = f'data-generator-{uid}-{bitmind.__version__}' + + config = { + 'run_name': run_name, + 'uid': uid, + 'hotkey': hotkey, + 'version': bitmind.__version__ + } + + wandb_project = TESTNET_WANDB_PROJECT + if netuid == MAINNET_UID: + wandb_project = MAINNET_WANDB_PROJECT + + # Initialize the wandb run for the single project + bt.logging.info(f"Initializing W&B run for '{WANDB_ENTITY}/{wandb_project}'") + try: + run = wandb.init( + name=run_name, + project=wandb_project, + entity=WANDB_ENTITY, + config=config, + dir=full_path, + reinit=True + ) + except wandb.UsageError as e: + bt.logging.warning(e) + bt.logging.warning("Did you run wandb login?") + return + +if __name__ == '__main__': + + init_wandb_run(**load_validator_info()) + + image_cache = ImageCache(REAL_IMAGE_CACHE_DIR, datasets=None, run_updater=False) + while True: + if image_cache._extracted_cache_empty(): + bt.logging.info("SyntheticDataGenerator waiting for real image cache to populate") + time.sleep(5) + continue + bt.logging.info("Image cache was populated! Proceeding to data generation") + break + + sgd = SyntheticDataGenerator( + prompt_type='annotation', + use_random_t2vis_model=True, + device='cuda', + image_cache=image_cache, + output_dir=SYNTH_CACHE_DIR) + + bt.logging.info("Starting standalone data generator service") + while True: + try: + sgd.batch_generate(batch_size=1) + time.sleep(1) + except Exception as e: + bt.logging.error(f"Error in batch generation: {str(e)}") + time.sleep(5) diff --git a/bitmind/validator/verify_models.py b/bitmind/validator/verify_models.py index 8aa8b61e..278a0ff3 100644 --- a/bitmind/validator/verify_models.py +++ b/bitmind/validator/verify_models.py @@ -1,6 +1,6 @@ import os -from bitmind.synthetic_image_generation.synthetic_image_generator import SyntheticImageGenerator -from bitmind.constants import DIFFUSER_NAMES, IMAGE_ANNOTATION_MODEL, TEXT_MODERATION_MODEL +from bitmind.synthetic_data_generation import SyntheticDataGenerator +from bitmind.validator.config import T2VIS_MODEL_NAMES as MODEL_NAMES, IMAGE_ANNOTATION_MODEL, TEXT_MODERATION_MODEL import bittensor as bt @@ -38,10 +38,9 @@ def main(): It also initializes and loads diffusers for uncached models. """ bt.logging.info("Verifying validator model downloads....") - synthetic_image_generator = SyntheticImageGenerator( + synthetic_image_generator = SyntheticDataGenerator( prompt_type='annotation', - use_random_diffuser=True, - diffuser_name=None + use_random_t2vis_model=True ) # Check and load annotation and moderation models if not cached @@ -50,14 +49,14 @@ def main(): synthetic_image_generator.image_annotation_generator.clear_gpu() # Initialize and load diffusers if not cached - for model_name in DIFFUSER_NAMES: + for model_name in MODEL_NAMES: if not is_model_cached(model_name): - synthetic_image_generator = SyntheticImageGenerator( + synthetic_image_generator = SyntheticDataGenerator( prompt_type='annotation', - use_random_diffuser=False, - diffuser_name=model_name + use_random_t2vis_model=False, + t2vis_model_name=model_name ) - synthetic_image_generator.load_diffuser(model_name) + synthetic_image_generator.load_t2vis_model(model_name) synthetic_image_generator.clear_gpu() diff --git a/bitmind/validator/video_utils.py b/bitmind/validator/video_utils.py new file mode 100644 index 00000000..3c1e7e04 --- /dev/null +++ b/bitmind/validator/video_utils.py @@ -0,0 +1,106 @@ +import tempfile +from pathlib import Path +from typing import Optional, BinaryIO, List, Union + +import bittensor as bt +import ffmpeg +import numpy as np +from moviepy.editor import VideoFileClip +from PIL import Image + +from .cache.util import seconds_to_str + + +def video_to_pil(video_path: Union[str, Path]) -> List[Image.Image]: + """Load video file and convert it to a list of PIL images. + + Args: + video_path: Path to the input video file. + + Returns: + List of PIL Image objects representing each frame of the video. + """ + clip = VideoFileClip(str(video_path)) + frames = [Image.fromarray(np.array(frame)) for frame in clip.iter_frames()] + clip.close() + return frames + + +def clip_video( + video_path: str, + start: int, + num_seconds: int +) -> Optional[BinaryIO]: + """Extract a clip from a video file. + + Args: + video_path: Path to the input video file. + start: Start time in seconds. + num_seconds: Duration of the clip in seconds. + + Returns: + A temporary file object containing the clipped video, + or None if the operation fails. + + Raises: + ffmpeg.Error: If FFmpeg encounters an error during processing. + """ + temp_fileobj = tempfile.NamedTemporaryFile(suffix=".mp4") + try: + ( + ffmpeg + .input(video_path, ss=seconds_to_str(start), t=str(num_seconds)) + .output(temp_fileobj.name, vf='fps=1') + .overwrite_output() + .run(capture_stderr=True) + ) + return temp_fileobj + except ffmpeg.Error as e: + bt.logging.error(f"FFmpeg error: {e.stderr.decode()}") + raise + + +def get_video_duration(filename: str) -> int: + """Get the duration of a video file in seconds. + + Args: + filename: Path to the video file. + + Returns: + Duration of the video in seconds. + + Raises: + KeyError: If video stream information cannot be found. + """ + metadata = ffmpeg.probe(filename) + video_stream = next( + (stream for stream in metadata['streams'] + if stream['codec_type'] == 'video'), + None + ) + if not video_stream: + raise KeyError("No video stream found in the file") + return int(float(video_stream['duration'])) + + +def copy_audio(video_path: str) -> BinaryIO: + """Extract the audio stream from a video file. + + Args: + video_path: Path to the input video file. + + Returns: + A temporary file object containing the extracted audio stream. + + Raises: + ffmpeg.Error: If FFmpeg encounters an error during processing. + """ + temp_audiofile = tempfile.NamedTemporaryFile(suffix=".aac") + ( + ffmpeg + .input(video_path) + .output(temp_audiofile.name, vn=None, acodec='copy') + .overwrite_output() + .run(quiet=True) + ) + return temp_audiofile \ No newline at end of file diff --git a/neurons/miner.py b/neurons/miner.py index 25c1fa36..794e3955 100644 --- a/neurons/miner.py +++ b/neurons/miner.py @@ -30,7 +30,7 @@ from base_miner import DETECTOR_REGISTRY from bitmind.base.miner import BaseMinerNeuron -from bitmind.protocol import ImageSynapse +from bitmind.protocol import ImageSynapse, VideoSynapse, decode_video_synapse from bitmind.utils.config import get_device @@ -38,145 +38,122 @@ class Miner(BaseMinerNeuron): def __init__(self, config=None): super(Miner, self).__init__(config=config) - if self.config.neuron.device == 'auto': - self.config.neuron.device = get_device() - self.load_detector() - - def load_detector(self): - self.deepfake_detector = DETECTOR_REGISTRY[self.config.neuron.detector]( - config=self.config.neuron.detector_config, - device=self.config.neuron.device + bt.logging.info("Attaching forward function to miner axon.") + self.axon.attach( + forward_fn=self.forward_image, + blacklist_fn=self.blacklist_image, + priority_fn=self.priority_image, + ).attach( + forward_fn=self.forward_video, + blacklist_fn=self.blacklist_video, + priority_fn=self.priority_video, ) - - async def forward( + bt.logging.info(f"Axon created: {self.axon}") + + bt.logging.info("Loading image detection model if configured") + self.load_image_detector() + bt.logging.info("Loading video detection model if configured") + self.load_video_detector() + + def load_image_detector(self): + if (str(self.config.neuron.image_detector).lower() == 'none' or + str(self.config.neuron.image_detector_config).lower() == 'none'): + bt.logging.warning("No image detector configuration provided, skipping.") + self.image_detector = None + return + + if self.config.neuron.image_detector_device == 'auto': + bt.logging.warning("Automatic device configuration enabled for image detector") + self.config.neuron.image_detector_device = get_device() + + self.image_detector = DETECTOR_REGISTRY[self.config.neuron.image_detector]( + config=self.config.neuron.image_detector_config, + device=self.config.neuron.image_detector_device + ) + bt.logging.info(f"Loaded image detection model: {self.config.neuron.image_detector}") + + def load_video_detector(self): + if (str(self.config.neuron.video_detector).lower() == 'none' or + str(self.config.neuron.video_detector_config).lower() == 'none'): + bt.logging.warning("No video detector configuration provided, skipping.") + self.video_detector = None + return + + if self.config.neuron.video_detector_device == 'auto': + bt.logging.warning("Automatic device configuration enabled for video detector") + self.config.neuron.video_detector_device = get_device() + + self.video_detector = DETECTOR_REGISTRY[self.config.neuron.video_detector]( + config=self.config.neuron.video_detector_config, + device=self.config.neuron.video_detector_device + ) + bt.logging.info(f"Loaded video detection model: {self.config.neuron.video_detector}") + + async def forward_image( self, synapse: ImageSynapse ) -> ImageSynapse: """ - Loads the deepfake detection model (a PyTorch binary classifier) from the path specified in --neuron.model_path. - Processes the incoming ImageSynapse and passes the image to the loaded model for classification. - The model is loaded here, rather than in __init__, so that miners may (backup) and overwrite - their model file as a means of updating their miner's predictor. + Perform inference on image Args: - synapse (ImageSynapse): The synapse object containing the list of b64 encoded images in the + synapse (bt.Synapse): The synapse object containing the list of b64 encoded images in the 'images' field. Returns: - ImageSynapse: The synapse object with the 'predictions' field populated with a list of probabilities + bt.Synapse: The synapse object with the 'predictions' field populated with a list of probabilities """ - try: - image_bytes = base64.b64decode(synapse.image) - image = Image.open(io.BytesIO(image_bytes)) - - pred = self.deepfake_detector(image) - - synapse.prediction = pred - - except Exception as e: - bt.logging.error("Error performing inference") - bt.logging.error(e) - - bt.logging.info(f"PREDICTION: {synapse.prediction}") + if self.image_detector is None: + bt.logging.info("Image detection model not configured; skipping image challenge") + else: + bt.logging.info("Received image challenge!") + try: + image_bytes = base64.b64decode(synapse.image) + image = Image.open(io.BytesIO(image_bytes)) + synapse.prediction = self.image_detector(image) + except Exception as e: + bt.logging.error("Error performing inference") + bt.logging.error(e) + bt.logging.info(f"PREDICTION: {synapse.prediction}") return synapse - async def blacklist( - self, synapse: ImageSynapse - ) -> typing.Tuple[bool, str]: + async def forward_video( + self, synapse: VideoSynapse + ) -> VideoSynapse: """ - Determines whether an incoming request should be blacklisted and thus ignored. Your implementation should - define the logic for blacklisting requests based on your needs and desired security parameters. - - Blacklist runs before the synapse data has been deserialized (i.e. before synapse.data is available). - The synapse is instead contructed via the headers of the request. It is important to blacklist - requests before they are deserialized to avoid wasting resources on requests that will be ignored. - + Perform inference on video Args: - synapse (ImageSynapse): A synapse object constructed from the headers of the incoming request. + synapse (bt.Synapse): The synapse object containing the list of b64 encoded images in the + 'images' field. Returns: - Tuple[bool, str]: A tuple containing a boolean indicating whether the synapse's hotkey is blacklisted, - and a string providing the reason for the decision. - - This function is a security measure to prevent resource wastage on undesired requests. It should be enhanced - to include checks against the metagraph for entity registration, validator status, and sufficient stake - before deserialization of synapse data to minimize processing overhead. + bt.Synapse: The synapse object with the 'predictions' field populated with a list of probabilities - Example blacklist logic: - - Reject if the hotkey is not a registered entity within the metagraph. - - Consider blacklisting entities that are not validators or have insufficient stake. - - In practice it would be wise to blacklist requests from entities that are not validators, or do not have - enough stake. This can be checked via metagraph.S and metagraph.validator_permit. You can always attain - the uid of the sender via a metagraph.hotkeys.index( synapse.dendrite.hotkey ) call. - - Otherwise, allow the request to be processed further. - """ - if synapse.dendrite is None or synapse.dendrite.hotkey is None: - bt.logging.warning("Received a request without a dendrite or hotkey.") - return True, "Missing dendrite or hotkey" - - # TODO(developer): Define how miners should blacklist requests. - uid = self.metagraph.hotkeys.index(synapse.dendrite.hotkey) - if ( - not self.config.blacklist.allow_non_registered - and synapse.dendrite.hotkey not in self.metagraph.hotkeys - ): - # Ignore requests from un-registered entities. - bt.logging.trace( - f"Blacklisting un-registered hotkey {synapse.dendrite.hotkey}" - ) - return True, "Unrecognized hotkey" - - if self.config.blacklist.force_validator_permit: - # If the config is set to force validator permit, then we should only allow requests from validators. - if not self.metagraph.validator_permit[uid]: - bt.logging.warning( - f"Blacklisting a request from non-validator hotkey {synapse.dendrite.hotkey}" - ) - return True, "Non-validator hotkey" - - bt.logging.trace( - f"Not Blacklisting recognized hotkey {synapse.dendrite.hotkey}" - ) - return False, "Hotkey recognized!" - - async def priority(self, synapse: ImageSynapse) -> float: """ - The priority function determines the order in which requests are handled. More valuable or higher-priority - requests are processed before others. You should design your own priority mechanism with care. - - This implementation assigns priority to incoming requests based on the calling entity's stake in the metagraph. + if self.video_detector is None: + bt.logging.info("Video detection model not configured; skipping video challenge") + else: + bt.logging.info("Received video challenge!") + try: + frames_tensor = decode_video_synapse(synapse) + synapse.prediction = self.video_detector(frames_tensor) + except Exception as e: + bt.logging.error("Error performing inference") + bt.logging.error(e) + bt.logging.info(f"PREDICTION: {synapse.prediction}") + return synapse - Args: - synapse (ImageSynapse): The synapse object that contains metadata about the incoming request. + async def blacklist_image(self, synapse: ImageSynapse) -> typing.Tuple[bool, str]: + return await self.blacklist(synapse) - Returns: - float: A priority score derived from the stake of the calling entity. + async def blacklist_video(self, synapse: VideoSynapse) -> typing.Tuple[bool, str]: + return await self.blacklist(synapse) - Miners may recieve messages from multiple entities at once. This function determines which request should be - processed first. Higher values indicate that the request should be processed first. Lower values indicate - that the request should be processed later. + async def priority_image(self, synapse: ImageSynapse) -> float: + return await self.priority(synapse) - Example priority logic: - - A higher stake results in a higher priority value. - """ - if synapse.dendrite is None or synapse.dendrite.hotkey is None: - bt.logging.warning("Received a request without a dendrite or hotkey.") - return 0.0 - - # TODO(developer): Define how miners should prioritize requests. - caller_uid = self.metagraph.hotkeys.index( - synapse.dendrite.hotkey - ) # Get the caller index. - - prirority = float( - self.metagraph.S[caller_uid] - ) # Return the stake as the priority. - bt.logging.trace( - f"Prioritizing {synapse.dendrite.hotkey} with value: ", prirority - ) - return prirority + async def priority_video(self, synapse: VideoSynapse) -> float: + return await self.priority(synapse) def save_state(self): pass diff --git a/neurons/validator.py b/neurons/validator.py index 15c38fbd..8b9a4c2c 100644 --- a/neurons/validator.py +++ b/neurons/validator.py @@ -18,15 +18,28 @@ # DEALINGS IN THE SOFTWARE. import bittensor as bt +import yaml import wandb import time from neurons.validator_proxy import ValidatorProxy -from bitmind.validator import forward +from bitmind.validator.forward import forward +from bitmind.validator.cache import VideoCache, ImageCache from bitmind.base.validator import BaseValidatorNeuron -from bitmind.synthetic_image_generation.synthetic_image_generator import SyntheticImageGenerator -from bitmind.image_dataset import ImageDataset -from bitmind.constants import VALIDATOR_DATASET_META, WANDB_PROJECT, WANDB_ENTITY +from bitmind.validator.config import ( + MAINNET_UID, + MAINNET_WANDB_PROJECT, + TESTNET_WANDB_PROJECT, + IMAGE_DATASETS, + VIDEO_DATASETS, + WANDB_ENTITY, + REAL_VIDEO_CACHE_DIR, + REAL_IMAGE_CACHE_DIR, + SYNTH_IMAGE_CACHE_DIR, + SYNTH_VIDEO_CACHE_DIR, + VALIDATOR_INFO_PATH +) + import bitmind @@ -50,25 +63,26 @@ def __init__(self, config=None): self.last_responding_miner_uids = [] self.validator_proxy = ValidatorProxy(self) - - bt.logging.info("init_wandb()") - self.init_wandb() - - bt.logging.info("Loading real datasets") - self.real_image_datasets = [ - ImageDataset(ds['path'], 'train', ds.get('name', None)) - for ds in VALIDATOR_DATASET_META['real'] - ] - self.total_real_images = sum([ - len(ds) for ds in self.real_image_datasets - ]) - - self.synthetic_image_generator = SyntheticImageGenerator( - prompt_type='annotation', - use_random_diffuser=True, - diffuser_name=None, - device=self.config.neuron.device) + # real media caches run async update tasks to download and unpack parts of subsets of datasets + self.real_media_cache = { + 'image': ImageCache(REAL_IMAGE_CACHE_DIR, run_updater=True, datasets=IMAGE_DATASETS['real']), + 'video': VideoCache(REAL_VIDEO_CACHE_DIR, run_updater=True, datasets=VIDEO_DATASETS['real']) + } + + # synthetic media caches are populated by the SyntheticDataGenerator process (started by start_validator.sh) + self.synthetic_media_cache = { + 'image': ImageCache(SYNTH_IMAGE_CACHE_DIR, run_updater=False), + 'video': VideoCache(SYNTH_VIDEO_CACHE_DIR, run_updater=False) + } + + self.media_cache = { + 'real': self.real_media_cache, + 'synthetic': self.real_media_cache, + } + + self.init_wandb() + self.store_vali_info() self._fake_prob = self.config.get('fake_prob', 0.5) async def forward(self): @@ -93,12 +107,16 @@ def init_wandb(self): self.config.version = bitmind.__version__ self.config.type = self.neuron_type + wandb_project = TESTNET_WANDB_PROJECT + if self.config.netuid == MAINNET_UID: + wandb_project = MAINNET_WANDB_PROJECT + # Initialize the wandb run for the single project - print("Initializing W&B") + bt.logging.info(f"Initializing W&B run for '{WANDB_ENTITY}/{wandb_project}'") try: run = wandb.init( name=run_name, - project=WANDB_PROJECT, + project=wandb_project, entity=WANDB_ENTITY, config=self.config, dir=self.config.full_path, @@ -114,7 +132,23 @@ def init_wandb(self): self.config.signature = signature wandb.config.update(self.config, allow_val_change=True) - bt.logging.success(f"Started wandb run for project '{WANDB_PROJECT}'") + bt.logging.success(f"Started wandb run {run_name}") + + def store_vali_info(self): + """ + Stores the uid, hotkey and netuid of the currently running vali instance. + The SyntheticDataGenerator process reads this to name its w&b run + """ + validator_info = { + 'uid': self.uid, + 'hotkey': self.wallet.hotkey.ss58_address, + 'netuid': self.config.netuid, + 'full_path': self.config.neuron.full_path + } + with open(VALIDATOR_INFO_PATH, 'w') as f: + yaml.safe_dump(validator_info, f, indent=4) + + bt.logging.info(f"Wrote validator info to {VALIDATOR_INFO_PATH}") # The main function parses the configuration and runs the validator. @@ -124,4 +158,4 @@ def init_wandb(self): with Validator() as validator: while True: bt.logging.info(f"Validator running | uid {validator.uid} | {time.time()}") - time.sleep(5) + time.sleep(45) diff --git a/neurons/validator_proxy.py b/neurons/validator_proxy.py index 1cc9c53c..a4bc478d 100644 --- a/neurons/validator_proxy.py +++ b/neurons/validator_proxy.py @@ -20,12 +20,15 @@ import socket import base64 -from bitmind.image_transforms import base_transforms +from bitmind.validator.config import TARGET_IMAGE_SIZE +from bitmind.utils.image_transforms import get_base_transforms from bitmind.protocol import ImageSynapse, prepare_image_synapse from bitmind.utils.uids import get_random_uids from bitmind.validator.proxy import ProxyCounter import bitmind +base_transforms = get_base_transforms(TARGET_IMAGE_SIZE) + def preprocess_image(b64_image): image_bytes = base64.b64decode(b64_image) diff --git a/requirements-miner.txt b/requirements-miner.txt index 11da3b9e..6105cf20 100644 --- a/requirements-miner.txt +++ b/requirements-miner.txt @@ -1,5 +1,7 @@ tensorboardx==2.6.2.2 -dlib==19.24.6 +#dlib==19.24.6 imutils==0.5.4 scikit-image==0.24.0 ultralytics==8.2.86 +timm==1.0.11 +einops==0.8.0 diff --git a/requirements-validator.txt b/requirements-validator.txt index 45bad1fa..99577b94 100644 --- a/requirements-validator.txt +++ b/requirements-validator.txt @@ -1,5 +1,14 @@ httpx==0.27.0 -diffusers==0.30.0 transformers==4.46.3 sentencepiece==0.2.0 bitsandbytes==0.43.3 +imageio==2.35.1 +imageio-ffmpeg==0.5.1 +moviepy==1.0.3 +av==13.1.0 +yt-dlp==2024.11.4 +ffmpeg-python==0.2.0 +moviepy==1.0.3 +pyffmpeg==2.4.2.18.1 +#diffusers==0.30.0 +git+https://github.com/huggingface/diffusers.git diff --git a/setup_miner_env.sh b/setup_miner_env.sh index 004479d2..e6dd2457 100755 --- a/setup_miner_env.sh +++ b/setup_miner_env.sh @@ -10,16 +10,23 @@ sudo npm install pm2@latest -g sudo apt install build-essential cmake -y sudo apt install libopenblas-dev liblapack-dev -y sudo apt install libx11-dev libgtk-3-dev -y +sudo apt install unzip # Install Python dependencies pip install -e . pip install -r requirements-miner.txt echo "# Default options: -DETECTOR=CAMO # Options: CAMO, UCF, NPR -DETECTOR_CONFIG=camo.yaml # Configs live in base_miner/deepfake_detectors/configs +IMAGE_DETECTOR=CAMO # Options: CAMO, UCF, NPR, None +IMAGE_DETECTOR_CONFIG=camo.yaml # Configs live in base_miner/deepfake_detectors/configs # Supply a filename or relative path -DEVICE=cpu # Options: cpu, cuda + +VIDEO_DETECTOR=TALL # Options: TALL, None +VIDEO_DETECTOR_CONFIG=tall.yaml # Configs live in base_miner/deepfake_detectors/configs + # Supply a filename or relative path + +IMAGE_DETECTOR_DEVICE=cpu # Options: cpu, cuda +VIDEO_DETECTOR_DEVICE=cpu # Subtensor Network Configuration: NETUID=34 # Network User ID options: 34, 168 diff --git a/setup_validator_env.sh b/setup_validator_env.sh index 225bb195..74c89829 100755 --- a/setup_validator_env.sh +++ b/setup_validator_env.sh @@ -2,11 +2,9 @@ # Update system and install required packages sudo apt update -y -sudo apt install python3-pip -y -sudo apt install nano -y -sudo apt install libgl1 -y -sudo apt install npm -y +sudo apt install python3-pip nano libgl1 npm ffmpeg -y sudo npm install pm2@latest -g +sudo apt install -y unzip # Install Python dependencies pip install -e . diff --git a/start_miner.sh b/start_miner.sh index 42b9c6a5..dd3c94f7 100755 --- a/start_miner.sh +++ b/start_miner.sh @@ -1,21 +1,21 @@ #!/bin/bash -# Load environment variables from .env file set -a source miner.env set +a -# Check if the process is already running if pm2 list | grep -q "bitmind_miner"; then echo "Process 'bitmind_miner' is already running. Deleting it..." pm2 delete bitmind_miner fi -# Start the process with arguments from environment variables pm2 start neurons/miner.py --name bitmind_miner -- \ - --neuron.detector $DETECTOR \ - --neuron.detector_config $DETECTOR_CONFIG \ - --neuron.device $DEVICE \ + --neuron.image_detector ${IMAGE_DETECTOR:-None} \ + --neuron.image_detector_config ${IMAGE_DETECTOR_CONFIG:-None} \ + --neuron.image_detector_device ${IMAGE_DETECTOR_DEVICE:-None} \ + --neuron.video_detector ${VIDEO_DETECTOR:-None} \ + --neuron.video_detector_config ${VIDEO_DETECTOR_CONFIG:-None} \ + --neuron.video_detector_device ${VIDEO_DETECTOR_DEVICE:-None} \ --netuid $NETUID \ --subtensor.network $SUBTENSOR_NETWORK \ --subtensor.chain_endpoint $SUBTENSOR_CHAIN_ENDPOINT \ diff --git a/start_validator.sh b/start_validator.sh index c1997ce6..d3dc7ce2 100755 --- a/start_validator.sh +++ b/start_validator.sh @@ -1,14 +1,16 @@ #!/bin/bash -# Load environment variables from .env file +# Load environment variables from .env file & set defaults set -a source validator.env set +a -# Set default values for environment variables : ${VALIDATOR_PROXY_PORT:=10913} : ${DEVICE:=cuda} +VALIDATOR_PROCESS_NAME="bitmind_validator" +DATA_GEN_PROCESS_NAME="bitmind_data_generator" + # Login to Weights & Biases if ! wandb login $WANDB_API_KEY; then echo "Failed to login to Weights & Biases with the provided API key." @@ -21,20 +23,20 @@ if ! huggingface-cli login --token $HUGGING_FACE_TOKEN; then exit 1 fi -# Check if the process is already running -if pm2 list | grep -q "bitmind_validator"; then - echo "Process 'bitmind_validator' is already running. Deleting it..." - pm2 delete bitmind_validator +# VALIDATOR PROCESS +if pm2 list | grep -q "$VALIDATOR_PROCESS_NAME"; then + echo "Process '$VALIDATOR_PROCESS_NAME' is already running. Deleting it..." + pm2 delete $VALIDATOR_PROCESS_NAME fi -echo "Verifying access to synthetic image generation models. This may take a few minutes." -if ! python3 bitmind/validator/verify_models.py; then - echo "Failed to verify diffusion models. Please check the configurations or model access permissions." - exit 1 -fi +#echo "Verifying access to synthetic image generation models. This may take a few minutes." +#if ! python3 bitmind/validator/verify_models.py; then +# echo "Failed to verify diffusion models. Please check the configurations or model access permissions." +# exit 1 +#fi -# Start the process with arguments from environment variables -pm2 start neurons/validator.py --name bitmind_validator -- \ +echo "Starting validator process" +pm2 start neurons/validator.py --name $VALIDATOR_PROCESS_NAME -- \ --netuid $NETUID \ --subtensor.network $SUBTENSOR_NETWORK \ --subtensor.chain_endpoint $SUBTENSOR_CHAIN_ENDPOINT \ @@ -43,3 +45,12 @@ pm2 start neurons/validator.py --name bitmind_validator -- \ --axon.port $VALIDATOR_AXON_PORT \ --proxy.port $VALIDATOR_PROXY_PORT \ --neuron.device $DEVICE + +# SYNTHETIC DATA GENERATOR PROCESS +if pm2 list | grep -q "$DATA_GEN_PROCESS_NAME"; then + echo "Process '$DATA_GEN_PROCESS_NAME' is already running. Deleting it..." + pm2 delete $DATA_GEN_PROCESS_NAME +fi + +echo "Starting SyntheticDataGenerator process" +pm2 start bitmind/validator/run_data_generator.py --name $DATA_GEN_PROCESS_NAME diff --git a/tests/fixtures/image_transforms.py b/tests/fixtures/image_transforms.py index 985dce77..e4f35a71 100644 --- a/tests/fixtures/image_transforms.py +++ b/tests/fixtures/image_transforms.py @@ -1,8 +1,8 @@ from functools import partial import torchvision.transforms as transforms -from bitmind.constants import TARGET_IMAGE_SIZE -from bitmind.image_transforms import ( +from bitmind.validator.config import TARGET_IMAGE_SIZE +from bitmind.utils.image_transforms import ( center_crop, RandomResizedCropWithParams, RandomHorizontalFlipWithParams, @@ -10,8 +10,8 @@ RandomRotationWithParams, ConvertToRGB, ComposeWithParams, - base_transforms, - random_aug_transforms + get_base_transforms, + get_random_augmentations ) @@ -25,6 +25,6 @@ ] TRANSFORM_PIPELINES = [ - base_transforms, - random_aug_transforms + get_base_transforms(TARGET_IMAGE_SIZE), + get_random_augmentations(TARGET_IMAGE_SIZE) ] \ No newline at end of file diff --git a/tests/validator/test_generate_image.py b/tests/validator/test_generate_image.py index 4cc8728e..f4cd1705 100644 --- a/tests/validator/test_generate_image.py +++ b/tests/validator/test_generate_image.py @@ -1,7 +1,7 @@ import pytest from unittest.mock import patch, MagicMock -from bitmind.synthetic_image_generation.synthetic_image_generator import SyntheticImageGenerator -from bitmind.constants import DIFFUSER_NAMES +from bitmind.synthetic_data_generation.synthetic_data_generator import SyntheticDataGenerator +from bitmind.validator.config import T2I_MODEL_NAMES from PIL import Image @@ -38,7 +38,7 @@ def mock_image_annotation_generator(): yield instance -@pytest.mark.parametrize("diffuser_name", DIFFUSER_NAMES) +@pytest.mark.parametrize("diffuser_name", T2I_MODEL_NAMES) def test_generate_image_with_diffusers(mock_diffuser, mock_image_annotation_generator, diffuser_name): """ Test the image generation process using different diffusion models. @@ -64,7 +64,7 @@ def test_generate_image_with_diffusers(mock_diffuser, mock_image_annotation_gene - Validating the image generation process - Integration testing with different diffuser models """ - generator = SyntheticImageGenerator( + generator = SyntheticDataGenerator( prompt_type='annotation', use_random_diffuser=False, diffuser_name=diffuser_name diff --git a/tests/validator/test_verify_models.py b/tests/validator/test_verify_models.py index ef14cbce..8e669692 100644 --- a/tests/validator/test_verify_models.py +++ b/tests/validator/test_verify_models.py @@ -2,7 +2,7 @@ import os from unittest.mock import patch, MagicMock, call from bitmind.validator.verify_models import is_model_cached, main -from bitmind.constants import DIFFUSER_NAMES, IMAGE_ANNOTATION_MODEL, TEXT_MODERATION_MODEL +from bitmind.validator.config import T2I_MODEL_NAMES, IMAGE_ANNOTATION_MODEL, TEXT_MODERATION_MODEL @pytest.fixture def mock_expanduser(): @@ -88,7 +88,7 @@ def test_main(mock_is_model_cached, MockSyntheticImageGenerator): # Expected calls with varying parameters based on model type expected_calls = [ call(prompt_type='annotation', use_random_diffuser=True, diffuser_name=None), # For IMAGE_ANNOTATION_MODEL and TEXT_MODERATION_MODEL - *[call(prompt_type='annotation', use_random_diffuser=False, diffuser_name=name) for name in DIFFUSER_NAMES] # For each name in DIFFUSER_NAMES + *[call(prompt_type='annotation', use_random_diffuser=False, diffuser_name=name) for name in T2I_MODEL_NAMES] # For each name in T2I_MODEL_NAMES ] # Verify all calls to SyntheticImageGenerator with the correct parameters