From f1d01af2908bc4caae06b11d30e447c697ea52d2 Mon Sep 17 00:00:00 2001 From: olafhappy Date: Tue, 18 Mar 2025 12:56:22 +0000 Subject: [PATCH 1/3] fix: refactoring project structure --- autoencoding.ipynb => notebook/autoencoding.ipynb | 0 interpolate.ipynb => notebook/interpolate.ipynb | 0 manipulate.ipynb => notebook/manipulate.ipynb | 0 manipulate_note.ipynb => notebook/manipulate_note.ipynb | 0 sample.ipynb => notebook/sample.ipynb | 0 run_bedroom128.py => scripts/run_bedroom128.py | 0 run_bedroom128_ddim.py => scripts/run_bedroom128_ddim.py | 0 run_celeba64.py => scripts/run_celeba64.py | 0 run_ffhq128.py => scripts/run_ffhq128.py | 0 run_ffhq128_cls.py => scripts/run_ffhq128_cls.py | 0 run_ffhq128_ddim.py => scripts/run_ffhq128_ddim.py | 0 run_ffhq256.py => scripts/run_ffhq256.py | 0 run_ffhq256_cls.py => scripts/run_ffhq256_cls.py | 0 run_ffhq256_latent.py => scripts/run_ffhq256_latent.py | 0 run_horse128.py => scripts/run_horse128.py | 0 run_horse128_ddim.py => scripts/run_horse128_ddim.py | 0 align.py => source/align.py | 0 choices.py => source/choices.py | 0 config.py => source/config.py | 0 config_base.py => source/config_base.py | 0 data_resize_bedroom.py => source/data_resize_bedroom.py | 0 data_resize_celeba.py => source/data_resize_celeba.py | 0 data_resize_celebahq.py => source/data_resize_celebahq.py | 0 data_resize_ffhq.py => source/data_resize_ffhq.py | 0 data_resize_horse.py => source/data_resize_horse.py | 0 dataset.py => source/dataset.py | 0 dataset_util.py => source/dataset_util.py | 0 dist_utils.py => source/dist_utils.py | 0 experiment.py => source/experiment.py | 0 experiment_classifier.py => source/experiment_classifier.py | 0 lmdb_writer.py => source/lmdb_writer.py | 0 metrics.py => source/metrics.py | 0 predict.py => source/predict.py | 0 renderer.py => source/renderer.py | 0 ssim.py => source/ssim.py | 0 templates.py => source/templates.py | 0 templates_cls.py => source/templates_cls.py | 0 templates_latent.py => source/templates_latent.py | 0 38 files changed, 0 insertions(+), 0 deletions(-) rename autoencoding.ipynb => notebook/autoencoding.ipynb (100%) rename interpolate.ipynb => notebook/interpolate.ipynb (100%) rename manipulate.ipynb => notebook/manipulate.ipynb (100%) rename manipulate_note.ipynb => notebook/manipulate_note.ipynb (100%) rename sample.ipynb => notebook/sample.ipynb (100%) rename run_bedroom128.py => scripts/run_bedroom128.py (100%) rename run_bedroom128_ddim.py => scripts/run_bedroom128_ddim.py (100%) rename run_celeba64.py => scripts/run_celeba64.py (100%) rename run_ffhq128.py => scripts/run_ffhq128.py (100%) rename run_ffhq128_cls.py => scripts/run_ffhq128_cls.py (100%) rename run_ffhq128_ddim.py => scripts/run_ffhq128_ddim.py (100%) rename run_ffhq256.py => scripts/run_ffhq256.py (100%) rename run_ffhq256_cls.py => scripts/run_ffhq256_cls.py (100%) rename run_ffhq256_latent.py => scripts/run_ffhq256_latent.py (100%) rename run_horse128.py => scripts/run_horse128.py (100%) rename run_horse128_ddim.py => scripts/run_horse128_ddim.py (100%) rename align.py => source/align.py (100%) rename choices.py => source/choices.py (100%) mode change 100755 => 100644 rename config.py => source/config.py (100%) rename config_base.py => source/config_base.py (100%) mode change 100755 => 100644 rename data_resize_bedroom.py => source/data_resize_bedroom.py (100%) rename data_resize_celeba.py => source/data_resize_celeba.py (100%) rename data_resize_celebahq.py => source/data_resize_celebahq.py (100%) rename data_resize_ffhq.py => source/data_resize_ffhq.py (100%) rename data_resize_horse.py => source/data_resize_horse.py (100%) rename dataset.py => source/dataset.py (100%) mode change 100755 => 100644 rename dataset_util.py => source/dataset_util.py (100%) mode change 100755 => 100644 rename dist_utils.py => source/dist_utils.py (100%) mode change 100755 => 100644 rename experiment.py => source/experiment.py (100%) mode change 100755 => 100644 rename experiment_classifier.py => source/experiment_classifier.py (100%) rename lmdb_writer.py => source/lmdb_writer.py (100%) mode change 100755 => 100644 rename metrics.py => source/metrics.py (100%) mode change 100755 => 100644 rename predict.py => source/predict.py (100%) rename renderer.py => source/renderer.py (100%) mode change 100755 => 100644 rename ssim.py => source/ssim.py (100%) mode change 100755 => 100644 rename templates.py => source/templates.py (100%) rename templates_cls.py => source/templates_cls.py (100%) rename templates_latent.py => source/templates_latent.py (100%) diff --git a/autoencoding.ipynb b/notebook/autoencoding.ipynb similarity index 100% rename from autoencoding.ipynb rename to notebook/autoencoding.ipynb diff --git a/interpolate.ipynb b/notebook/interpolate.ipynb similarity index 100% rename from interpolate.ipynb rename to notebook/interpolate.ipynb diff --git a/manipulate.ipynb b/notebook/manipulate.ipynb similarity index 100% rename from manipulate.ipynb rename to notebook/manipulate.ipynb diff --git a/manipulate_note.ipynb b/notebook/manipulate_note.ipynb similarity index 100% rename from manipulate_note.ipynb rename to notebook/manipulate_note.ipynb diff --git a/sample.ipynb b/notebook/sample.ipynb similarity index 100% rename from sample.ipynb rename to notebook/sample.ipynb diff --git a/run_bedroom128.py b/scripts/run_bedroom128.py similarity index 100% rename from run_bedroom128.py rename to scripts/run_bedroom128.py diff --git a/run_bedroom128_ddim.py b/scripts/run_bedroom128_ddim.py similarity index 100% rename from run_bedroom128_ddim.py rename to scripts/run_bedroom128_ddim.py diff --git a/run_celeba64.py b/scripts/run_celeba64.py similarity index 100% rename from run_celeba64.py rename to scripts/run_celeba64.py diff --git a/run_ffhq128.py b/scripts/run_ffhq128.py similarity index 100% rename from run_ffhq128.py rename to scripts/run_ffhq128.py diff --git a/run_ffhq128_cls.py b/scripts/run_ffhq128_cls.py similarity index 100% rename from run_ffhq128_cls.py rename to scripts/run_ffhq128_cls.py diff --git a/run_ffhq128_ddim.py b/scripts/run_ffhq128_ddim.py similarity index 100% rename from run_ffhq128_ddim.py rename to scripts/run_ffhq128_ddim.py diff --git a/run_ffhq256.py b/scripts/run_ffhq256.py similarity index 100% rename from run_ffhq256.py rename to scripts/run_ffhq256.py diff --git a/run_ffhq256_cls.py b/scripts/run_ffhq256_cls.py similarity index 100% rename from run_ffhq256_cls.py rename to scripts/run_ffhq256_cls.py diff --git a/run_ffhq256_latent.py b/scripts/run_ffhq256_latent.py similarity index 100% rename from run_ffhq256_latent.py rename to scripts/run_ffhq256_latent.py diff --git a/run_horse128.py b/scripts/run_horse128.py similarity index 100% rename from run_horse128.py rename to scripts/run_horse128.py diff --git a/run_horse128_ddim.py b/scripts/run_horse128_ddim.py similarity index 100% rename from run_horse128_ddim.py rename to scripts/run_horse128_ddim.py diff --git a/align.py b/source/align.py similarity index 100% rename from align.py rename to source/align.py diff --git a/choices.py b/source/choices.py old mode 100755 new mode 100644 similarity index 100% rename from choices.py rename to source/choices.py diff --git a/config.py b/source/config.py similarity index 100% rename from config.py rename to source/config.py diff --git a/config_base.py b/source/config_base.py old mode 100755 new mode 100644 similarity index 100% rename from config_base.py rename to source/config_base.py diff --git a/data_resize_bedroom.py b/source/data_resize_bedroom.py similarity index 100% rename from data_resize_bedroom.py rename to source/data_resize_bedroom.py diff --git a/data_resize_celeba.py b/source/data_resize_celeba.py similarity index 100% rename from data_resize_celeba.py rename to source/data_resize_celeba.py diff --git a/data_resize_celebahq.py b/source/data_resize_celebahq.py similarity index 100% rename from data_resize_celebahq.py rename to source/data_resize_celebahq.py diff --git a/data_resize_ffhq.py b/source/data_resize_ffhq.py similarity index 100% rename from data_resize_ffhq.py rename to source/data_resize_ffhq.py diff --git a/data_resize_horse.py b/source/data_resize_horse.py similarity index 100% rename from data_resize_horse.py rename to source/data_resize_horse.py diff --git a/dataset.py b/source/dataset.py old mode 100755 new mode 100644 similarity index 100% rename from dataset.py rename to source/dataset.py diff --git a/dataset_util.py b/source/dataset_util.py old mode 100755 new mode 100644 similarity index 100% rename from dataset_util.py rename to source/dataset_util.py diff --git a/dist_utils.py b/source/dist_utils.py old mode 100755 new mode 100644 similarity index 100% rename from dist_utils.py rename to source/dist_utils.py diff --git a/experiment.py b/source/experiment.py old mode 100755 new mode 100644 similarity index 100% rename from experiment.py rename to source/experiment.py diff --git a/experiment_classifier.py b/source/experiment_classifier.py similarity index 100% rename from experiment_classifier.py rename to source/experiment_classifier.py diff --git a/lmdb_writer.py b/source/lmdb_writer.py old mode 100755 new mode 100644 similarity index 100% rename from lmdb_writer.py rename to source/lmdb_writer.py diff --git a/metrics.py b/source/metrics.py old mode 100755 new mode 100644 similarity index 100% rename from metrics.py rename to source/metrics.py diff --git a/predict.py b/source/predict.py similarity index 100% rename from predict.py rename to source/predict.py diff --git a/renderer.py b/source/renderer.py old mode 100755 new mode 100644 similarity index 100% rename from renderer.py rename to source/renderer.py diff --git a/ssim.py b/source/ssim.py old mode 100755 new mode 100644 similarity index 100% rename from ssim.py rename to source/ssim.py diff --git a/templates.py b/source/templates.py similarity index 100% rename from templates.py rename to source/templates.py diff --git a/templates_cls.py b/source/templates_cls.py similarity index 100% rename from templates_cls.py rename to source/templates_cls.py diff --git a/templates_latent.py b/source/templates_latent.py similarity index 100% rename from templates_latent.py rename to source/templates_latent.py From ea0e773d688d604052b81b6622731e139c745b72 Mon Sep 17 00:00:00 2001 From: olafhappy Date: Tue, 18 Mar 2025 15:46:36 +0000 Subject: [PATCH 2/3] fix: updating reorganizing structure --- source/config/config.py | 425 +++++++++++ source/config/config_base.py | 72 ++ source/dataset-preprocessing/align.py | 249 ++++++ .../data_resize_bedroom.py | 101 +++ .../data_resize_celeba.py | 120 +++ .../data_resize_celebahq.py | 120 +++ .../dataset-preprocessing/data_resize_ffhq.py | 123 +++ .../data_resize_horse.py | 100 +++ source/dataset-preprocessing/dataset.py | 716 ++++++++++++++++++ source/dataset-preprocessing/dataset_util.py | 13 + source/templates/templates.py | 323 ++++++++ source/templates/templates_cls.py | 38 + source/templates/templates_latent.py | 150 ++++ 13 files changed, 2550 insertions(+) create mode 100644 source/config/config.py create mode 100644 source/config/config_base.py create mode 100644 source/dataset-preprocessing/align.py create mode 100644 source/dataset-preprocessing/data_resize_bedroom.py create mode 100644 source/dataset-preprocessing/data_resize_celeba.py create mode 100644 source/dataset-preprocessing/data_resize_celebahq.py create mode 100644 source/dataset-preprocessing/data_resize_ffhq.py create mode 100644 source/dataset-preprocessing/data_resize_horse.py create mode 100644 source/dataset-preprocessing/dataset.py create mode 100644 source/dataset-preprocessing/dataset_util.py create mode 100644 source/templates/templates.py create mode 100644 source/templates/templates_cls.py create mode 100644 source/templates/templates_latent.py diff --git a/source/config/config.py b/source/config/config.py new file mode 100644 index 0000000..98068e8 --- /dev/null +++ b/source/config/config.py @@ -0,0 +1,425 @@ +from model.unet import ScaleAt +from model.latentnet import * +from diffusion.resample import UniformSampler +from diffusion.diffusion import space_timesteps +from typing import Tuple + +from torch.utils.data import DataLoader + +from config_base import BaseConfig +from dataset import * +from diffusion import * +from diffusion.base import GenerativeType, LossType, ModelMeanType, ModelVarType, get_named_beta_schedule +from model import * +from choices import * +from multiprocessing import get_context +import os +from dataset_util import * +from torch.utils.data.distributed import DistributedSampler + +data_paths = { + 'ffhqlmdb256': + os.path.expanduser('datasets/ffhq256.lmdb'), + # used for training a classifier + 'celeba': + os.path.expanduser('datasets/celeba'), + # used for training DPM models + 'celebalmdb': + os.path.expanduser('datasets/celeba.lmdb'), + 'celebahq': + os.path.expanduser('datasets/celebahq256.lmdb'), + 'horse256': + os.path.expanduser('datasets/horse256.lmdb'), + 'bedroom256': + os.path.expanduser('datasets/bedroom256.lmdb'), + 'celeba_anno': + os.path.expanduser('datasets/celeba_anno/list_attr_celeba.txt'), + 'celebahq_anno': + os.path.expanduser( + 'datasets/celeba_anno/CelebAMask-HQ-attribute-anno.txt'), + 'celeba_relight': + os.path.expanduser('datasets/celeba_hq_light/celeba_light.txt'), +} + + +@dataclass +class PretrainConfig(BaseConfig): + name: str + path: str + + +@dataclass +class TrainConfig(BaseConfig): + # random seed + seed: int = 0 + train_mode: TrainMode = TrainMode.diffusion + train_cond0_prob: float = 0 + train_pred_xstart_detach: bool = True + train_interpolate_prob: float = 0 + train_interpolate_img: bool = False + manipulate_mode: ManipulateMode = ManipulateMode.celebahq_all + manipulate_cls: str = None + manipulate_shots: int = None + manipulate_loss: ManipulateLossType = ManipulateLossType.bce + manipulate_znormalize: bool = False + manipulate_seed: int = 0 + accum_batches: int = 1 + autoenc_mid_attn: bool = True + batch_size: int = 16 + batch_size_eval: int = None + beatgans_gen_type: GenerativeType = GenerativeType.ddim + beatgans_loss_type: LossType = LossType.mse + beatgans_model_mean_type: ModelMeanType = ModelMeanType.eps + beatgans_model_var_type: ModelVarType = ModelVarType.fixed_large + beatgans_rescale_timesteps: bool = False + latent_infer_path: str = None + latent_znormalize: bool = False + latent_gen_type: GenerativeType = GenerativeType.ddim + latent_loss_type: LossType = LossType.mse + latent_model_mean_type: ModelMeanType = ModelMeanType.eps + latent_model_var_type: ModelVarType = ModelVarType.fixed_large + latent_rescale_timesteps: bool = False + latent_T_eval: int = 1_000 + latent_clip_sample: bool = False + latent_beta_scheduler: str = 'linear' + beta_scheduler: str = 'linear' + data_name: str = '' + data_val_name: str = None + diffusion_type: str = None + dropout: float = 0.1 + ema_decay: float = 0.9999 + eval_num_images: int = 5_000 + eval_every_samples: int = 200_000 + eval_ema_every_samples: int = 200_000 + fid_use_torch: bool = True + fp16: bool = False + grad_clip: float = 1 + img_size: int = 64 + lr: float = 0.0001 + optimizer: OptimizerType = OptimizerType.adam + weight_decay: float = 0 + model_conf: ModelConfig = None + model_name: ModelName = None + model_type: ModelType = None + net_attn: Tuple[int] = None + net_beatgans_attn_head: int = 1 + # not necessarily the same as the the number of style channels + net_beatgans_embed_channels: int = 512 + net_resblock_updown: bool = True + net_enc_use_time: bool = False + net_enc_pool: str = 'adaptivenonzero' + net_beatgans_gradient_checkpoint: bool = False + net_beatgans_resnet_two_cond: bool = False + net_beatgans_resnet_use_zero_module: bool = True + net_beatgans_resnet_scale_at: ScaleAt = ScaleAt.after_norm + net_beatgans_resnet_cond_channels: int = None + net_ch_mult: Tuple[int] = None + net_ch: int = 64 + net_enc_attn: Tuple[int] = None + net_enc_k: int = None + # number of resblocks for the encoder (half-unet) + net_enc_num_res_blocks: int = 2 + net_enc_channel_mult: Tuple[int] = None + net_enc_grad_checkpoint: bool = False + net_autoenc_stochastic: bool = False + net_latent_activation: Activation = Activation.silu + net_latent_channel_mult: Tuple[int] = (1, 2, 4) + net_latent_condition_bias: float = 0 + net_latent_dropout: float = 0 + net_latent_layers: int = None + net_latent_net_last_act: Activation = Activation.none + net_latent_net_type: LatentNetType = LatentNetType.none + net_latent_num_hid_channels: int = 1024 + net_latent_num_time_layers: int = 2 + net_latent_skip_layers: Tuple[int] = None + net_latent_time_emb_channels: int = 64 + net_latent_use_norm: bool = False + net_latent_time_last_act: bool = False + net_num_res_blocks: int = 2 + # number of resblocks for the UNET + net_num_input_res_blocks: int = None + net_enc_num_cls: int = None + num_workers: int = 4 + parallel: bool = False + postfix: str = '' + sample_size: int = 64 + sample_every_samples: int = 20_000 + save_every_samples: int = 100_000 + style_ch: int = 512 + T_eval: int = 1_000 + T_sampler: str = 'uniform' + T: int = 1_000 + total_samples: int = 10_000_000 + warmup: int = 0 + pretrain: PretrainConfig = None + continue_from: PretrainConfig = None + eval_programs: Tuple[str] = None + # if present load the checkpoint from this path instead + eval_path: str = None + base_dir: str = 'checkpoints' + use_cache_dataset: bool = False + data_cache_dir: str = os.path.expanduser('~/cache') + work_cache_dir: str = os.path.expanduser('~/mycache') + # to be overridden + name: str = '' + + def __post_init__(self): + self.batch_size_eval = self.batch_size_eval or self.batch_size + self.data_val_name = self.data_val_name or self.data_name + + def scale_up_gpus(self, num_gpus, num_nodes=1): + self.eval_ema_every_samples *= num_gpus * num_nodes + self.eval_every_samples *= num_gpus * num_nodes + self.sample_every_samples *= num_gpus * num_nodes + self.batch_size *= num_gpus * num_nodes + self.batch_size_eval *= num_gpus * num_nodes + return self + + @property + def batch_size_effective(self): + return self.batch_size * self.accum_batches + + @property + def fid_cache(self): + # we try to use the local dirs to reduce the load over network drives + # hopefully, this would reduce the disconnection problems with sshfs + return f'{self.work_cache_dir}/eval_images/{self.data_name}_size{self.img_size}_{self.eval_num_images}' + + @property + def data_path(self): + # may use the cache dir + path = data_paths[self.data_name] + if self.use_cache_dataset and path is not None: + path = use_cached_dataset_path( + path, f'{self.data_cache_dir}/{self.data_name}') + return path + + @property + def logdir(self): + return f'{self.base_dir}/{self.name}' + + @property + def generate_dir(self): + # we try to use the local dirs to reduce the load over network drives + # hopefully, this would reduce the disconnection problems with sshfs + return f'{self.work_cache_dir}/gen_images/{self.name}' + + def _make_diffusion_conf(self, T=None): + if self.diffusion_type == 'beatgans': + # can use T < self.T for evaluation + # follows the guided-diffusion repo conventions + # t's are evenly spaced + if self.beatgans_gen_type == GenerativeType.ddpm: + section_counts = [T] + elif self.beatgans_gen_type == GenerativeType.ddim: + section_counts = f'ddim{T}' + else: + raise NotImplementedError() + + return SpacedDiffusionBeatGansConfig( + gen_type=self.beatgans_gen_type, + model_type=self.model_type, + betas=get_named_beta_schedule(self.beta_scheduler, self.T), + model_mean_type=self.beatgans_model_mean_type, + model_var_type=self.beatgans_model_var_type, + loss_type=self.beatgans_loss_type, + rescale_timesteps=self.beatgans_rescale_timesteps, + use_timesteps=space_timesteps(num_timesteps=self.T, + section_counts=section_counts), + fp16=self.fp16, + ) + else: + raise NotImplementedError() + + def _make_latent_diffusion_conf(self, T=None): + # can use T < self.T for evaluation + # follows the guided-diffusion repo conventions + # t's are evenly spaced + if self.latent_gen_type == GenerativeType.ddpm: + section_counts = [T] + elif self.latent_gen_type == GenerativeType.ddim: + section_counts = f'ddim{T}' + else: + raise NotImplementedError() + + return SpacedDiffusionBeatGansConfig( + train_pred_xstart_detach=self.train_pred_xstart_detach, + gen_type=self.latent_gen_type, + # latent's model is always ddpm + model_type=ModelType.ddpm, + # latent shares the beta scheduler and full T + betas=get_named_beta_schedule(self.latent_beta_scheduler, self.T), + model_mean_type=self.latent_model_mean_type, + model_var_type=self.latent_model_var_type, + loss_type=self.latent_loss_type, + rescale_timesteps=self.latent_rescale_timesteps, + use_timesteps=space_timesteps(num_timesteps=self.T, + section_counts=section_counts), + fp16=self.fp16, + ) + + @property + def model_out_channels(self): + return 3 + + def make_T_sampler(self): + if self.T_sampler == 'uniform': + return UniformSampler(self.T) + else: + raise NotImplementedError() + + def make_diffusion_conf(self): + return self._make_diffusion_conf(self.T) + + def make_eval_diffusion_conf(self): + return self._make_diffusion_conf(T=self.T_eval) + + def make_latent_diffusion_conf(self): + return self._make_latent_diffusion_conf(T=self.T) + + def make_latent_eval_diffusion_conf(self): + # latent can have different eval T + return self._make_latent_diffusion_conf(T=self.latent_T_eval) + + def make_dataset(self, path=None, **kwargs): + if self.data_name == 'ffhqlmdb256': + return FFHQlmdb(path=path or self.data_path, + image_size=self.img_size, + **kwargs) + elif self.data_name == 'horse256': + return Horse_lmdb(path=path or self.data_path, + image_size=self.img_size, + **kwargs) + elif self.data_name == 'bedroom256': + return Horse_lmdb(path=path or self.data_path, + image_size=self.img_size, + **kwargs) + elif self.data_name == 'celebalmdb': + # always use d2c crop + return CelebAlmdb(path=path or self.data_path, + image_size=self.img_size, + original_resolution=None, + crop_d2c=True, + **kwargs) + else: + raise NotImplementedError() + + def make_loader(self, + dataset, + shuffle: bool, + num_worker: bool = None, + drop_last: bool = True, + batch_size: int = None, + parallel: bool = False): + if parallel and distributed.is_initialized(): + # drop last to make sure that there is no added special indexes + sampler = DistributedSampler(dataset, + shuffle=shuffle, + drop_last=True) + else: + sampler = None + return DataLoader( + dataset, + batch_size=batch_size or self.batch_size, + sampler=sampler, + # with sampler, use the sample instead of this option + shuffle=False if sampler else shuffle, + num_workers=num_worker or self.num_workers, + pin_memory=True, + drop_last=drop_last, + multiprocessing_context=get_context('fork'), + ) + + def make_model_conf(self): + if self.model_name == ModelName.beatgans_ddpm: + self.model_type = ModelType.ddpm + self.model_conf = BeatGANsUNetConfig( + attention_resolutions=self.net_attn, + channel_mult=self.net_ch_mult, + conv_resample=True, + dims=2, + dropout=self.dropout, + embed_channels=self.net_beatgans_embed_channels, + image_size=self.img_size, + in_channels=3, + model_channels=self.net_ch, + num_classes=None, + num_head_channels=-1, + num_heads_upsample=-1, + num_heads=self.net_beatgans_attn_head, + num_res_blocks=self.net_num_res_blocks, + num_input_res_blocks=self.net_num_input_res_blocks, + out_channels=self.model_out_channels, + resblock_updown=self.net_resblock_updown, + use_checkpoint=self.net_beatgans_gradient_checkpoint, + use_new_attention_order=False, + resnet_two_cond=self.net_beatgans_resnet_two_cond, + resnet_use_zero_module=self. + net_beatgans_resnet_use_zero_module, + ) + elif self.model_name in [ + ModelName.beatgans_autoenc, + ]: + cls = BeatGANsAutoencConfig + # supports both autoenc and vaeddpm + if self.model_name == ModelName.beatgans_autoenc: + self.model_type = ModelType.autoencoder + else: + raise NotImplementedError() + + if self.net_latent_net_type == LatentNetType.none: + latent_net_conf = None + elif self.net_latent_net_type == LatentNetType.skip: + latent_net_conf = MLPSkipNetConfig( + num_channels=self.style_ch, + skip_layers=self.net_latent_skip_layers, + num_hid_channels=self.net_latent_num_hid_channels, + num_layers=self.net_latent_layers, + num_time_emb_channels=self.net_latent_time_emb_channels, + activation=self.net_latent_activation, + use_norm=self.net_latent_use_norm, + condition_bias=self.net_latent_condition_bias, + dropout=self.net_latent_dropout, + last_act=self.net_latent_net_last_act, + num_time_layers=self.net_latent_num_time_layers, + time_last_act=self.net_latent_time_last_act, + ) + else: + raise NotImplementedError() + + self.model_conf = cls( + attention_resolutions=self.net_attn, + channel_mult=self.net_ch_mult, + conv_resample=True, + dims=2, + dropout=self.dropout, + embed_channels=self.net_beatgans_embed_channels, + enc_out_channels=self.style_ch, + enc_pool=self.net_enc_pool, + enc_num_res_block=self.net_enc_num_res_blocks, + enc_channel_mult=self.net_enc_channel_mult, + enc_grad_checkpoint=self.net_enc_grad_checkpoint, + enc_attn_resolutions=self.net_enc_attn, + image_size=self.img_size, + in_channels=3, + model_channels=self.net_ch, + num_classes=None, + num_head_channels=-1, + num_heads_upsample=-1, + num_heads=self.net_beatgans_attn_head, + num_res_blocks=self.net_num_res_blocks, + num_input_res_blocks=self.net_num_input_res_blocks, + out_channels=self.model_out_channels, + resblock_updown=self.net_resblock_updown, + use_checkpoint=self.net_beatgans_gradient_checkpoint, + use_new_attention_order=False, + resnet_two_cond=self.net_beatgans_resnet_two_cond, + resnet_use_zero_module=self. + net_beatgans_resnet_use_zero_module, + latent_net_conf=latent_net_conf, + resnet_cond_channels=self.net_beatgans_resnet_cond_channels, + ) + else: + raise NotImplementedError(self.model_name) + + return self.model_conf diff --git a/source/config/config_base.py b/source/config/config_base.py new file mode 100644 index 0000000..385f9ee --- /dev/null +++ b/source/config/config_base.py @@ -0,0 +1,72 @@ +import json +import os +from copy import deepcopy +from dataclasses import dataclass + + +@dataclass +class BaseConfig: + def clone(self): + return deepcopy(self) + + def inherit(self, another): + """inherit common keys from a given config""" + common_keys = set(self.__dict__.keys()) & set(another.__dict__.keys()) + for k in common_keys: + setattr(self, k, getattr(another, k)) + + def propagate(self): + """push down the configuration to all members""" + for k, v in self.__dict__.items(): + if isinstance(v, BaseConfig): + v.inherit(self) + v.propagate() + + def save(self, save_path): + """save config to json file""" + dirname = os.path.dirname(save_path) + if not os.path.exists(dirname): + os.makedirs(dirname) + conf = self.as_dict_jsonable() + with open(save_path, 'w') as f: + json.dump(conf, f) + + def load(self, load_path): + """load json config""" + with open(load_path) as f: + conf = json.load(f) + self.from_dict(conf) + + def from_dict(self, dict, strict=False): + for k, v in dict.items(): + if not hasattr(self, k): + if strict: + raise ValueError(f"loading extra '{k}'") + else: + print(f"loading extra '{k}'") + continue + if isinstance(self.__dict__[k], BaseConfig): + self.__dict__[k].from_dict(v) + else: + self.__dict__[k] = v + + def as_dict_jsonable(self): + conf = {} + for k, v in self.__dict__.items(): + if isinstance(v, BaseConfig): + conf[k] = v.as_dict_jsonable() + else: + if jsonable(v): + conf[k] = v + else: + # ignore not jsonable + pass + return conf + + +def jsonable(x): + try: + json.dumps(x) + return True + except TypeError: + return False diff --git a/source/dataset-preprocessing/align.py b/source/dataset-preprocessing/align.py new file mode 100644 index 0000000..8783e37 --- /dev/null +++ b/source/dataset-preprocessing/align.py @@ -0,0 +1,249 @@ +import bz2 +import os +import os.path as osp +import sys +from multiprocessing import Pool + +import dlib +import numpy as np +import PIL.Image +import requests +import scipy.ndimage +from tqdm import tqdm +from argparse import ArgumentParser + +LANDMARKS_MODEL_URL = 'http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2' + + +def image_align(src_file, + dst_file, + face_landmarks, + output_size=1024, + transform_size=4096, + enable_padding=True): + # Align function from FFHQ dataset pre-processing step + # https://github.com/NVlabs/ffhq-dataset/blob/master/download_ffhq.py + + lm = np.array(face_landmarks) + lm_chin = lm[0:17] # left-right + lm_eyebrow_left = lm[17:22] # left-right + lm_eyebrow_right = lm[22:27] # left-right + lm_nose = lm[27:31] # top-down + lm_nostrils = lm[31:36] # top-down + lm_eye_left = lm[36:42] # left-clockwise + lm_eye_right = lm[42:48] # left-clockwise + lm_mouth_outer = lm[48:60] # left-clockwise + lm_mouth_inner = lm[60:68] # left-clockwise + + # Calculate auxiliary vectors. + eye_left = np.mean(lm_eye_left, axis=0) + eye_right = np.mean(lm_eye_right, axis=0) + eye_avg = (eye_left + eye_right) * 0.5 + eye_to_eye = eye_right - eye_left + mouth_left = lm_mouth_outer[0] + mouth_right = lm_mouth_outer[6] + mouth_avg = (mouth_left + mouth_right) * 0.5 + eye_to_mouth = mouth_avg - eye_avg + + # Choose oriented crop rectangle. + x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] + x /= np.hypot(*x) + x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8) + y = np.flipud(x) * [-1, 1] + c = eye_avg + eye_to_mouth * 0.1 + quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) + qsize = np.hypot(*x) * 2 + + # Load in-the-wild image. + if not os.path.isfile(src_file): + print( + '\nCannot find source image. Please run "--wilds" before "--align".' + ) + return + img = PIL.Image.open(src_file) + img = img.convert('RGB') + + # Shrink. + shrink = int(np.floor(qsize / output_size * 0.5)) + if shrink > 1: + rsize = (int(np.rint(float(img.size[0]) / shrink)), + int(np.rint(float(img.size[1]) / shrink))) + img = img.resize(rsize, PIL.Image.ANTIALIAS) + quad /= shrink + qsize /= shrink + + # Crop. + border = max(int(np.rint(qsize * 0.1)), 3) + crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), + int(np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1])))) + crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), + min(crop[2] + border, + img.size[0]), min(crop[3] + border, img.size[1])) + if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]: + img = img.crop(crop) + quad -= crop[0:2] + + # Pad. + pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), + int(np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1])))) + pad = (max(-pad[0] + border, + 0), max(-pad[1] + border, + 0), max(pad[2] - img.size[0] + border, + 0), max(pad[3] - img.size[1] + border, 0)) + if enable_padding and max(pad) > border - 4: + pad = np.maximum(pad, int(np.rint(qsize * 0.3))) + img = np.pad(np.float32(img), + ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect') + h, w, _ = img.shape + y, x, _ = np.ogrid[:h, :w, :1] + mask = np.maximum( + 1.0 - + np.minimum(np.float32(x) / pad[0], + np.float32(w - 1 - x) / pad[2]), 1.0 - + np.minimum(np.float32(y) / pad[1], + np.float32(h - 1 - y) / pad[3])) + blur = qsize * 0.02 + img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - + img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) + img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0) + img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), + 'RGB') + quad += pad[:2] + + # Transform. + img = img.transform((transform_size, transform_size), PIL.Image.QUAD, + (quad + 0.5).flatten(), PIL.Image.BILINEAR) + if output_size < transform_size: + img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS) + + # Save aligned image. + img.save(dst_file, 'PNG') + + +class LandmarksDetector: + def __init__(self, predictor_model_path): + """ + :param predictor_model_path: path to shape_predictor_68_face_landmarks.dat file + """ + self.detector = dlib.get_frontal_face_detector( + ) # cnn_face_detection_model_v1 also can be used + self.shape_predictor = dlib.shape_predictor(predictor_model_path) + + def get_landmarks(self, image): + img = dlib.load_rgb_image(image) + dets = self.detector(img, 1) + + for detection in dets: + face_landmarks = [ + (item.x, item.y) + for item in self.shape_predictor(img, detection).parts() + ] + yield face_landmarks + + +def unpack_bz2(src_path): + dst_path = src_path[:-4] + if os.path.exists(dst_path): + print('cached') + return dst_path + data = bz2.BZ2File(src_path).read() + with open(dst_path, 'wb') as fp: + fp.write(data) + return dst_path + + +def work_landmark(raw_img_path, img_name, face_landmarks): + face_img_name = '%s.png' % (os.path.splitext(img_name)[0], ) + aligned_face_path = os.path.join(ALIGNED_IMAGES_DIR, face_img_name) + if os.path.exists(aligned_face_path): + return + image_align(raw_img_path, + aligned_face_path, + face_landmarks, + output_size=256) + + +def get_file(src, tgt): + if os.path.exists(tgt): + print('cached') + return tgt + tgt_dir = os.path.dirname(tgt) + if not os.path.exists(tgt_dir): + os.makedirs(tgt_dir) + file = requests.get(src) + open(tgt, 'wb').write(file.content) + return tgt + + +if __name__ == "__main__": + """ + Extracts and aligns all faces from images using DLib and a function from original FFHQ dataset preparation step + python align_images.py /raw_images /aligned_images + """ + parser = ArgumentParser() + parser.add_argument("-i", + "--input_imgs_path", + type=str, + default="imgs", + help="input images directory path") + parser.add_argument("-o", + "--output_imgs_path", + type=str, + default="imgs_align", + help="output images directory path") + + args = parser.parse_args() + + # takes very long time ... + landmarks_model_path = unpack_bz2( + get_file( + 'http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2', + 'temp/shape_predictor_68_face_landmarks.dat.bz2')) + + # RAW_IMAGES_DIR = sys.argv[1] + # ALIGNED_IMAGES_DIR = sys.argv[2] + RAW_IMAGES_DIR = args.input_imgs_path + ALIGNED_IMAGES_DIR = args.output_imgs_path + + if not osp.exists(ALIGNED_IMAGES_DIR): os.makedirs(ALIGNED_IMAGES_DIR) + + files = os.listdir(RAW_IMAGES_DIR) + print(f'total img files {len(files)}') + with tqdm(total=len(files)) as progress: + + def cb(*args): + # print('update') + progress.update() + + def err_cb(e): + print('error:', e) + + with Pool(8) as pool: + res = [] + landmarks_detector = LandmarksDetector(landmarks_model_path) + for img_name in files: + raw_img_path = os.path.join(RAW_IMAGES_DIR, img_name) + # print('img_name:', img_name) + for i, face_landmarks in enumerate( + landmarks_detector.get_landmarks(raw_img_path), + start=1): + # assert i == 1, f'{i}' + # print(i, face_landmarks) + # face_img_name = '%s_%02d.png' % (os.path.splitext(img_name)[0], i) + # aligned_face_path = os.path.join(ALIGNED_IMAGES_DIR, face_img_name) + # image_align(raw_img_path, aligned_face_path, face_landmarks, output_size=256) + + work_landmark(raw_img_path, img_name, face_landmarks) + progress.update() + + # job = pool.apply_async( + # work_landmark, + # (raw_img_path, img_name, face_landmarks), + # callback=cb, + # error_callback=err_cb, + # ) + # res.append(job) + + # pool.close() + # pool.join() + print(f"output aligned images at: {ALIGNED_IMAGES_DIR}") diff --git a/source/dataset-preprocessing/data_resize_bedroom.py b/source/dataset-preprocessing/data_resize_bedroom.py new file mode 100644 index 0000000..32a7351 --- /dev/null +++ b/source/dataset-preprocessing/data_resize_bedroom.py @@ -0,0 +1,101 @@ +import argparse +import multiprocessing +import os +from os.path import join, exists +from functools import partial +from io import BytesIO +import shutil + +import lmdb +from PIL import Image +from torchvision.datasets import LSUNClass +from torchvision.transforms import functional as trans_fn +from tqdm import tqdm + +from multiprocessing import Process, Queue + + +def resize_and_convert(img, size, resample, quality=100): + img = trans_fn.resize(img, size, resample) + img = trans_fn.center_crop(img, size) + buffer = BytesIO() + img.save(buffer, format="webp", quality=quality) + val = buffer.getvalue() + + return val + + +def resize_multiple(img, + sizes=(128, 256, 512, 1024), + resample=Image.LANCZOS, + quality=100): + imgs = [] + + for size in sizes: + imgs.append(resize_and_convert(img, size, resample, quality)) + + return imgs + + +def resize_worker(idx, img, sizes, resample): + img = img.convert("RGB") + out = resize_multiple(img, sizes=sizes, resample=resample) + return idx, out + + +from torch.utils.data import Dataset, DataLoader + + +class ConvertDataset(Dataset): + def __init__(self, data) -> None: + self.data = data + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + img, _ = self.data[index] + bytes = resize_and_convert(img, 256, Image.LANCZOS, quality=90) + return bytes + + +if __name__ == "__main__": + """ + converting lsun' original lmdb to our lmdb, which is somehow more performant. + """ + from tqdm import tqdm + + # path to the original lsun's lmdb + src_path = 'datasets/bedroom_train_lmdb' + out_path = 'datasets/bedroom256.lmdb' + + dataset = LSUNClass(root=os.path.expanduser(src_path)) + dataset = ConvertDataset(dataset) + loader = DataLoader(dataset, + batch_size=50, + num_workers=12, + collate_fn=collate_fn, + shuffle=False) + + target = os.path.expanduser(out_path) + if os.path.exists(target): + shutil.rmtree(target) + + with lmdb.open(target, map_size=1024**4, readahead=False) as env: + with tqdm(total=len(dataset)) as progress: + i = 0 + for batch in loader: + with env.begin(write=True) as txn: + for img in batch: + key = f"{256}-{str(i).zfill(7)}".encode("utf-8") + # print(key) + txn.put(key, img) + i += 1 + progress.update() + # if i == 1000: + # break + # if total == len(imgset): + # break + + with env.begin(write=True) as txn: + txn.put("length".encode("utf-8"), str(i).encode("utf-8")) diff --git a/source/dataset-preprocessing/data_resize_celeba.py b/source/dataset-preprocessing/data_resize_celeba.py new file mode 100644 index 0000000..02891cc --- /dev/null +++ b/source/dataset-preprocessing/data_resize_celeba.py @@ -0,0 +1,120 @@ +import argparse +import multiprocessing +import os +import shutil +from functools import partial +from io import BytesIO +from multiprocessing import Process, Queue +from os.path import exists, join +from pathlib import Path + +import lmdb +from PIL import Image +from torch.utils.data import DataLoader, Dataset +from torchvision.datasets import LSUNClass +from torchvision.transforms import functional as trans_fn +from tqdm import tqdm + + +def resize_and_convert(img, size, resample, quality=100): + if size is not None: + img = trans_fn.resize(img, size, resample) + img = trans_fn.center_crop(img, size) + + buffer = BytesIO() + img.save(buffer, format="webp", quality=quality) + val = buffer.getvalue() + + return val + + +# Define a top-level collate function instead of using a lambda to avoid pickling issues +def collate_fn(batch): + return batch + +def resize_multiple(img, + sizes=(128, 256, 512, 1024), + resample=Image.LANCZOS, + quality=100): + imgs = [] + + for size in sizes: + imgs.append(resize_and_convert(img, size, resample, quality)) + + return imgs + + +def resize_worker(idx, img, sizes, resample): + img = img.convert("RGB") + out = resize_multiple(img, sizes=sizes, resample=resample) + return idx, out + + +class ConvertDataset(Dataset): + def __init__(self, data, size) -> None: + self.data = data + self.size = size + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + img = self.data[index] + bytes = resize_and_convert(img, self.size, Image.LANCZOS, quality=100) + return bytes + + +class ImageFolder(Dataset): + def __init__(self, folder, ext='jpg'): + super().__init__() + paths = sorted([p for p in Path(f'{folder}').glob(f'*.{ext}')]) + self.paths = paths + + def __len__(self): + return len(self.paths) + + def __getitem__(self, index): + path = os.path.join(self.paths[index]) + img = Image.open(path) + return img + + +if __name__ == "__main__": + from tqdm import tqdm + + out_path = 'datasets/celeba.lmdb' + in_path = 'datasets/celeba' + ext = 'jpg' + size = None + + dataset = ImageFolder(in_path, ext) + print('len:', len(dataset)) + dataset = ConvertDataset(dataset, size) + loader = DataLoader(dataset, + batch_size=50, + num_workers=12, + collate_fn=collate_fn, + shuffle=False) + + target = os.path.expanduser(out_path) + if os.path.exists(target): + shutil.rmtree(target) + + with lmdb.open(target, map_size=1024**4, readahead=False) as env: + with tqdm(total=len(dataset)) as progress: + i = 0 + for batch in loader: + with env.begin(write=True) as txn: + for img in batch: + key = f"{size}-{str(i).zfill(7)}".encode("utf-8") + # print(key) + txn.put(key, img) + i += 1 + progress.update() + # if i == 1000: + # break + # if total == len(imgset): + # break + + with env.begin(write=True) as txn: + txn.put("length".encode("utf-8"), str(i).encode("utf-8")) diff --git a/source/dataset-preprocessing/data_resize_celebahq.py b/source/dataset-preprocessing/data_resize_celebahq.py new file mode 100644 index 0000000..c6a2d7f --- /dev/null +++ b/source/dataset-preprocessing/data_resize_celebahq.py @@ -0,0 +1,120 @@ +import argparse +import multiprocessing +from functools import partial +from io import BytesIO +from pathlib import Path + +import lmdb +from PIL import Image +from torch.utils.data import Dataset +from torchvision.transforms import functional as trans_fn +from tqdm import tqdm +import os + + +def resize_and_convert(img, size, resample, quality=100): + img = trans_fn.resize(img, size, resample) + img = trans_fn.center_crop(img, size) + buffer = BytesIO() + img.save(buffer, format="jpeg", quality=quality) + val = buffer.getvalue() + + return val + +# Define a top-level collate function instead of using a lambda to avoid pickling issues +def collate_fn(batch): + return batch + +def resize_multiple(img, + sizes=(128, 256, 512, 1024), + resample=Image.LANCZOS, + quality=100): + imgs = [] + + for size in sizes: + imgs.append(resize_and_convert(img, size, resample, quality)) + + return imgs + + +def resize_worker(img_file, sizes, resample): + i, (file, idx) = img_file + img = Image.open(file) + img = img.convert("RGB") + out = resize_multiple(img, sizes=sizes, resample=resample) + + return i, idx, out + + +def prepare(env, + paths, + n_worker, + sizes=(128, 256, 512, 1024), + resample=Image.LANCZOS): + resize_fn = partial(resize_worker, sizes=sizes, resample=resample) + + # index = filename in int + indexs = [] + for each in paths: + file = os.path.basename(each) + name, ext = file.split('.') + idx = int(name) + indexs.append(idx) + + # sort by file index + files = sorted(zip(paths, indexs), key=lambda x: x[1]) + files = list(enumerate(files)) + total = 0 + + with multiprocessing.Pool(n_worker) as pool: + for i, idx, imgs in tqdm(pool.imap_unordered(resize_fn, files)): + for size, img in zip(sizes, imgs): + key = f"{size}-{str(idx).zfill(5)}".encode("utf-8") + + with env.begin(write=True) as txn: + txn.put(key, img) + + total += 1 + + with env.begin(write=True) as txn: + txn.put("length".encode("utf-8"), str(total).encode("utf-8")) + + +class ImageFolder(Dataset): + def __init__(self, folder, exts=['jpg']): + super().__init__() + self.paths = [ + p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}') + ] + + def __len__(self): + return len(self.paths) + + def __getitem__(self, index): + path = os.path.join(self.folder, self.paths[index]) + img = Image.open(path) + return img + + +if __name__ == "__main__": + """ + converting celebahq images to lmdb + """ + num_workers = 16 + in_path = 'datasets/celebahq' + out_path = 'datasets/celebahq256.lmdb' + + resample_map = {"lanczos": Image.LANCZOS, "bilinear": Image.BILINEAR} + resample = resample_map['lanczos'] + + sizes = [256] + + print(f"Make dataset of image sizes:", ", ".join(str(s) for s in sizes)) + + # imgset = datasets.ImageFolder(in_path) + # imgset = ImageFolder(in_path) + exts = ['jpg'] + paths = [p for ext in exts for p in Path(f'{in_path}').glob(f'**/*.{ext}')] + + with lmdb.open(out_path, map_size=1024**4, readahead=False) as env: + prepare(env, paths, num_workers, sizes=sizes, resample=resample) diff --git a/source/dataset-preprocessing/data_resize_ffhq.py b/source/dataset-preprocessing/data_resize_ffhq.py new file mode 100644 index 0000000..226c8ec --- /dev/null +++ b/source/dataset-preprocessing/data_resize_ffhq.py @@ -0,0 +1,123 @@ +import argparse +import multiprocessing +from functools import partial +from io import BytesIO +from pathlib import Path + +import lmdb +from PIL import Image +from torch.utils.data import Dataset +from torchvision.transforms import functional as trans_fn +from tqdm import tqdm +import os + + +def resize_and_convert(img, size, resample, quality=100): + img = trans_fn.resize(img, size, resample) + img = trans_fn.center_crop(img, size) + buffer = BytesIO() + img.save(buffer, format="jpeg", quality=quality) + val = buffer.getvalue() + + return val + + +def resize_multiple(img, + sizes=(128, 256, 512, 1024), + resample=Image.LANCZOS, + quality=100): + imgs = [] + + for size in sizes: + imgs.append(resize_and_convert(img, size, resample, quality)) + + return imgs + + +def resize_worker(img_file, sizes, resample): + i, (file, idx) = img_file + img = Image.open(file) + img = img.convert("RGB") + out = resize_multiple(img, sizes=sizes, resample=resample) + + return i, idx, out + + +def prepare(env, + paths, + n_worker, + sizes=(128, 256, 512, 1024), + resample=Image.LANCZOS): + resize_fn = partial(resize_worker, sizes=sizes, resample=resample) + + # index = filename in int + indexs = [] + for each in paths: + file = os.path.basename(each) + name, ext = file.split('.') + idx = int(name) + indexs.append(idx) + + # sort by file index + files = sorted(zip(paths, indexs), key=lambda x: x[1]) + files = list(enumerate(files)) + total = 0 + + with multiprocessing.Pool(n_worker) as pool: + for i, idx, imgs in tqdm(pool.imap_unordered(resize_fn, files)): + for size, img in zip(sizes, imgs): + key = f"{size}-{str(idx).zfill(5)}".encode("utf-8") + + with env.begin(write=True) as txn: + txn.put(key, img) + + total += 1 + + with env.begin(write=True) as txn: + txn.put("length".encode("utf-8"), str(total).encode("utf-8")) + + +class ImageFolder(Dataset): + def __init__(self, folder, exts=['jpg']): + super().__init__() + self.paths = [ + p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}') + ] + + def __len__(self): + return len(self.paths) + + def __getitem__(self, index): + path = os.path.join(self.folder, self.paths[index]) + img = Image.open(path) + return img + + +if __name__ == "__main__": + """ + converting ffhq images to lmdb + """ + num_workers = 16 + # original ffhq data path + in_path = 'datasets/ffhq' + # target output path + out_path = 'datasets/ffhq.lmdb' + + if not os.path.exists(out_path): + os.makedirs(out_path) + + resample_map = {"lanczos": Image.LANCZOS, "bilinear": Image.BILINEAR} + resample = resample_map['lanczos'] + + sizes = [256] + + print(f"Make dataset of image sizes:", ", ".join(str(s) for s in sizes)) + + # imgset = datasets.ImageFolder(in_path) + # imgset = ImageFolder(in_path) + exts = ['jpg'] + paths = [p for ext in exts for p in Path(f'{in_path}').glob(f'**/*.{ext}')] + # print(paths[:10]) + + with lmdb.open(out_path, map_size=1024**4, readahead=False) as env: + prepare(env, paths, num_workers, sizes=sizes, resample=resample) diff --git a/source/dataset-preprocessing/data_resize_horse.py b/source/dataset-preprocessing/data_resize_horse.py new file mode 100644 index 0000000..6893613 --- /dev/null +++ b/source/dataset-preprocessing/data_resize_horse.py @@ -0,0 +1,100 @@ +import argparse +import multiprocessing +import os +import shutil +from functools import partial +from io import BytesIO +from multiprocessing import Process, Queue +from os.path import exists, join + +import lmdb +from PIL import Image +from torch.utils.data import DataLoader, Dataset +from torchvision.datasets import LSUNClass +from torchvision.transforms import functional as trans_fn +from tqdm import tqdm + + +def resize_and_convert(img, size, resample, quality=100): + img = trans_fn.resize(img, size, resample) + img = trans_fn.center_crop(img, size) + buffer = BytesIO() + img.save(buffer, format="webp", quality=quality) + val = buffer.getvalue() + + return val + + +def resize_multiple(img, + sizes=(128, 256, 512, 1024), + resample=Image.LANCZOS, + quality=100): + imgs = [] + + for size in sizes: + imgs.append(resize_and_convert(img, size, resample, quality)) + + return imgs + + +def resize_worker(idx, img, sizes, resample): + img = img.convert("RGB") + out = resize_multiple(img, sizes=sizes, resample=resample) + return idx, out + + +class ConvertDataset(Dataset): + def __init__(self, data) -> None: + self.data = data + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + img, _ = self.data[index] + bytes = resize_and_convert(img, 256, Image.LANCZOS, quality=90) + return bytes + +# Define a top-level collate function instead of using a lambda to avoid pickling issues +def collate_fn(batch): + return batch + +if __name__ == "__main__": + """ + converting lsun' original lmdb to our lmdb, which is somehow more performant. + """ + from tqdm import tqdm + + # path to the original lsun's lmdb + src_path = 'datasets/horse_train_lmdb' + out_path = 'datasets/horse256.lmdb' + + dataset = LSUNClass(root=os.path.expanduser(src_path)) + dataset = ConvertDataset(dataset) + loader = DataLoader(dataset, + batch_size=50, + num_workers=16, + collate_fn=collate_fn) + + target = os.path.expanduser(out_path) + if os.path.exists(target): + shutil.rmtree(target) + + with lmdb.open(target, map_size=1024**4, readahead=False) as env: + with tqdm(total=len(dataset)) as progress: + i = 0 + for batch in loader: + with env.begin(write=True) as txn: + for img in batch: + key = f"{256}-{str(i).zfill(7)}".encode("utf-8") + # print(key) + txn.put(key, img) + i += 1 + progress.update() + # if i == 1000: + # break + # if total == len(imgset): + # break + + with env.begin(write=True) as txn: + txn.put("length".encode("utf-8"), str(i).encode("utf-8")) diff --git a/source/dataset-preprocessing/dataset.py b/source/dataset-preprocessing/dataset.py new file mode 100644 index 0000000..627671f --- /dev/null +++ b/source/dataset-preprocessing/dataset.py @@ -0,0 +1,716 @@ +import os +from io import BytesIO +from pathlib import Path + +import lmdb +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms +from torchvision.datasets import CIFAR10, LSUNClass +import torch +import pandas as pd + +import torchvision.transforms.functional as Ftrans + + +class ImageDataset(Dataset): + def __init__( + self, + folder, + image_size, + exts=['jpg'], + do_augment: bool = True, + do_transform: bool = True, + do_normalize: bool = True, + sort_names=False, + has_subdir: bool = True, + ): + super().__init__() + self.folder = folder + self.image_size = image_size + + # relative paths (make it shorter, saves memory and faster to sort) + if has_subdir: + self.paths = [ + p.relative_to(folder) for ext in exts + for p in Path(f'{folder}').glob(f'**/*.{ext}') + ] + else: + self.paths = [ + p.relative_to(folder) for ext in exts + for p in Path(f'{folder}').glob(f'*.{ext}') + ] + if sort_names: + self.paths = sorted(self.paths) + + transform = [ + transforms.Resize(image_size), + transforms.CenterCrop(image_size), + ] + if do_augment: + transform.append(transforms.RandomHorizontalFlip()) + if do_transform: + transform.append(transforms.ToTensor()) + if do_normalize: + transform.append( + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) + self.transform = transforms.Compose(transform) + + def __len__(self): + return len(self.paths) + + def __getitem__(self, index): + path = os.path.join(self.folder, self.paths[index]) + img = Image.open(path) + # if the image is 'rgba'! + img = img.convert('RGB') + if self.transform is not None: + img = self.transform(img) + return {'img': img, 'index': index} + + +class SubsetDataset(Dataset): + def __init__(self, dataset, size): + assert len(dataset) >= size + self.dataset = dataset + self.size = size + + def __len__(self): + return self.size + + def __getitem__(self, index): + assert index < self.size + return self.dataset[index] + + +class BaseLMDB(Dataset): + def __init__(self, path, original_resolution, zfill: int = 5): + self.original_resolution = original_resolution + self.zfill = zfill + self.env = lmdb.open( + path, + max_readers=32, + readonly=True, + lock=False, + readahead=False, + meminit=False, + ) + + if not self.env: + raise IOError('Cannot open lmdb dataset', path) + + with self.env.begin(write=False) as txn: + self.length = int( + txn.get('length'.encode('utf-8')).decode('utf-8')) + + def __len__(self): + return self.length + + def __getitem__(self, index): + with self.env.begin(write=False) as txn: + key = f'{self.original_resolution}-{str(index).zfill(self.zfill)}'.encode( + 'utf-8') + img_bytes = txn.get(key) + + buffer = BytesIO(img_bytes) + img = Image.open(buffer) + return img + + +def make_transform( + image_size, + flip_prob=0.5, + crop_d2c=False, +): + if crop_d2c: + transform = [ + d2c_crop(), + transforms.Resize(image_size), + ] + else: + transform = [ + transforms.Resize(image_size), + transforms.CenterCrop(image_size), + ] + transform.append(transforms.RandomHorizontalFlip(p=flip_prob)) + transform.append(transforms.ToTensor()) + transform.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) + transform = transforms.Compose(transform) + return transform + + +class FFHQlmdb(Dataset): + def __init__(self, + path=os.path.expanduser('datasets/ffhq256.lmdb'), + image_size=256, + original_resolution=256, + split=None, + as_tensor: bool = True, + do_augment: bool = True, + do_normalize: bool = True, + **kwargs): + self.original_resolution = original_resolution + self.data = BaseLMDB(path, original_resolution, zfill=5) + self.length = len(self.data) + + if split is None: + self.offset = 0 + elif split == 'train': + # last 60k + self.length = self.length - 10000 + self.offset = 10000 + elif split == 'test': + # first 10k + self.length = 10000 + self.offset = 0 + else: + raise NotImplementedError() + + transform = [ + transforms.Resize(image_size), + ] + if do_augment: + transform.append(transforms.RandomHorizontalFlip()) + if as_tensor: + transform.append(transforms.ToTensor()) + if do_normalize: + transform.append( + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) + self.transform = transforms.Compose(transform) + + def __len__(self): + return self.length + + def __getitem__(self, index): + assert index < self.length + index = index + self.offset + img = self.data[index] + if self.transform is not None: + img = self.transform(img) + return {'img': img, 'index': index} + + +class Crop: + def __init__(self, x1, x2, y1, y2): + self.x1 = x1 + self.x2 = x2 + self.y1 = y1 + self.y2 = y2 + + def __call__(self, img): + return Ftrans.crop(img, self.x1, self.y1, self.x2 - self.x1, + self.y2 - self.y1) + + def __repr__(self): + return self.__class__.__name__ + "(x1={}, x2={}, y1={}, y2={})".format( + self.x1, self.x2, self.y1, self.y2) + + +def d2c_crop(): + # from D2C paper for CelebA dataset. + cx = 89 + cy = 121 + x1 = cy - 64 + x2 = cy + 64 + y1 = cx - 64 + y2 = cx + 64 + return Crop(x1, x2, y1, y2) + + +class CelebAlmdb(Dataset): + """ + also supports for d2c crop. + """ + def __init__(self, + path, + image_size, + original_resolution=128, + split=None, + as_tensor: bool = True, + do_augment: bool = True, + do_normalize: bool = True, + crop_d2c: bool = False, + **kwargs): + self.original_resolution = original_resolution + self.data = BaseLMDB(path, original_resolution, zfill=7) + self.length = len(self.data) + self.crop_d2c = crop_d2c + + if split is None: + self.offset = 0 + else: + raise NotImplementedError() + + if crop_d2c: + transform = [ + d2c_crop(), + transforms.Resize(image_size), + ] + else: + transform = [ + transforms.Resize(image_size), + transforms.CenterCrop(image_size), + ] + + if do_augment: + transform.append(transforms.RandomHorizontalFlip()) + if as_tensor: + transform.append(transforms.ToTensor()) + if do_normalize: + transform.append( + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) + self.transform = transforms.Compose(transform) + + def __len__(self): + return self.length + + def __getitem__(self, index): + assert index < self.length + index = index + self.offset + img = self.data[index] + if self.transform is not None: + img = self.transform(img) + return {'img': img, 'index': index} + + +class Horse_lmdb(Dataset): + def __init__(self, + path=os.path.expanduser('datasets/horse256.lmdb'), + image_size=128, + original_resolution=256, + do_augment: bool = True, + do_transform: bool = True, + do_normalize: bool = True, + **kwargs): + self.original_resolution = original_resolution + print(path) + self.data = BaseLMDB(path, original_resolution, zfill=7) + self.length = len(self.data) + + transform = [ + transforms.Resize(image_size), + transforms.CenterCrop(image_size), + ] + if do_augment: + transform.append(transforms.RandomHorizontalFlip()) + if do_transform: + transform.append(transforms.ToTensor()) + if do_normalize: + transform.append( + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) + self.transform = transforms.Compose(transform) + + def __len__(self): + return self.length + + def __getitem__(self, index): + img = self.data[index] + if self.transform is not None: + img = self.transform(img) + return {'img': img, 'index': index} + + +class Bedroom_lmdb(Dataset): + def __init__(self, + path=os.path.expanduser('datasets/bedroom256.lmdb'), + image_size=128, + original_resolution=256, + do_augment: bool = True, + do_transform: bool = True, + do_normalize: bool = True, + **kwargs): + self.original_resolution = original_resolution + print(path) + self.data = BaseLMDB(path, original_resolution, zfill=7) + self.length = len(self.data) + + transform = [ + transforms.Resize(image_size), + transforms.CenterCrop(image_size), + ] + if do_augment: + transform.append(transforms.RandomHorizontalFlip()) + if do_transform: + transform.append(transforms.ToTensor()) + if do_normalize: + transform.append( + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) + self.transform = transforms.Compose(transform) + + def __len__(self): + return self.length + + def __getitem__(self, index): + img = self.data[index] + img = self.transform(img) + return {'img': img, 'index': index} + + +class CelebAttrDataset(Dataset): + + id_to_cls = [ + '5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', 'Bags_Under_Eyes', + 'Bald', 'Bangs', 'Big_Lips', 'Big_Nose', 'Black_Hair', 'Blond_Hair', + 'Blurry', 'Brown_Hair', 'Bushy_Eyebrows', 'Chubby', 'Double_Chin', + 'Eyeglasses', 'Goatee', 'Gray_Hair', 'Heavy_Makeup', 'High_Cheekbones', + 'Male', 'Mouth_Slightly_Open', 'Mustache', 'Narrow_Eyes', 'No_Beard', + 'Oval_Face', 'Pale_Skin', 'Pointy_Nose', 'Receding_Hairline', + 'Rosy_Cheeks', 'Sideburns', 'Smiling', 'Straight_Hair', 'Wavy_Hair', + 'Wearing_Earrings', 'Wearing_Hat', 'Wearing_Lipstick', + 'Wearing_Necklace', 'Wearing_Necktie', 'Young' + ] + cls_to_id = {v: k for k, v in enumerate(id_to_cls)} + + def __init__(self, + folder, + image_size=64, + attr_path=os.path.expanduser( + 'datasets/celeba_anno/list_attr_celeba.txt'), + ext='png', + only_cls_name: str = None, + only_cls_value: int = None, + do_augment: bool = False, + do_transform: bool = True, + do_normalize: bool = True, + d2c: bool = False): + super().__init__() + self.folder = folder + self.image_size = image_size + self.ext = ext + + # relative paths (make it shorter, saves memory and faster to sort) + paths = [ + str(p.relative_to(folder)) + for p in Path(f'{folder}').glob(f'**/*.{ext}') + ] + paths = [str(each).split('.')[0] + '.jpg' for each in paths] + + if d2c: + transform = [ + d2c_crop(), + transforms.Resize(image_size), + ] + else: + transform = [ + transforms.Resize(image_size), + transforms.CenterCrop(image_size), + ] + if do_augment: + transform.append(transforms.RandomHorizontalFlip()) + if do_transform: + transform.append(transforms.ToTensor()) + if do_normalize: + transform.append( + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) + self.transform = transforms.Compose(transform) + + with open(attr_path) as f: + # discard the top line + f.readline() + self.df = pd.read_csv(f, delim_whitespace=True) + self.df = self.df[self.df.index.isin(paths)] + + if only_cls_name is not None: + self.df = self.df[self.df[only_cls_name] == only_cls_value] + + def pos_count(self, cls_name): + return (self.df[cls_name] == 1).sum() + + def neg_count(self, cls_name): + return (self.df[cls_name] == -1).sum() + + def __len__(self): + return len(self.df) + + def __getitem__(self, index): + row = self.df.iloc[index] + name = row.name.split('.')[0] + name = f'{name}.{self.ext}' + + path = os.path.join(self.folder, name) + img = Image.open(path) + + labels = [0] * len(self.id_to_cls) + for k, v in row.items(): + labels[self.cls_to_id[k]] = int(v) + + if self.transform is not None: + img = self.transform(img) + + return {'img': img, 'index': index, 'labels': torch.tensor(labels)} + + +class CelebD2CAttrDataset(CelebAttrDataset): + """ + the dataset is used in the D2C paper. + it has a specific crop from the original CelebA. + """ + def __init__(self, + folder, + image_size=64, + attr_path=os.path.expanduser( + 'datasets/celeba_anno/list_attr_celeba.txt'), + ext='jpg', + only_cls_name: str = None, + only_cls_value: int = None, + do_augment: bool = False, + do_transform: bool = True, + do_normalize: bool = True, + d2c: bool = True): + super().__init__(folder, + image_size, + attr_path, + ext=ext, + only_cls_name=only_cls_name, + only_cls_value=only_cls_value, + do_augment=do_augment, + do_transform=do_transform, + do_normalize=do_normalize, + d2c=d2c) + + +class CelebAttrFewshotDataset(Dataset): + def __init__( + self, + cls_name, + K, + img_folder, + img_size=64, + ext='png', + seed=0, + only_cls_name: str = None, + only_cls_value: int = None, + all_neg: bool = False, + do_augment: bool = False, + do_transform: bool = True, + do_normalize: bool = True, + d2c: bool = False, + ) -> None: + self.cls_name = cls_name + self.K = K + self.img_folder = img_folder + self.ext = ext + + if all_neg: + path = f'data/celeba_fewshots/K{K}_allneg_{cls_name}_{seed}.csv' + else: + path = f'data/celeba_fewshots/K{K}_{cls_name}_{seed}.csv' + self.df = pd.read_csv(path, index_col=0) + if only_cls_name is not None: + self.df = self.df[self.df[only_cls_name] == only_cls_value] + + if d2c: + transform = [ + d2c_crop(), + transforms.Resize(img_size), + ] + else: + transform = [ + transforms.Resize(img_size), + transforms.CenterCrop(img_size), + ] + if do_augment: + transform.append(transforms.RandomHorizontalFlip()) + if do_transform: + transform.append(transforms.ToTensor()) + if do_normalize: + transform.append( + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) + self.transform = transforms.Compose(transform) + + def pos_count(self, cls_name): + return (self.df[cls_name] == 1).sum() + + def neg_count(self, cls_name): + return (self.df[cls_name] == -1).sum() + + def __len__(self): + return len(self.df) + + def __getitem__(self, index): + row = self.df.iloc[index] + name = row.name.split('.')[0] + name = f'{name}.{self.ext}' + + path = os.path.join(self.img_folder, name) + img = Image.open(path) + + # (1, 1) + label = torch.tensor(int(row[self.cls_name])).unsqueeze(-1) + + if self.transform is not None: + img = self.transform(img) + + return {'img': img, 'index': index, 'labels': label} + + +class CelebD2CAttrFewshotDataset(CelebAttrFewshotDataset): + def __init__(self, + cls_name, + K, + img_folder, + img_size=64, + ext='jpg', + seed=0, + only_cls_name: str = None, + only_cls_value: int = None, + all_neg: bool = False, + do_augment: bool = False, + do_transform: bool = True, + do_normalize: bool = True, + is_negative=False, + d2c: bool = True) -> None: + super().__init__(cls_name, + K, + img_folder, + img_size, + ext=ext, + seed=seed, + only_cls_name=only_cls_name, + only_cls_value=only_cls_value, + all_neg=all_neg, + do_augment=do_augment, + do_transform=do_transform, + do_normalize=do_normalize, + d2c=d2c) + self.is_negative = is_negative + + +class CelebHQAttrDataset(Dataset): + id_to_cls = [ + '5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', 'Bags_Under_Eyes', + 'Bald', 'Bangs', 'Big_Lips', 'Big_Nose', 'Black_Hair', 'Blond_Hair', + 'Blurry', 'Brown_Hair', 'Bushy_Eyebrows', 'Chubby', 'Double_Chin', + 'Eyeglasses', 'Goatee', 'Gray_Hair', 'Heavy_Makeup', 'High_Cheekbones', + 'Male', 'Mouth_Slightly_Open', 'Mustache', 'Narrow_Eyes', 'No_Beard', + 'Oval_Face', 'Pale_Skin', 'Pointy_Nose', 'Receding_Hairline', + 'Rosy_Cheeks', 'Sideburns', 'Smiling', 'Straight_Hair', 'Wavy_Hair', + 'Wearing_Earrings', 'Wearing_Hat', 'Wearing_Lipstick', + 'Wearing_Necklace', 'Wearing_Necktie', 'Young' + ] + cls_to_id = {v: k for k, v in enumerate(id_to_cls)} + + def __init__(self, + path=os.path.expanduser('datasets/celebahq256.lmdb'), + image_size=None, + attr_path=os.path.expanduser( + 'datasets/celeba_anno/CelebAMask-HQ-attribute-anno.txt'), + original_resolution=256, + do_augment: bool = False, + do_transform: bool = True, + do_normalize: bool = True): + super().__init__() + self.image_size = image_size + self.data = BaseLMDB(path, original_resolution, zfill=5) + + transform = [ + transforms.Resize(image_size), + transforms.CenterCrop(image_size), + ] + if do_augment: + transform.append(transforms.RandomHorizontalFlip()) + if do_transform: + transform.append(transforms.ToTensor()) + if do_normalize: + transform.append( + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) + self.transform = transforms.Compose(transform) + + with open(attr_path) as f: + # discard the top line + f.readline() + self.df = pd.read_csv(f, delim_whitespace=True) + + def pos_count(self, cls_name): + return (self.df[cls_name] == 1).sum() + + def neg_count(self, cls_name): + return (self.df[cls_name] == -1).sum() + + def __len__(self): + return len(self.df) + + def __getitem__(self, index): + row = self.df.iloc[index] + img_name = row.name + img_idx, ext = img_name.split('.') + img = self.data[img_idx] + + labels = [0] * len(self.id_to_cls) + for k, v in row.items(): + labels[self.cls_to_id[k]] = int(v) + + if self.transform is not None: + img = self.transform(img) + return {'img': img, 'index': index, 'labels': torch.tensor(labels)} + + +class CelebHQAttrFewshotDataset(Dataset): + def __init__(self, + cls_name, + K, + path, + image_size, + original_resolution=256, + do_augment: bool = False, + do_transform: bool = True, + do_normalize: bool = True): + super().__init__() + self.image_size = image_size + self.cls_name = cls_name + self.K = K + self.data = BaseLMDB(path, original_resolution, zfill=5) + + transform = [ + transforms.Resize(image_size), + transforms.CenterCrop(image_size), + ] + if do_augment: + transform.append(transforms.RandomHorizontalFlip()) + if do_transform: + transform.append(transforms.ToTensor()) + if do_normalize: + transform.append( + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) + self.transform = transforms.Compose(transform) + + self.df = pd.read_csv(f'data/celebahq_fewshots/K{K}_{cls_name}.csv', + index_col=0) + + def pos_count(self, cls_name): + return (self.df[cls_name] == 1).sum() + + def neg_count(self, cls_name): + return (self.df[cls_name] == -1).sum() + + def __len__(self): + return len(self.df) + + def __getitem__(self, index): + row = self.df.iloc[index] + img_name = row.name + img_idx, ext = img_name.split('.') + img = self.data[img_idx] + + # (1, 1) + label = torch.tensor(int(row[self.cls_name])).unsqueeze(-1) + + if self.transform is not None: + img = self.transform(img) + + return {'img': img, 'index': index, 'labels': label} + + +class Repeat(Dataset): + def __init__(self, dataset, new_len) -> None: + super().__init__() + self.dataset = dataset + self.original_len = len(dataset) + self.new_len = new_len + + def __len__(self): + return self.new_len + + def __getitem__(self, index): + index = index % self.original_len + return self.dataset[index] diff --git a/source/dataset-preprocessing/dataset_util.py b/source/dataset-preprocessing/dataset_util.py new file mode 100644 index 0000000..d2075ac --- /dev/null +++ b/source/dataset-preprocessing/dataset_util.py @@ -0,0 +1,13 @@ +import shutil +import os +from dist_utils import * + + +def use_cached_dataset_path(source_path, cache_path): + if get_rank() == 0: + if not os.path.exists(cache_path): + # shutil.rmtree(cache_path) + print(f'copying the data: {source_path} to {cache_path}') + shutil.copytree(source_path, cache_path) + barrier() + return cache_path \ No newline at end of file diff --git a/source/templates/templates.py b/source/templates/templates.py new file mode 100644 index 0000000..63cb4de --- /dev/null +++ b/source/templates/templates.py @@ -0,0 +1,323 @@ +from experiment import * + +# PyTorch vs PyTorch Lighning Module + +# Check if the environment variable 'BREAK_CIRCULAR_IMPORTS' is set. +# This variable is likely used to prevent issues related to circular imports. +if os.environ.get('BREAK_CIRCULAR_IMPORTS'): + # If the variable is set, we deliberately skip importing modules that are known to cause circular import problems. + import sys + + # Create a dummy module for 'torch.types' by assigning an empty type (class) to it. + sys.modules['torch.types'] = type('', (), {}) + # Similarly, create a dummy module for 'torch.utils._python_dispatch'. + sys.modules['torch.utils._python_dispatch'] = type('', (), {}) + +# Try to import the symbol 'IMPORTS_READY' from the module 'fixed_imports'. +# This is wrapped in a try/except block to gracefully handle the situation where 'fixed_imports' isn't available. +try: + from fixed_imports import IMPORTS_READY +except ImportError: + # If the module or the symbol is not found, simply pass without raising an error. + pass + + + +def ddpm(): + """ + base configuration for all DDIM-based models. + """ + conf = TrainConfig() + conf.batch_size = 32 + conf.beatgans_gen_type = GenerativeType.ddim + conf.beta_scheduler = 'linear' + conf.data_name = 'ffhq' + conf.diffusion_type = 'beatgans' + conf.eval_ema_every_samples = 200_000 + conf.eval_every_samples = 200_000 + conf.fp16 = True + conf.lr = 1e-4 + conf.model_name = ModelName.beatgans_ddpm + conf.net_attn = (16, ) + conf.net_beatgans_attn_head = 1 + conf.net_beatgans_embed_channels = 512 + conf.net_ch_mult = (1, 2, 4, 8) + conf.net_ch = 64 + conf.sample_size = 32 + conf.T_eval = 20 + conf.T = 1000 + conf.make_model_conf() + return conf + + +def autoenc_base(): + """ + base configuration for all Diff-AE models. + """ + conf = TrainConfig() + conf.batch_size = 32 + conf.beatgans_gen_type = GenerativeType.ddim + conf.beta_scheduler = 'linear' + conf.data_name = 'ffhq' + conf.diffusion_type = 'beatgans' + conf.eval_ema_every_samples = 200_000 + conf.eval_every_samples = 200_000 + conf.fp16 = True + conf.lr = 1e-4 + conf.model_name = ModelName.beatgans_autoenc + conf.net_attn = (16, ) + conf.net_beatgans_attn_head = 1 + conf.net_beatgans_embed_channels = 512 + conf.net_beatgans_resnet_two_cond = True + conf.net_ch_mult = (1, 2, 4, 8) + conf.net_ch = 64 + conf.net_enc_channel_mult = (1, 2, 4, 8, 8) + conf.net_enc_pool = 'adaptivenonzero' + conf.sample_size = 32 + conf.T_eval = 20 + conf.T = 1000 + conf.make_model_conf() + return conf + + +def ffhq64_ddpm(): + conf = ddpm() + conf.data_name = 'ffhqlmdb256' + conf.warmup = 0 + conf.total_samples = 72_000_000 + conf.scale_up_gpus(4) + return conf + + +def ffhq64_autoenc(): + conf = autoenc_base() + conf.data_name = 'ffhqlmdb256' + conf.warmup = 0 + conf.total_samples = 72_000_000 + conf.net_ch_mult = (1, 2, 4, 8) + conf.net_enc_channel_mult = (1, 2, 4, 8, 8) + conf.eval_every_samples = 1_000_000 + conf.eval_ema_every_samples = 1_000_000 + conf.scale_up_gpus(4) + conf.make_model_conf() + return conf + + +def celeba64d2c_ddpm(): + conf = ffhq128_ddpm() + conf.data_name = 'celebalmdb' + conf.eval_every_samples = 10_000_000 + conf.eval_ema_every_samples = 10_000_000 + conf.total_samples = 72_000_000 + conf.name = 'celeba64d2c_ddpm' + return conf + + +def celeba64d2c_autoenc(): + conf = ffhq64_autoenc() + conf.data_name = 'celebalmdb' + conf.eval_every_samples = 10_000_000 + conf.eval_ema_every_samples = 10_000_000 + conf.total_samples = 72_000_000 + conf.name = 'celeba64d2c_autoenc' + return conf + + +def ffhq128_ddpm(): + conf = ddpm() + conf.data_name = 'ffhqlmdb256' + conf.warmup = 0 + conf.total_samples = 48_000_000 + conf.img_size = 128 + conf.net_ch = 128 + # channels: + # 3 => 128 * 1 => 128 * 1 => 128 * 2 => 128 * 3 => 128 * 4 + # sizes: + # 128 => 128 => 64 => 32 => 16 => 8 + conf.net_ch_mult = (1, 1, 2, 3, 4) + conf.eval_every_samples = 1_000_000 + conf.eval_ema_every_samples = 1_000_000 + conf.scale_up_gpus(4) + conf.eval_ema_every_samples = 10_000_000 + conf.eval_every_samples = 10_000_000 + conf.make_model_conf() + return conf + + +def ffhq128_autoenc_base(): + conf = autoenc_base() + conf.data_name = 'ffhqlmdb256' + conf.scale_up_gpus(4) + conf.img_size = 128 + conf.net_ch = 128 + # final resolution = 8x8 + conf.net_ch_mult = (1, 1, 2, 3, 4) + # final resolution = 4x4 + conf.net_enc_channel_mult = (1, 1, 2, 3, 4, 4) + conf.eval_ema_every_samples = 10_000_000 + conf.eval_every_samples = 10_000_000 + conf.make_model_conf() + return conf + + +def ffhq256_autoenc(): + conf = ffhq128_autoenc_base() + conf.img_size = 256 + conf.net_ch = 128 + conf.net_ch_mult = (1, 1, 2, 2, 4, 4) + conf.net_enc_channel_mult = (1, 1, 2, 2, 4, 4, 4) + conf.eval_every_samples = 10_000_000 + conf.eval_ema_every_samples = 10_000_000 + conf.total_samples = 200_000_000 + conf.batch_size = 64 + conf.make_model_conf() + conf.name = 'ffhq256_autoenc' + return conf + + +def ffhq256_autoenc_eco(): + conf = ffhq128_autoenc_base() + conf.img_size = 256 + conf.net_ch = 128 + conf.net_ch_mult = (1, 1, 2, 2, 4, 4) + conf.net_enc_channel_mult = (1, 1, 2, 2, 4, 4, 4) + conf.eval_every_samples = 10_000_000 + conf.eval_ema_every_samples = 10_000_000 + conf.total_samples = 200_000_000 + conf.batch_size = 64 + conf.make_model_conf() + conf.name = 'ffhq256_autoenc_eco' + return conf + + +def ffhq128_ddpm_72M(): + conf = ffhq128_ddpm() + conf.total_samples = 72_000_000 + conf.name = 'ffhq128_ddpm_72M' + return conf + + +def ffhq128_autoenc_72M(): + conf = ffhq128_autoenc_base() + conf.total_samples = 72_000_000 + conf.name = 'ffhq128_autoenc_72M' + return conf + + +def ffhq128_ddpm_130M(): + conf = ffhq128_ddpm() + conf.total_samples = 130_000_000 + conf.eval_ema_every_samples = 10_000_000 + conf.eval_every_samples = 10_000_000 + conf.name = 'ffhq128_ddpm_130M' + return conf + + +def ffhq128_autoenc_130M(): + conf = ffhq128_autoenc_base() + conf.total_samples = 130_000_000 + conf.eval_ema_every_samples = 10_000_000 + conf.eval_every_samples = 10_000_000 + conf.name = 'ffhq128_autoenc_130M' + return conf + + +def horse128_ddpm(): + conf = ffhq128_ddpm() + conf.data_name = 'horse256' + conf.total_samples = 130_000_000 + conf.eval_ema_every_samples = 10_000_000 + conf.eval_every_samples = 10_000_000 + conf.name = 'horse128_ddpm' + return conf + + +def horse128_autoenc(): + conf = ffhq128_autoenc_base() + conf.data_name = 'horse256' + conf.total_samples = 130_000_000 + conf.eval_ema_every_samples = 10_000_000 + conf.eval_every_samples = 10_000_000 + conf.name = 'horse128_autoenc' + return conf + + +def bedroom128_ddpm(): + conf = ffhq128_ddpm() + conf.data_name = 'bedroom256' + conf.eval_ema_every_samples = 10_000_000 + conf.eval_every_samples = 10_000_000 + conf.total_samples = 120_000_000 + conf.name = 'bedroom128_ddpm' + return conf + + +def bedroom128_autoenc(): + conf = ffhq128_autoenc_base() + conf.data_name = 'bedroom256' + conf.eval_ema_every_samples = 10_000_000 + conf.eval_every_samples = 10_000_000 + conf.total_samples = 120_000_000 + conf.name = 'bedroom128_autoenc' + return conf + + +def pretrain_celeba64d2c_72M(): + conf = celeba64d2c_autoenc() + conf.pretrain = PretrainConfig( + name='72M', + path=f'checkpoints/{celeba64d2c_autoenc().name}/last.ckpt', + ) + conf.latent_infer_path = f'checkpoints/{celeba64d2c_autoenc().name}/latent.pkl' + return conf + + +def pretrain_ffhq128_autoenc72M(): + conf = ffhq128_autoenc_base() + conf.postfix = '' + conf.pretrain = PretrainConfig( + name='72M', + path=f'checkpoints/{ffhq128_autoenc_72M().name}/last.ckpt', + ) + conf.latent_infer_path = f'checkpoints/{ffhq128_autoenc_72M().name}/latent.pkl' + return conf + + +def pretrain_ffhq128_autoenc130M(): + conf = ffhq128_autoenc_base() + conf.pretrain = PretrainConfig( + name='130M', + path=f'checkpoints/{ffhq128_autoenc_130M().name}/last.ckpt', + ) + conf.latent_infer_path = f'checkpoints/{ffhq128_autoenc_130M().name}/latent.pkl' + return conf + + +def pretrain_ffhq256_autoenc(): + conf = ffhq256_autoenc() + conf.pretrain = PretrainConfig( + name='90M', + path=f'checkpoints/{ffhq256_autoenc().name}/last.ckpt', + ) + conf.latent_infer_path = f'checkpoints/{ffhq256_autoenc().name}/latent.pkl' + return conf + + +def pretrain_horse128(): + conf = horse128_autoenc() + conf.pretrain = PretrainConfig( + name='82M', + path=f'checkpoints/{horse128_autoenc().name}/last.ckpt', + ) + conf.latent_infer_path = f'checkpoints/{horse128_autoenc().name}/latent.pkl' + return conf + + +def pretrain_bedroom128(): + conf = bedroom128_autoenc() + conf.pretrain = PretrainConfig( + name='120M', + path=f'checkpoints/{bedroom128_autoenc().name}/last.ckpt', + ) + conf.latent_infer_path = f'checkpoints/{bedroom128_autoenc().name}/latent.pkl' + return conf diff --git a/source/templates/templates_cls.py b/source/templates/templates_cls.py new file mode 100644 index 0000000..4fc3bcf --- /dev/null +++ b/source/templates/templates_cls.py @@ -0,0 +1,38 @@ +from templates import * + + +def ffhq128_autoenc_cls(): + conf = ffhq128_autoenc_130M() + conf.train_mode = TrainMode.manipulate + conf.manipulate_mode = ManipulateMode.celebahq_all + conf.manipulate_znormalize = True + conf.latent_infer_path = f'checkpoints/{ffhq128_autoenc_130M().name}/latent.pkl' + conf.batch_size = 32 + conf.lr = 1e-3 + conf.total_samples = 300_000 + # use the pretraining trick instead of contiuning trick + conf.pretrain = PretrainConfig( + '130M', + f'checkpoints/{ffhq128_autoenc_130M().name}/last.ckpt', + ) + conf.name = 'ffhq128_autoenc_cls' + return conf + + +def ffhq256_autoenc_cls(): + '''We first train the encoder on FFHQ dataset then use it as a pretrained to train a linear classifer on CelebA dataset with attribute labels''' + conf = ffhq256_autoenc() + conf.train_mode = TrainMode.manipulate + conf.manipulate_mode = ManipulateMode.celebahq_all + conf.manipulate_znormalize = True + conf.latent_infer_path = f'checkpoints/{ffhq256_autoenc().name}/latent.pkl' # we train on Celeb dataset, not FFHQ + conf.batch_size = 32 + conf.lr = 1e-3 + conf.total_samples = 300_000 + # use the pretraining trick instead of contiuning trick + conf.pretrain = PretrainConfig( + '130M', + f'checkpoints/{ffhq256_autoenc().name}/last.ckpt', + ) + conf.name = 'ffhq256_autoenc_cls' + return conf diff --git a/source/templates/templates_latent.py b/source/templates/templates_latent.py new file mode 100644 index 0000000..b82c257 --- /dev/null +++ b/source/templates/templates_latent.py @@ -0,0 +1,150 @@ +from templates import * + + +def latent_diffusion_config(conf: TrainConfig): + conf.batch_size = 128 + conf.train_mode = TrainMode.latent_diffusion + conf.latent_gen_type = GenerativeType.ddim + conf.latent_loss_type = LossType.mse + conf.latent_model_mean_type = ModelMeanType.eps + conf.latent_model_var_type = ModelVarType.fixed_large + conf.latent_rescale_timesteps = False + conf.latent_clip_sample = False + conf.latent_T_eval = 20 + conf.latent_znormalize = True + conf.total_samples = 96_000_000 + conf.sample_every_samples = 400_000 + conf.eval_every_samples = 20_000_000 + conf.eval_ema_every_samples = 20_000_000 + conf.save_every_samples = 2_000_000 + return conf + + +def latent_diffusion128_config(conf: TrainConfig): + conf = latent_diffusion_config(conf) + conf.batch_size_eval = 32 + return conf + + +def latent_mlp_2048_norm_10layers(conf: TrainConfig): + conf.net_latent_net_type = LatentNetType.skip + conf.net_latent_layers = 10 + conf.net_latent_skip_layers = list(range(1, conf.net_latent_layers)) + conf.net_latent_activation = Activation.silu + conf.net_latent_num_hid_channels = 2048 + conf.net_latent_use_norm = True + conf.net_latent_condition_bias = 1 + return conf + + +def latent_mlp_2048_norm_20layers(conf: TrainConfig): + conf = latent_mlp_2048_norm_10layers(conf) + conf.net_latent_layers = 20 + conf.net_latent_skip_layers = list(range(1, conf.net_latent_layers)) + return conf + + +def latent_256_batch_size(conf: TrainConfig): + conf.batch_size = 256 + conf.eval_ema_every_samples = 100_000_000 + conf.eval_every_samples = 100_000_000 + conf.sample_every_samples = 1_000_000 + conf.save_every_samples = 2_000_000 + conf.total_samples = 301_000_000 + return conf + + +def latent_512_batch_size(conf: TrainConfig): + conf.batch_size = 512 + conf.eval_ema_every_samples = 100_000_000 + conf.eval_every_samples = 100_000_000 + conf.sample_every_samples = 1_000_000 + conf.save_every_samples = 5_000_000 + conf.total_samples = 501_000_000 + return conf + + +def latent_2048_batch_size(conf: TrainConfig): + conf.batch_size = 2048 + conf.eval_ema_every_samples = 200_000_000 + conf.eval_every_samples = 200_000_000 + conf.sample_every_samples = 4_000_000 + conf.save_every_samples = 20_000_000 + conf.total_samples = 1_501_000_000 + return conf + + +def adamw_weight_decay(conf: TrainConfig): + conf.optimizer = OptimizerType.adamw + conf.weight_decay = 0.01 + return conf + + +def ffhq128_autoenc_latent(): + conf = pretrain_ffhq128_autoenc130M() + conf = latent_diffusion128_config(conf) + conf = latent_mlp_2048_norm_10layers(conf) + conf = latent_256_batch_size(conf) + conf = adamw_weight_decay(conf) + conf.total_samples = 101_000_000 + conf.latent_loss_type = LossType.l1 + conf.latent_beta_scheduler = 'const0.008' + conf.name = 'ffhq128_autoenc_latent' + return conf + + +def ffhq256_autoenc_latent(): + conf = pretrain_ffhq256_autoenc() + conf = latent_diffusion128_config(conf) + conf = latent_mlp_2048_norm_10layers(conf) + conf = latent_256_batch_size(conf) + conf = adamw_weight_decay(conf) + conf.total_samples = 101_000_000 + conf.latent_loss_type = LossType.l1 + conf.latent_beta_scheduler = 'const0.008' + conf.eval_ema_every_samples = 200_000_000 + conf.eval_every_samples = 200_000_000 + conf.sample_every_samples = 4_000_000 + conf.name = 'ffhq256_autoenc_latent' + return conf + + +def horse128_autoenc_latent(): + conf = pretrain_horse128() + conf = latent_diffusion128_config(conf) + conf = latent_2048_batch_size(conf) + conf = latent_mlp_2048_norm_20layers(conf) + conf.total_samples = 2_001_000_000 + conf.latent_beta_scheduler = 'const0.008' + conf.latent_loss_type = LossType.l1 + conf.name = 'horse128_autoenc_latent' + return conf + + +def bedroom128_autoenc_latent(): + conf = pretrain_bedroom128() + conf = latent_diffusion128_config(conf) + conf = latent_2048_batch_size(conf) + conf = latent_mlp_2048_norm_20layers(conf) + conf.total_samples = 2_001_000_000 + conf.latent_beta_scheduler = 'const0.008' + conf.latent_loss_type = LossType.l1 + conf.name = 'bedroom128_autoenc_latent' + return conf + + +def celeba64d2c_autoenc_latent(): + conf = pretrain_celeba64d2c_72M() + conf = latent_diffusion_config(conf) + conf = latent_512_batch_size(conf) + conf = latent_mlp_2048_norm_10layers(conf) + conf = adamw_weight_decay(conf) + # just for the name + conf.continue_from = PretrainConfig('200M', + f'log-latent/{conf.name}/last.ckpt') + conf.postfix = '_300M' + conf.total_samples = 301_000_000 + conf.latent_beta_scheduler = 'const0.008' + conf.latent_loss_type = LossType.l1 + conf.name = 'celeba64d2c_autoenc_latent' + return conf From 189cffe03b6de32e2a713faea4591c1b75721dc1 Mon Sep 17 00:00:00 2001 From: olafhappy Date: Tue, 18 Mar 2025 15:50:57 +0000 Subject: [PATCH 3/3] fix: add some new files --- .../xsem_from_image.ipynb | 0 .../xt_from_image.ipynb | 0 source/align.py | 249 ------ source/config.py | 425 ----------- source/config_base.py | 72 -- source/data_resize_bedroom.py | 101 --- source/data_resize_celeba.py | 120 --- source/data_resize_celebahq.py | 120 --- source/data_resize_ffhq.py | 123 --- source/data_resize_horse.py | 100 --- source/dataset.py | 716 ------------------ source/dataset_util.py | 13 - fixed_imports.py => source/fixed_imports.py | 0 source/templates.py | 323 -------- source/templates_cls.py | 38 - source/templates_latent.py | 150 ---- 16 files changed, 2550 deletions(-) rename xsem_from_image.ipynb => notebook/xsem_from_image.ipynb (100%) rename xt_from_image.ipynb => notebook/xt_from_image.ipynb (100%) delete mode 100644 source/align.py delete mode 100644 source/config.py delete mode 100644 source/config_base.py delete mode 100644 source/data_resize_bedroom.py delete mode 100644 source/data_resize_celeba.py delete mode 100644 source/data_resize_celebahq.py delete mode 100644 source/data_resize_ffhq.py delete mode 100644 source/data_resize_horse.py delete mode 100644 source/dataset.py delete mode 100644 source/dataset_util.py rename fixed_imports.py => source/fixed_imports.py (100%) delete mode 100644 source/templates.py delete mode 100644 source/templates_cls.py delete mode 100644 source/templates_latent.py diff --git a/xsem_from_image.ipynb b/notebook/xsem_from_image.ipynb similarity index 100% rename from xsem_from_image.ipynb rename to notebook/xsem_from_image.ipynb diff --git a/xt_from_image.ipynb b/notebook/xt_from_image.ipynb similarity index 100% rename from xt_from_image.ipynb rename to notebook/xt_from_image.ipynb diff --git a/source/align.py b/source/align.py deleted file mode 100644 index 8783e37..0000000 --- a/source/align.py +++ /dev/null @@ -1,249 +0,0 @@ -import bz2 -import os -import os.path as osp -import sys -from multiprocessing import Pool - -import dlib -import numpy as np -import PIL.Image -import requests -import scipy.ndimage -from tqdm import tqdm -from argparse import ArgumentParser - -LANDMARKS_MODEL_URL = 'http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2' - - -def image_align(src_file, - dst_file, - face_landmarks, - output_size=1024, - transform_size=4096, - enable_padding=True): - # Align function from FFHQ dataset pre-processing step - # https://github.com/NVlabs/ffhq-dataset/blob/master/download_ffhq.py - - lm = np.array(face_landmarks) - lm_chin = lm[0:17] # left-right - lm_eyebrow_left = lm[17:22] # left-right - lm_eyebrow_right = lm[22:27] # left-right - lm_nose = lm[27:31] # top-down - lm_nostrils = lm[31:36] # top-down - lm_eye_left = lm[36:42] # left-clockwise - lm_eye_right = lm[42:48] # left-clockwise - lm_mouth_outer = lm[48:60] # left-clockwise - lm_mouth_inner = lm[60:68] # left-clockwise - - # Calculate auxiliary vectors. - eye_left = np.mean(lm_eye_left, axis=0) - eye_right = np.mean(lm_eye_right, axis=0) - eye_avg = (eye_left + eye_right) * 0.5 - eye_to_eye = eye_right - eye_left - mouth_left = lm_mouth_outer[0] - mouth_right = lm_mouth_outer[6] - mouth_avg = (mouth_left + mouth_right) * 0.5 - eye_to_mouth = mouth_avg - eye_avg - - # Choose oriented crop rectangle. - x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] - x /= np.hypot(*x) - x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8) - y = np.flipud(x) * [-1, 1] - c = eye_avg + eye_to_mouth * 0.1 - quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) - qsize = np.hypot(*x) * 2 - - # Load in-the-wild image. - if not os.path.isfile(src_file): - print( - '\nCannot find source image. Please run "--wilds" before "--align".' - ) - return - img = PIL.Image.open(src_file) - img = img.convert('RGB') - - # Shrink. - shrink = int(np.floor(qsize / output_size * 0.5)) - if shrink > 1: - rsize = (int(np.rint(float(img.size[0]) / shrink)), - int(np.rint(float(img.size[1]) / shrink))) - img = img.resize(rsize, PIL.Image.ANTIALIAS) - quad /= shrink - qsize /= shrink - - # Crop. - border = max(int(np.rint(qsize * 0.1)), 3) - crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), - int(np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1])))) - crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), - min(crop[2] + border, - img.size[0]), min(crop[3] + border, img.size[1])) - if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]: - img = img.crop(crop) - quad -= crop[0:2] - - # Pad. - pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), - int(np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1])))) - pad = (max(-pad[0] + border, - 0), max(-pad[1] + border, - 0), max(pad[2] - img.size[0] + border, - 0), max(pad[3] - img.size[1] + border, 0)) - if enable_padding and max(pad) > border - 4: - pad = np.maximum(pad, int(np.rint(qsize * 0.3))) - img = np.pad(np.float32(img), - ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect') - h, w, _ = img.shape - y, x, _ = np.ogrid[:h, :w, :1] - mask = np.maximum( - 1.0 - - np.minimum(np.float32(x) / pad[0], - np.float32(w - 1 - x) / pad[2]), 1.0 - - np.minimum(np.float32(y) / pad[1], - np.float32(h - 1 - y) / pad[3])) - blur = qsize * 0.02 - img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) - img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0) - img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), - 'RGB') - quad += pad[:2] - - # Transform. - img = img.transform((transform_size, transform_size), PIL.Image.QUAD, - (quad + 0.5).flatten(), PIL.Image.BILINEAR) - if output_size < transform_size: - img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS) - - # Save aligned image. - img.save(dst_file, 'PNG') - - -class LandmarksDetector: - def __init__(self, predictor_model_path): - """ - :param predictor_model_path: path to shape_predictor_68_face_landmarks.dat file - """ - self.detector = dlib.get_frontal_face_detector( - ) # cnn_face_detection_model_v1 also can be used - self.shape_predictor = dlib.shape_predictor(predictor_model_path) - - def get_landmarks(self, image): - img = dlib.load_rgb_image(image) - dets = self.detector(img, 1) - - for detection in dets: - face_landmarks = [ - (item.x, item.y) - for item in self.shape_predictor(img, detection).parts() - ] - yield face_landmarks - - -def unpack_bz2(src_path): - dst_path = src_path[:-4] - if os.path.exists(dst_path): - print('cached') - return dst_path - data = bz2.BZ2File(src_path).read() - with open(dst_path, 'wb') as fp: - fp.write(data) - return dst_path - - -def work_landmark(raw_img_path, img_name, face_landmarks): - face_img_name = '%s.png' % (os.path.splitext(img_name)[0], ) - aligned_face_path = os.path.join(ALIGNED_IMAGES_DIR, face_img_name) - if os.path.exists(aligned_face_path): - return - image_align(raw_img_path, - aligned_face_path, - face_landmarks, - output_size=256) - - -def get_file(src, tgt): - if os.path.exists(tgt): - print('cached') - return tgt - tgt_dir = os.path.dirname(tgt) - if not os.path.exists(tgt_dir): - os.makedirs(tgt_dir) - file = requests.get(src) - open(tgt, 'wb').write(file.content) - return tgt - - -if __name__ == "__main__": - """ - Extracts and aligns all faces from images using DLib and a function from original FFHQ dataset preparation step - python align_images.py /raw_images /aligned_images - """ - parser = ArgumentParser() - parser.add_argument("-i", - "--input_imgs_path", - type=str, - default="imgs", - help="input images directory path") - parser.add_argument("-o", - "--output_imgs_path", - type=str, - default="imgs_align", - help="output images directory path") - - args = parser.parse_args() - - # takes very long time ... - landmarks_model_path = unpack_bz2( - get_file( - 'http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2', - 'temp/shape_predictor_68_face_landmarks.dat.bz2')) - - # RAW_IMAGES_DIR = sys.argv[1] - # ALIGNED_IMAGES_DIR = sys.argv[2] - RAW_IMAGES_DIR = args.input_imgs_path - ALIGNED_IMAGES_DIR = args.output_imgs_path - - if not osp.exists(ALIGNED_IMAGES_DIR): os.makedirs(ALIGNED_IMAGES_DIR) - - files = os.listdir(RAW_IMAGES_DIR) - print(f'total img files {len(files)}') - with tqdm(total=len(files)) as progress: - - def cb(*args): - # print('update') - progress.update() - - def err_cb(e): - print('error:', e) - - with Pool(8) as pool: - res = [] - landmarks_detector = LandmarksDetector(landmarks_model_path) - for img_name in files: - raw_img_path = os.path.join(RAW_IMAGES_DIR, img_name) - # print('img_name:', img_name) - for i, face_landmarks in enumerate( - landmarks_detector.get_landmarks(raw_img_path), - start=1): - # assert i == 1, f'{i}' - # print(i, face_landmarks) - # face_img_name = '%s_%02d.png' % (os.path.splitext(img_name)[0], i) - # aligned_face_path = os.path.join(ALIGNED_IMAGES_DIR, face_img_name) - # image_align(raw_img_path, aligned_face_path, face_landmarks, output_size=256) - - work_landmark(raw_img_path, img_name, face_landmarks) - progress.update() - - # job = pool.apply_async( - # work_landmark, - # (raw_img_path, img_name, face_landmarks), - # callback=cb, - # error_callback=err_cb, - # ) - # res.append(job) - - # pool.close() - # pool.join() - print(f"output aligned images at: {ALIGNED_IMAGES_DIR}") diff --git a/source/config.py b/source/config.py deleted file mode 100644 index 98068e8..0000000 --- a/source/config.py +++ /dev/null @@ -1,425 +0,0 @@ -from model.unet import ScaleAt -from model.latentnet import * -from diffusion.resample import UniformSampler -from diffusion.diffusion import space_timesteps -from typing import Tuple - -from torch.utils.data import DataLoader - -from config_base import BaseConfig -from dataset import * -from diffusion import * -from diffusion.base import GenerativeType, LossType, ModelMeanType, ModelVarType, get_named_beta_schedule -from model import * -from choices import * -from multiprocessing import get_context -import os -from dataset_util import * -from torch.utils.data.distributed import DistributedSampler - -data_paths = { - 'ffhqlmdb256': - os.path.expanduser('datasets/ffhq256.lmdb'), - # used for training a classifier - 'celeba': - os.path.expanduser('datasets/celeba'), - # used for training DPM models - 'celebalmdb': - os.path.expanduser('datasets/celeba.lmdb'), - 'celebahq': - os.path.expanduser('datasets/celebahq256.lmdb'), - 'horse256': - os.path.expanduser('datasets/horse256.lmdb'), - 'bedroom256': - os.path.expanduser('datasets/bedroom256.lmdb'), - 'celeba_anno': - os.path.expanduser('datasets/celeba_anno/list_attr_celeba.txt'), - 'celebahq_anno': - os.path.expanduser( - 'datasets/celeba_anno/CelebAMask-HQ-attribute-anno.txt'), - 'celeba_relight': - os.path.expanduser('datasets/celeba_hq_light/celeba_light.txt'), -} - - -@dataclass -class PretrainConfig(BaseConfig): - name: str - path: str - - -@dataclass -class TrainConfig(BaseConfig): - # random seed - seed: int = 0 - train_mode: TrainMode = TrainMode.diffusion - train_cond0_prob: float = 0 - train_pred_xstart_detach: bool = True - train_interpolate_prob: float = 0 - train_interpolate_img: bool = False - manipulate_mode: ManipulateMode = ManipulateMode.celebahq_all - manipulate_cls: str = None - manipulate_shots: int = None - manipulate_loss: ManipulateLossType = ManipulateLossType.bce - manipulate_znormalize: bool = False - manipulate_seed: int = 0 - accum_batches: int = 1 - autoenc_mid_attn: bool = True - batch_size: int = 16 - batch_size_eval: int = None - beatgans_gen_type: GenerativeType = GenerativeType.ddim - beatgans_loss_type: LossType = LossType.mse - beatgans_model_mean_type: ModelMeanType = ModelMeanType.eps - beatgans_model_var_type: ModelVarType = ModelVarType.fixed_large - beatgans_rescale_timesteps: bool = False - latent_infer_path: str = None - latent_znormalize: bool = False - latent_gen_type: GenerativeType = GenerativeType.ddim - latent_loss_type: LossType = LossType.mse - latent_model_mean_type: ModelMeanType = ModelMeanType.eps - latent_model_var_type: ModelVarType = ModelVarType.fixed_large - latent_rescale_timesteps: bool = False - latent_T_eval: int = 1_000 - latent_clip_sample: bool = False - latent_beta_scheduler: str = 'linear' - beta_scheduler: str = 'linear' - data_name: str = '' - data_val_name: str = None - diffusion_type: str = None - dropout: float = 0.1 - ema_decay: float = 0.9999 - eval_num_images: int = 5_000 - eval_every_samples: int = 200_000 - eval_ema_every_samples: int = 200_000 - fid_use_torch: bool = True - fp16: bool = False - grad_clip: float = 1 - img_size: int = 64 - lr: float = 0.0001 - optimizer: OptimizerType = OptimizerType.adam - weight_decay: float = 0 - model_conf: ModelConfig = None - model_name: ModelName = None - model_type: ModelType = None - net_attn: Tuple[int] = None - net_beatgans_attn_head: int = 1 - # not necessarily the same as the the number of style channels - net_beatgans_embed_channels: int = 512 - net_resblock_updown: bool = True - net_enc_use_time: bool = False - net_enc_pool: str = 'adaptivenonzero' - net_beatgans_gradient_checkpoint: bool = False - net_beatgans_resnet_two_cond: bool = False - net_beatgans_resnet_use_zero_module: bool = True - net_beatgans_resnet_scale_at: ScaleAt = ScaleAt.after_norm - net_beatgans_resnet_cond_channels: int = None - net_ch_mult: Tuple[int] = None - net_ch: int = 64 - net_enc_attn: Tuple[int] = None - net_enc_k: int = None - # number of resblocks for the encoder (half-unet) - net_enc_num_res_blocks: int = 2 - net_enc_channel_mult: Tuple[int] = None - net_enc_grad_checkpoint: bool = False - net_autoenc_stochastic: bool = False - net_latent_activation: Activation = Activation.silu - net_latent_channel_mult: Tuple[int] = (1, 2, 4) - net_latent_condition_bias: float = 0 - net_latent_dropout: float = 0 - net_latent_layers: int = None - net_latent_net_last_act: Activation = Activation.none - net_latent_net_type: LatentNetType = LatentNetType.none - net_latent_num_hid_channels: int = 1024 - net_latent_num_time_layers: int = 2 - net_latent_skip_layers: Tuple[int] = None - net_latent_time_emb_channels: int = 64 - net_latent_use_norm: bool = False - net_latent_time_last_act: bool = False - net_num_res_blocks: int = 2 - # number of resblocks for the UNET - net_num_input_res_blocks: int = None - net_enc_num_cls: int = None - num_workers: int = 4 - parallel: bool = False - postfix: str = '' - sample_size: int = 64 - sample_every_samples: int = 20_000 - save_every_samples: int = 100_000 - style_ch: int = 512 - T_eval: int = 1_000 - T_sampler: str = 'uniform' - T: int = 1_000 - total_samples: int = 10_000_000 - warmup: int = 0 - pretrain: PretrainConfig = None - continue_from: PretrainConfig = None - eval_programs: Tuple[str] = None - # if present load the checkpoint from this path instead - eval_path: str = None - base_dir: str = 'checkpoints' - use_cache_dataset: bool = False - data_cache_dir: str = os.path.expanduser('~/cache') - work_cache_dir: str = os.path.expanduser('~/mycache') - # to be overridden - name: str = '' - - def __post_init__(self): - self.batch_size_eval = self.batch_size_eval or self.batch_size - self.data_val_name = self.data_val_name or self.data_name - - def scale_up_gpus(self, num_gpus, num_nodes=1): - self.eval_ema_every_samples *= num_gpus * num_nodes - self.eval_every_samples *= num_gpus * num_nodes - self.sample_every_samples *= num_gpus * num_nodes - self.batch_size *= num_gpus * num_nodes - self.batch_size_eval *= num_gpus * num_nodes - return self - - @property - def batch_size_effective(self): - return self.batch_size * self.accum_batches - - @property - def fid_cache(self): - # we try to use the local dirs to reduce the load over network drives - # hopefully, this would reduce the disconnection problems with sshfs - return f'{self.work_cache_dir}/eval_images/{self.data_name}_size{self.img_size}_{self.eval_num_images}' - - @property - def data_path(self): - # may use the cache dir - path = data_paths[self.data_name] - if self.use_cache_dataset and path is not None: - path = use_cached_dataset_path( - path, f'{self.data_cache_dir}/{self.data_name}') - return path - - @property - def logdir(self): - return f'{self.base_dir}/{self.name}' - - @property - def generate_dir(self): - # we try to use the local dirs to reduce the load over network drives - # hopefully, this would reduce the disconnection problems with sshfs - return f'{self.work_cache_dir}/gen_images/{self.name}' - - def _make_diffusion_conf(self, T=None): - if self.diffusion_type == 'beatgans': - # can use T < self.T for evaluation - # follows the guided-diffusion repo conventions - # t's are evenly spaced - if self.beatgans_gen_type == GenerativeType.ddpm: - section_counts = [T] - elif self.beatgans_gen_type == GenerativeType.ddim: - section_counts = f'ddim{T}' - else: - raise NotImplementedError() - - return SpacedDiffusionBeatGansConfig( - gen_type=self.beatgans_gen_type, - model_type=self.model_type, - betas=get_named_beta_schedule(self.beta_scheduler, self.T), - model_mean_type=self.beatgans_model_mean_type, - model_var_type=self.beatgans_model_var_type, - loss_type=self.beatgans_loss_type, - rescale_timesteps=self.beatgans_rescale_timesteps, - use_timesteps=space_timesteps(num_timesteps=self.T, - section_counts=section_counts), - fp16=self.fp16, - ) - else: - raise NotImplementedError() - - def _make_latent_diffusion_conf(self, T=None): - # can use T < self.T for evaluation - # follows the guided-diffusion repo conventions - # t's are evenly spaced - if self.latent_gen_type == GenerativeType.ddpm: - section_counts = [T] - elif self.latent_gen_type == GenerativeType.ddim: - section_counts = f'ddim{T}' - else: - raise NotImplementedError() - - return SpacedDiffusionBeatGansConfig( - train_pred_xstart_detach=self.train_pred_xstart_detach, - gen_type=self.latent_gen_type, - # latent's model is always ddpm - model_type=ModelType.ddpm, - # latent shares the beta scheduler and full T - betas=get_named_beta_schedule(self.latent_beta_scheduler, self.T), - model_mean_type=self.latent_model_mean_type, - model_var_type=self.latent_model_var_type, - loss_type=self.latent_loss_type, - rescale_timesteps=self.latent_rescale_timesteps, - use_timesteps=space_timesteps(num_timesteps=self.T, - section_counts=section_counts), - fp16=self.fp16, - ) - - @property - def model_out_channels(self): - return 3 - - def make_T_sampler(self): - if self.T_sampler == 'uniform': - return UniformSampler(self.T) - else: - raise NotImplementedError() - - def make_diffusion_conf(self): - return self._make_diffusion_conf(self.T) - - def make_eval_diffusion_conf(self): - return self._make_diffusion_conf(T=self.T_eval) - - def make_latent_diffusion_conf(self): - return self._make_latent_diffusion_conf(T=self.T) - - def make_latent_eval_diffusion_conf(self): - # latent can have different eval T - return self._make_latent_diffusion_conf(T=self.latent_T_eval) - - def make_dataset(self, path=None, **kwargs): - if self.data_name == 'ffhqlmdb256': - return FFHQlmdb(path=path or self.data_path, - image_size=self.img_size, - **kwargs) - elif self.data_name == 'horse256': - return Horse_lmdb(path=path or self.data_path, - image_size=self.img_size, - **kwargs) - elif self.data_name == 'bedroom256': - return Horse_lmdb(path=path or self.data_path, - image_size=self.img_size, - **kwargs) - elif self.data_name == 'celebalmdb': - # always use d2c crop - return CelebAlmdb(path=path or self.data_path, - image_size=self.img_size, - original_resolution=None, - crop_d2c=True, - **kwargs) - else: - raise NotImplementedError() - - def make_loader(self, - dataset, - shuffle: bool, - num_worker: bool = None, - drop_last: bool = True, - batch_size: int = None, - parallel: bool = False): - if parallel and distributed.is_initialized(): - # drop last to make sure that there is no added special indexes - sampler = DistributedSampler(dataset, - shuffle=shuffle, - drop_last=True) - else: - sampler = None - return DataLoader( - dataset, - batch_size=batch_size or self.batch_size, - sampler=sampler, - # with sampler, use the sample instead of this option - shuffle=False if sampler else shuffle, - num_workers=num_worker or self.num_workers, - pin_memory=True, - drop_last=drop_last, - multiprocessing_context=get_context('fork'), - ) - - def make_model_conf(self): - if self.model_name == ModelName.beatgans_ddpm: - self.model_type = ModelType.ddpm - self.model_conf = BeatGANsUNetConfig( - attention_resolutions=self.net_attn, - channel_mult=self.net_ch_mult, - conv_resample=True, - dims=2, - dropout=self.dropout, - embed_channels=self.net_beatgans_embed_channels, - image_size=self.img_size, - in_channels=3, - model_channels=self.net_ch, - num_classes=None, - num_head_channels=-1, - num_heads_upsample=-1, - num_heads=self.net_beatgans_attn_head, - num_res_blocks=self.net_num_res_blocks, - num_input_res_blocks=self.net_num_input_res_blocks, - out_channels=self.model_out_channels, - resblock_updown=self.net_resblock_updown, - use_checkpoint=self.net_beatgans_gradient_checkpoint, - use_new_attention_order=False, - resnet_two_cond=self.net_beatgans_resnet_two_cond, - resnet_use_zero_module=self. - net_beatgans_resnet_use_zero_module, - ) - elif self.model_name in [ - ModelName.beatgans_autoenc, - ]: - cls = BeatGANsAutoencConfig - # supports both autoenc and vaeddpm - if self.model_name == ModelName.beatgans_autoenc: - self.model_type = ModelType.autoencoder - else: - raise NotImplementedError() - - if self.net_latent_net_type == LatentNetType.none: - latent_net_conf = None - elif self.net_latent_net_type == LatentNetType.skip: - latent_net_conf = MLPSkipNetConfig( - num_channels=self.style_ch, - skip_layers=self.net_latent_skip_layers, - num_hid_channels=self.net_latent_num_hid_channels, - num_layers=self.net_latent_layers, - num_time_emb_channels=self.net_latent_time_emb_channels, - activation=self.net_latent_activation, - use_norm=self.net_latent_use_norm, - condition_bias=self.net_latent_condition_bias, - dropout=self.net_latent_dropout, - last_act=self.net_latent_net_last_act, - num_time_layers=self.net_latent_num_time_layers, - time_last_act=self.net_latent_time_last_act, - ) - else: - raise NotImplementedError() - - self.model_conf = cls( - attention_resolutions=self.net_attn, - channel_mult=self.net_ch_mult, - conv_resample=True, - dims=2, - dropout=self.dropout, - embed_channels=self.net_beatgans_embed_channels, - enc_out_channels=self.style_ch, - enc_pool=self.net_enc_pool, - enc_num_res_block=self.net_enc_num_res_blocks, - enc_channel_mult=self.net_enc_channel_mult, - enc_grad_checkpoint=self.net_enc_grad_checkpoint, - enc_attn_resolutions=self.net_enc_attn, - image_size=self.img_size, - in_channels=3, - model_channels=self.net_ch, - num_classes=None, - num_head_channels=-1, - num_heads_upsample=-1, - num_heads=self.net_beatgans_attn_head, - num_res_blocks=self.net_num_res_blocks, - num_input_res_blocks=self.net_num_input_res_blocks, - out_channels=self.model_out_channels, - resblock_updown=self.net_resblock_updown, - use_checkpoint=self.net_beatgans_gradient_checkpoint, - use_new_attention_order=False, - resnet_two_cond=self.net_beatgans_resnet_two_cond, - resnet_use_zero_module=self. - net_beatgans_resnet_use_zero_module, - latent_net_conf=latent_net_conf, - resnet_cond_channels=self.net_beatgans_resnet_cond_channels, - ) - else: - raise NotImplementedError(self.model_name) - - return self.model_conf diff --git a/source/config_base.py b/source/config_base.py deleted file mode 100644 index 385f9ee..0000000 --- a/source/config_base.py +++ /dev/null @@ -1,72 +0,0 @@ -import json -import os -from copy import deepcopy -from dataclasses import dataclass - - -@dataclass -class BaseConfig: - def clone(self): - return deepcopy(self) - - def inherit(self, another): - """inherit common keys from a given config""" - common_keys = set(self.__dict__.keys()) & set(another.__dict__.keys()) - for k in common_keys: - setattr(self, k, getattr(another, k)) - - def propagate(self): - """push down the configuration to all members""" - for k, v in self.__dict__.items(): - if isinstance(v, BaseConfig): - v.inherit(self) - v.propagate() - - def save(self, save_path): - """save config to json file""" - dirname = os.path.dirname(save_path) - if not os.path.exists(dirname): - os.makedirs(dirname) - conf = self.as_dict_jsonable() - with open(save_path, 'w') as f: - json.dump(conf, f) - - def load(self, load_path): - """load json config""" - with open(load_path) as f: - conf = json.load(f) - self.from_dict(conf) - - def from_dict(self, dict, strict=False): - for k, v in dict.items(): - if not hasattr(self, k): - if strict: - raise ValueError(f"loading extra '{k}'") - else: - print(f"loading extra '{k}'") - continue - if isinstance(self.__dict__[k], BaseConfig): - self.__dict__[k].from_dict(v) - else: - self.__dict__[k] = v - - def as_dict_jsonable(self): - conf = {} - for k, v in self.__dict__.items(): - if isinstance(v, BaseConfig): - conf[k] = v.as_dict_jsonable() - else: - if jsonable(v): - conf[k] = v - else: - # ignore not jsonable - pass - return conf - - -def jsonable(x): - try: - json.dumps(x) - return True - except TypeError: - return False diff --git a/source/data_resize_bedroom.py b/source/data_resize_bedroom.py deleted file mode 100644 index 32a7351..0000000 --- a/source/data_resize_bedroom.py +++ /dev/null @@ -1,101 +0,0 @@ -import argparse -import multiprocessing -import os -from os.path import join, exists -from functools import partial -from io import BytesIO -import shutil - -import lmdb -from PIL import Image -from torchvision.datasets import LSUNClass -from torchvision.transforms import functional as trans_fn -from tqdm import tqdm - -from multiprocessing import Process, Queue - - -def resize_and_convert(img, size, resample, quality=100): - img = trans_fn.resize(img, size, resample) - img = trans_fn.center_crop(img, size) - buffer = BytesIO() - img.save(buffer, format="webp", quality=quality) - val = buffer.getvalue() - - return val - - -def resize_multiple(img, - sizes=(128, 256, 512, 1024), - resample=Image.LANCZOS, - quality=100): - imgs = [] - - for size in sizes: - imgs.append(resize_and_convert(img, size, resample, quality)) - - return imgs - - -def resize_worker(idx, img, sizes, resample): - img = img.convert("RGB") - out = resize_multiple(img, sizes=sizes, resample=resample) - return idx, out - - -from torch.utils.data import Dataset, DataLoader - - -class ConvertDataset(Dataset): - def __init__(self, data) -> None: - self.data = data - - def __len__(self): - return len(self.data) - - def __getitem__(self, index): - img, _ = self.data[index] - bytes = resize_and_convert(img, 256, Image.LANCZOS, quality=90) - return bytes - - -if __name__ == "__main__": - """ - converting lsun' original lmdb to our lmdb, which is somehow more performant. - """ - from tqdm import tqdm - - # path to the original lsun's lmdb - src_path = 'datasets/bedroom_train_lmdb' - out_path = 'datasets/bedroom256.lmdb' - - dataset = LSUNClass(root=os.path.expanduser(src_path)) - dataset = ConvertDataset(dataset) - loader = DataLoader(dataset, - batch_size=50, - num_workers=12, - collate_fn=collate_fn, - shuffle=False) - - target = os.path.expanduser(out_path) - if os.path.exists(target): - shutil.rmtree(target) - - with lmdb.open(target, map_size=1024**4, readahead=False) as env: - with tqdm(total=len(dataset)) as progress: - i = 0 - for batch in loader: - with env.begin(write=True) as txn: - for img in batch: - key = f"{256}-{str(i).zfill(7)}".encode("utf-8") - # print(key) - txn.put(key, img) - i += 1 - progress.update() - # if i == 1000: - # break - # if total == len(imgset): - # break - - with env.begin(write=True) as txn: - txn.put("length".encode("utf-8"), str(i).encode("utf-8")) diff --git a/source/data_resize_celeba.py b/source/data_resize_celeba.py deleted file mode 100644 index 02891cc..0000000 --- a/source/data_resize_celeba.py +++ /dev/null @@ -1,120 +0,0 @@ -import argparse -import multiprocessing -import os -import shutil -from functools import partial -from io import BytesIO -from multiprocessing import Process, Queue -from os.path import exists, join -from pathlib import Path - -import lmdb -from PIL import Image -from torch.utils.data import DataLoader, Dataset -from torchvision.datasets import LSUNClass -from torchvision.transforms import functional as trans_fn -from tqdm import tqdm - - -def resize_and_convert(img, size, resample, quality=100): - if size is not None: - img = trans_fn.resize(img, size, resample) - img = trans_fn.center_crop(img, size) - - buffer = BytesIO() - img.save(buffer, format="webp", quality=quality) - val = buffer.getvalue() - - return val - - -# Define a top-level collate function instead of using a lambda to avoid pickling issues -def collate_fn(batch): - return batch - -def resize_multiple(img, - sizes=(128, 256, 512, 1024), - resample=Image.LANCZOS, - quality=100): - imgs = [] - - for size in sizes: - imgs.append(resize_and_convert(img, size, resample, quality)) - - return imgs - - -def resize_worker(idx, img, sizes, resample): - img = img.convert("RGB") - out = resize_multiple(img, sizes=sizes, resample=resample) - return idx, out - - -class ConvertDataset(Dataset): - def __init__(self, data, size) -> None: - self.data = data - self.size = size - - def __len__(self): - return len(self.data) - - def __getitem__(self, index): - img = self.data[index] - bytes = resize_and_convert(img, self.size, Image.LANCZOS, quality=100) - return bytes - - -class ImageFolder(Dataset): - def __init__(self, folder, ext='jpg'): - super().__init__() - paths = sorted([p for p in Path(f'{folder}').glob(f'*.{ext}')]) - self.paths = paths - - def __len__(self): - return len(self.paths) - - def __getitem__(self, index): - path = os.path.join(self.paths[index]) - img = Image.open(path) - return img - - -if __name__ == "__main__": - from tqdm import tqdm - - out_path = 'datasets/celeba.lmdb' - in_path = 'datasets/celeba' - ext = 'jpg' - size = None - - dataset = ImageFolder(in_path, ext) - print('len:', len(dataset)) - dataset = ConvertDataset(dataset, size) - loader = DataLoader(dataset, - batch_size=50, - num_workers=12, - collate_fn=collate_fn, - shuffle=False) - - target = os.path.expanduser(out_path) - if os.path.exists(target): - shutil.rmtree(target) - - with lmdb.open(target, map_size=1024**4, readahead=False) as env: - with tqdm(total=len(dataset)) as progress: - i = 0 - for batch in loader: - with env.begin(write=True) as txn: - for img in batch: - key = f"{size}-{str(i).zfill(7)}".encode("utf-8") - # print(key) - txn.put(key, img) - i += 1 - progress.update() - # if i == 1000: - # break - # if total == len(imgset): - # break - - with env.begin(write=True) as txn: - txn.put("length".encode("utf-8"), str(i).encode("utf-8")) diff --git a/source/data_resize_celebahq.py b/source/data_resize_celebahq.py deleted file mode 100644 index c6a2d7f..0000000 --- a/source/data_resize_celebahq.py +++ /dev/null @@ -1,120 +0,0 @@ -import argparse -import multiprocessing -from functools import partial -from io import BytesIO -from pathlib import Path - -import lmdb -from PIL import Image -from torch.utils.data import Dataset -from torchvision.transforms import functional as trans_fn -from tqdm import tqdm -import os - - -def resize_and_convert(img, size, resample, quality=100): - img = trans_fn.resize(img, size, resample) - img = trans_fn.center_crop(img, size) - buffer = BytesIO() - img.save(buffer, format="jpeg", quality=quality) - val = buffer.getvalue() - - return val - -# Define a top-level collate function instead of using a lambda to avoid pickling issues -def collate_fn(batch): - return batch - -def resize_multiple(img, - sizes=(128, 256, 512, 1024), - resample=Image.LANCZOS, - quality=100): - imgs = [] - - for size in sizes: - imgs.append(resize_and_convert(img, size, resample, quality)) - - return imgs - - -def resize_worker(img_file, sizes, resample): - i, (file, idx) = img_file - img = Image.open(file) - img = img.convert("RGB") - out = resize_multiple(img, sizes=sizes, resample=resample) - - return i, idx, out - - -def prepare(env, - paths, - n_worker, - sizes=(128, 256, 512, 1024), - resample=Image.LANCZOS): - resize_fn = partial(resize_worker, sizes=sizes, resample=resample) - - # index = filename in int - indexs = [] - for each in paths: - file = os.path.basename(each) - name, ext = file.split('.') - idx = int(name) - indexs.append(idx) - - # sort by file index - files = sorted(zip(paths, indexs), key=lambda x: x[1]) - files = list(enumerate(files)) - total = 0 - - with multiprocessing.Pool(n_worker) as pool: - for i, idx, imgs in tqdm(pool.imap_unordered(resize_fn, files)): - for size, img in zip(sizes, imgs): - key = f"{size}-{str(idx).zfill(5)}".encode("utf-8") - - with env.begin(write=True) as txn: - txn.put(key, img) - - total += 1 - - with env.begin(write=True) as txn: - txn.put("length".encode("utf-8"), str(total).encode("utf-8")) - - -class ImageFolder(Dataset): - def __init__(self, folder, exts=['jpg']): - super().__init__() - self.paths = [ - p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}') - ] - - def __len__(self): - return len(self.paths) - - def __getitem__(self, index): - path = os.path.join(self.folder, self.paths[index]) - img = Image.open(path) - return img - - -if __name__ == "__main__": - """ - converting celebahq images to lmdb - """ - num_workers = 16 - in_path = 'datasets/celebahq' - out_path = 'datasets/celebahq256.lmdb' - - resample_map = {"lanczos": Image.LANCZOS, "bilinear": Image.BILINEAR} - resample = resample_map['lanczos'] - - sizes = [256] - - print(f"Make dataset of image sizes:", ", ".join(str(s) for s in sizes)) - - # imgset = datasets.ImageFolder(in_path) - # imgset = ImageFolder(in_path) - exts = ['jpg'] - paths = [p for ext in exts for p in Path(f'{in_path}').glob(f'**/*.{ext}')] - - with lmdb.open(out_path, map_size=1024**4, readahead=False) as env: - prepare(env, paths, num_workers, sizes=sizes, resample=resample) diff --git a/source/data_resize_ffhq.py b/source/data_resize_ffhq.py deleted file mode 100644 index 226c8ec..0000000 --- a/source/data_resize_ffhq.py +++ /dev/null @@ -1,123 +0,0 @@ -import argparse -import multiprocessing -from functools import partial -from io import BytesIO -from pathlib import Path - -import lmdb -from PIL import Image -from torch.utils.data import Dataset -from torchvision.transforms import functional as trans_fn -from tqdm import tqdm -import os - - -def resize_and_convert(img, size, resample, quality=100): - img = trans_fn.resize(img, size, resample) - img = trans_fn.center_crop(img, size) - buffer = BytesIO() - img.save(buffer, format="jpeg", quality=quality) - val = buffer.getvalue() - - return val - - -def resize_multiple(img, - sizes=(128, 256, 512, 1024), - resample=Image.LANCZOS, - quality=100): - imgs = [] - - for size in sizes: - imgs.append(resize_and_convert(img, size, resample, quality)) - - return imgs - - -def resize_worker(img_file, sizes, resample): - i, (file, idx) = img_file - img = Image.open(file) - img = img.convert("RGB") - out = resize_multiple(img, sizes=sizes, resample=resample) - - return i, idx, out - - -def prepare(env, - paths, - n_worker, - sizes=(128, 256, 512, 1024), - resample=Image.LANCZOS): - resize_fn = partial(resize_worker, sizes=sizes, resample=resample) - - # index = filename in int - indexs = [] - for each in paths: - file = os.path.basename(each) - name, ext = file.split('.') - idx = int(name) - indexs.append(idx) - - # sort by file index - files = sorted(zip(paths, indexs), key=lambda x: x[1]) - files = list(enumerate(files)) - total = 0 - - with multiprocessing.Pool(n_worker) as pool: - for i, idx, imgs in tqdm(pool.imap_unordered(resize_fn, files)): - for size, img in zip(sizes, imgs): - key = f"{size}-{str(idx).zfill(5)}".encode("utf-8") - - with env.begin(write=True) as txn: - txn.put(key, img) - - total += 1 - - with env.begin(write=True) as txn: - txn.put("length".encode("utf-8"), str(total).encode("utf-8")) - - -class ImageFolder(Dataset): - def __init__(self, folder, exts=['jpg']): - super().__init__() - self.paths = [ - p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}') - ] - - def __len__(self): - return len(self.paths) - - def __getitem__(self, index): - path = os.path.join(self.folder, self.paths[index]) - img = Image.open(path) - return img - - -if __name__ == "__main__": - """ - converting ffhq images to lmdb - """ - num_workers = 16 - # original ffhq data path - in_path = 'datasets/ffhq' - # target output path - out_path = 'datasets/ffhq.lmdb' - - if not os.path.exists(out_path): - os.makedirs(out_path) - - resample_map = {"lanczos": Image.LANCZOS, "bilinear": Image.BILINEAR} - resample = resample_map['lanczos'] - - sizes = [256] - - print(f"Make dataset of image sizes:", ", ".join(str(s) for s in sizes)) - - # imgset = datasets.ImageFolder(in_path) - # imgset = ImageFolder(in_path) - exts = ['jpg'] - paths = [p for ext in exts for p in Path(f'{in_path}').glob(f'**/*.{ext}')] - # print(paths[:10]) - - with lmdb.open(out_path, map_size=1024**4, readahead=False) as env: - prepare(env, paths, num_workers, sizes=sizes, resample=resample) diff --git a/source/data_resize_horse.py b/source/data_resize_horse.py deleted file mode 100644 index 6893613..0000000 --- a/source/data_resize_horse.py +++ /dev/null @@ -1,100 +0,0 @@ -import argparse -import multiprocessing -import os -import shutil -from functools import partial -from io import BytesIO -from multiprocessing import Process, Queue -from os.path import exists, join - -import lmdb -from PIL import Image -from torch.utils.data import DataLoader, Dataset -from torchvision.datasets import LSUNClass -from torchvision.transforms import functional as trans_fn -from tqdm import tqdm - - -def resize_and_convert(img, size, resample, quality=100): - img = trans_fn.resize(img, size, resample) - img = trans_fn.center_crop(img, size) - buffer = BytesIO() - img.save(buffer, format="webp", quality=quality) - val = buffer.getvalue() - - return val - - -def resize_multiple(img, - sizes=(128, 256, 512, 1024), - resample=Image.LANCZOS, - quality=100): - imgs = [] - - for size in sizes: - imgs.append(resize_and_convert(img, size, resample, quality)) - - return imgs - - -def resize_worker(idx, img, sizes, resample): - img = img.convert("RGB") - out = resize_multiple(img, sizes=sizes, resample=resample) - return idx, out - - -class ConvertDataset(Dataset): - def __init__(self, data) -> None: - self.data = data - - def __len__(self): - return len(self.data) - - def __getitem__(self, index): - img, _ = self.data[index] - bytes = resize_and_convert(img, 256, Image.LANCZOS, quality=90) - return bytes - -# Define a top-level collate function instead of using a lambda to avoid pickling issues -def collate_fn(batch): - return batch - -if __name__ == "__main__": - """ - converting lsun' original lmdb to our lmdb, which is somehow more performant. - """ - from tqdm import tqdm - - # path to the original lsun's lmdb - src_path = 'datasets/horse_train_lmdb' - out_path = 'datasets/horse256.lmdb' - - dataset = LSUNClass(root=os.path.expanduser(src_path)) - dataset = ConvertDataset(dataset) - loader = DataLoader(dataset, - batch_size=50, - num_workers=16, - collate_fn=collate_fn) - - target = os.path.expanduser(out_path) - if os.path.exists(target): - shutil.rmtree(target) - - with lmdb.open(target, map_size=1024**4, readahead=False) as env: - with tqdm(total=len(dataset)) as progress: - i = 0 - for batch in loader: - with env.begin(write=True) as txn: - for img in batch: - key = f"{256}-{str(i).zfill(7)}".encode("utf-8") - # print(key) - txn.put(key, img) - i += 1 - progress.update() - # if i == 1000: - # break - # if total == len(imgset): - # break - - with env.begin(write=True) as txn: - txn.put("length".encode("utf-8"), str(i).encode("utf-8")) diff --git a/source/dataset.py b/source/dataset.py deleted file mode 100644 index 627671f..0000000 --- a/source/dataset.py +++ /dev/null @@ -1,716 +0,0 @@ -import os -from io import BytesIO -from pathlib import Path - -import lmdb -from PIL import Image -from torch.utils.data import Dataset -from torchvision import transforms -from torchvision.datasets import CIFAR10, LSUNClass -import torch -import pandas as pd - -import torchvision.transforms.functional as Ftrans - - -class ImageDataset(Dataset): - def __init__( - self, - folder, - image_size, - exts=['jpg'], - do_augment: bool = True, - do_transform: bool = True, - do_normalize: bool = True, - sort_names=False, - has_subdir: bool = True, - ): - super().__init__() - self.folder = folder - self.image_size = image_size - - # relative paths (make it shorter, saves memory and faster to sort) - if has_subdir: - self.paths = [ - p.relative_to(folder) for ext in exts - for p in Path(f'{folder}').glob(f'**/*.{ext}') - ] - else: - self.paths = [ - p.relative_to(folder) for ext in exts - for p in Path(f'{folder}').glob(f'*.{ext}') - ] - if sort_names: - self.paths = sorted(self.paths) - - transform = [ - transforms.Resize(image_size), - transforms.CenterCrop(image_size), - ] - if do_augment: - transform.append(transforms.RandomHorizontalFlip()) - if do_transform: - transform.append(transforms.ToTensor()) - if do_normalize: - transform.append( - transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) - self.transform = transforms.Compose(transform) - - def __len__(self): - return len(self.paths) - - def __getitem__(self, index): - path = os.path.join(self.folder, self.paths[index]) - img = Image.open(path) - # if the image is 'rgba'! - img = img.convert('RGB') - if self.transform is not None: - img = self.transform(img) - return {'img': img, 'index': index} - - -class SubsetDataset(Dataset): - def __init__(self, dataset, size): - assert len(dataset) >= size - self.dataset = dataset - self.size = size - - def __len__(self): - return self.size - - def __getitem__(self, index): - assert index < self.size - return self.dataset[index] - - -class BaseLMDB(Dataset): - def __init__(self, path, original_resolution, zfill: int = 5): - self.original_resolution = original_resolution - self.zfill = zfill - self.env = lmdb.open( - path, - max_readers=32, - readonly=True, - lock=False, - readahead=False, - meminit=False, - ) - - if not self.env: - raise IOError('Cannot open lmdb dataset', path) - - with self.env.begin(write=False) as txn: - self.length = int( - txn.get('length'.encode('utf-8')).decode('utf-8')) - - def __len__(self): - return self.length - - def __getitem__(self, index): - with self.env.begin(write=False) as txn: - key = f'{self.original_resolution}-{str(index).zfill(self.zfill)}'.encode( - 'utf-8') - img_bytes = txn.get(key) - - buffer = BytesIO(img_bytes) - img = Image.open(buffer) - return img - - -def make_transform( - image_size, - flip_prob=0.5, - crop_d2c=False, -): - if crop_d2c: - transform = [ - d2c_crop(), - transforms.Resize(image_size), - ] - else: - transform = [ - transforms.Resize(image_size), - transforms.CenterCrop(image_size), - ] - transform.append(transforms.RandomHorizontalFlip(p=flip_prob)) - transform.append(transforms.ToTensor()) - transform.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) - transform = transforms.Compose(transform) - return transform - - -class FFHQlmdb(Dataset): - def __init__(self, - path=os.path.expanduser('datasets/ffhq256.lmdb'), - image_size=256, - original_resolution=256, - split=None, - as_tensor: bool = True, - do_augment: bool = True, - do_normalize: bool = True, - **kwargs): - self.original_resolution = original_resolution - self.data = BaseLMDB(path, original_resolution, zfill=5) - self.length = len(self.data) - - if split is None: - self.offset = 0 - elif split == 'train': - # last 60k - self.length = self.length - 10000 - self.offset = 10000 - elif split == 'test': - # first 10k - self.length = 10000 - self.offset = 0 - else: - raise NotImplementedError() - - transform = [ - transforms.Resize(image_size), - ] - if do_augment: - transform.append(transforms.RandomHorizontalFlip()) - if as_tensor: - transform.append(transforms.ToTensor()) - if do_normalize: - transform.append( - transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) - self.transform = transforms.Compose(transform) - - def __len__(self): - return self.length - - def __getitem__(self, index): - assert index < self.length - index = index + self.offset - img = self.data[index] - if self.transform is not None: - img = self.transform(img) - return {'img': img, 'index': index} - - -class Crop: - def __init__(self, x1, x2, y1, y2): - self.x1 = x1 - self.x2 = x2 - self.y1 = y1 - self.y2 = y2 - - def __call__(self, img): - return Ftrans.crop(img, self.x1, self.y1, self.x2 - self.x1, - self.y2 - self.y1) - - def __repr__(self): - return self.__class__.__name__ + "(x1={}, x2={}, y1={}, y2={})".format( - self.x1, self.x2, self.y1, self.y2) - - -def d2c_crop(): - # from D2C paper for CelebA dataset. - cx = 89 - cy = 121 - x1 = cy - 64 - x2 = cy + 64 - y1 = cx - 64 - y2 = cx + 64 - return Crop(x1, x2, y1, y2) - - -class CelebAlmdb(Dataset): - """ - also supports for d2c crop. - """ - def __init__(self, - path, - image_size, - original_resolution=128, - split=None, - as_tensor: bool = True, - do_augment: bool = True, - do_normalize: bool = True, - crop_d2c: bool = False, - **kwargs): - self.original_resolution = original_resolution - self.data = BaseLMDB(path, original_resolution, zfill=7) - self.length = len(self.data) - self.crop_d2c = crop_d2c - - if split is None: - self.offset = 0 - else: - raise NotImplementedError() - - if crop_d2c: - transform = [ - d2c_crop(), - transforms.Resize(image_size), - ] - else: - transform = [ - transforms.Resize(image_size), - transforms.CenterCrop(image_size), - ] - - if do_augment: - transform.append(transforms.RandomHorizontalFlip()) - if as_tensor: - transform.append(transforms.ToTensor()) - if do_normalize: - transform.append( - transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) - self.transform = transforms.Compose(transform) - - def __len__(self): - return self.length - - def __getitem__(self, index): - assert index < self.length - index = index + self.offset - img = self.data[index] - if self.transform is not None: - img = self.transform(img) - return {'img': img, 'index': index} - - -class Horse_lmdb(Dataset): - def __init__(self, - path=os.path.expanduser('datasets/horse256.lmdb'), - image_size=128, - original_resolution=256, - do_augment: bool = True, - do_transform: bool = True, - do_normalize: bool = True, - **kwargs): - self.original_resolution = original_resolution - print(path) - self.data = BaseLMDB(path, original_resolution, zfill=7) - self.length = len(self.data) - - transform = [ - transforms.Resize(image_size), - transforms.CenterCrop(image_size), - ] - if do_augment: - transform.append(transforms.RandomHorizontalFlip()) - if do_transform: - transform.append(transforms.ToTensor()) - if do_normalize: - transform.append( - transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) - self.transform = transforms.Compose(transform) - - def __len__(self): - return self.length - - def __getitem__(self, index): - img = self.data[index] - if self.transform is not None: - img = self.transform(img) - return {'img': img, 'index': index} - - -class Bedroom_lmdb(Dataset): - def __init__(self, - path=os.path.expanduser('datasets/bedroom256.lmdb'), - image_size=128, - original_resolution=256, - do_augment: bool = True, - do_transform: bool = True, - do_normalize: bool = True, - **kwargs): - self.original_resolution = original_resolution - print(path) - self.data = BaseLMDB(path, original_resolution, zfill=7) - self.length = len(self.data) - - transform = [ - transforms.Resize(image_size), - transforms.CenterCrop(image_size), - ] - if do_augment: - transform.append(transforms.RandomHorizontalFlip()) - if do_transform: - transform.append(transforms.ToTensor()) - if do_normalize: - transform.append( - transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) - self.transform = transforms.Compose(transform) - - def __len__(self): - return self.length - - def __getitem__(self, index): - img = self.data[index] - img = self.transform(img) - return {'img': img, 'index': index} - - -class CelebAttrDataset(Dataset): - - id_to_cls = [ - '5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', 'Bags_Under_Eyes', - 'Bald', 'Bangs', 'Big_Lips', 'Big_Nose', 'Black_Hair', 'Blond_Hair', - 'Blurry', 'Brown_Hair', 'Bushy_Eyebrows', 'Chubby', 'Double_Chin', - 'Eyeglasses', 'Goatee', 'Gray_Hair', 'Heavy_Makeup', 'High_Cheekbones', - 'Male', 'Mouth_Slightly_Open', 'Mustache', 'Narrow_Eyes', 'No_Beard', - 'Oval_Face', 'Pale_Skin', 'Pointy_Nose', 'Receding_Hairline', - 'Rosy_Cheeks', 'Sideburns', 'Smiling', 'Straight_Hair', 'Wavy_Hair', - 'Wearing_Earrings', 'Wearing_Hat', 'Wearing_Lipstick', - 'Wearing_Necklace', 'Wearing_Necktie', 'Young' - ] - cls_to_id = {v: k for k, v in enumerate(id_to_cls)} - - def __init__(self, - folder, - image_size=64, - attr_path=os.path.expanduser( - 'datasets/celeba_anno/list_attr_celeba.txt'), - ext='png', - only_cls_name: str = None, - only_cls_value: int = None, - do_augment: bool = False, - do_transform: bool = True, - do_normalize: bool = True, - d2c: bool = False): - super().__init__() - self.folder = folder - self.image_size = image_size - self.ext = ext - - # relative paths (make it shorter, saves memory and faster to sort) - paths = [ - str(p.relative_to(folder)) - for p in Path(f'{folder}').glob(f'**/*.{ext}') - ] - paths = [str(each).split('.')[0] + '.jpg' for each in paths] - - if d2c: - transform = [ - d2c_crop(), - transforms.Resize(image_size), - ] - else: - transform = [ - transforms.Resize(image_size), - transforms.CenterCrop(image_size), - ] - if do_augment: - transform.append(transforms.RandomHorizontalFlip()) - if do_transform: - transform.append(transforms.ToTensor()) - if do_normalize: - transform.append( - transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) - self.transform = transforms.Compose(transform) - - with open(attr_path) as f: - # discard the top line - f.readline() - self.df = pd.read_csv(f, delim_whitespace=True) - self.df = self.df[self.df.index.isin(paths)] - - if only_cls_name is not None: - self.df = self.df[self.df[only_cls_name] == only_cls_value] - - def pos_count(self, cls_name): - return (self.df[cls_name] == 1).sum() - - def neg_count(self, cls_name): - return (self.df[cls_name] == -1).sum() - - def __len__(self): - return len(self.df) - - def __getitem__(self, index): - row = self.df.iloc[index] - name = row.name.split('.')[0] - name = f'{name}.{self.ext}' - - path = os.path.join(self.folder, name) - img = Image.open(path) - - labels = [0] * len(self.id_to_cls) - for k, v in row.items(): - labels[self.cls_to_id[k]] = int(v) - - if self.transform is not None: - img = self.transform(img) - - return {'img': img, 'index': index, 'labels': torch.tensor(labels)} - - -class CelebD2CAttrDataset(CelebAttrDataset): - """ - the dataset is used in the D2C paper. - it has a specific crop from the original CelebA. - """ - def __init__(self, - folder, - image_size=64, - attr_path=os.path.expanduser( - 'datasets/celeba_anno/list_attr_celeba.txt'), - ext='jpg', - only_cls_name: str = None, - only_cls_value: int = None, - do_augment: bool = False, - do_transform: bool = True, - do_normalize: bool = True, - d2c: bool = True): - super().__init__(folder, - image_size, - attr_path, - ext=ext, - only_cls_name=only_cls_name, - only_cls_value=only_cls_value, - do_augment=do_augment, - do_transform=do_transform, - do_normalize=do_normalize, - d2c=d2c) - - -class CelebAttrFewshotDataset(Dataset): - def __init__( - self, - cls_name, - K, - img_folder, - img_size=64, - ext='png', - seed=0, - only_cls_name: str = None, - only_cls_value: int = None, - all_neg: bool = False, - do_augment: bool = False, - do_transform: bool = True, - do_normalize: bool = True, - d2c: bool = False, - ) -> None: - self.cls_name = cls_name - self.K = K - self.img_folder = img_folder - self.ext = ext - - if all_neg: - path = f'data/celeba_fewshots/K{K}_allneg_{cls_name}_{seed}.csv' - else: - path = f'data/celeba_fewshots/K{K}_{cls_name}_{seed}.csv' - self.df = pd.read_csv(path, index_col=0) - if only_cls_name is not None: - self.df = self.df[self.df[only_cls_name] == only_cls_value] - - if d2c: - transform = [ - d2c_crop(), - transforms.Resize(img_size), - ] - else: - transform = [ - transforms.Resize(img_size), - transforms.CenterCrop(img_size), - ] - if do_augment: - transform.append(transforms.RandomHorizontalFlip()) - if do_transform: - transform.append(transforms.ToTensor()) - if do_normalize: - transform.append( - transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) - self.transform = transforms.Compose(transform) - - def pos_count(self, cls_name): - return (self.df[cls_name] == 1).sum() - - def neg_count(self, cls_name): - return (self.df[cls_name] == -1).sum() - - def __len__(self): - return len(self.df) - - def __getitem__(self, index): - row = self.df.iloc[index] - name = row.name.split('.')[0] - name = f'{name}.{self.ext}' - - path = os.path.join(self.img_folder, name) - img = Image.open(path) - - # (1, 1) - label = torch.tensor(int(row[self.cls_name])).unsqueeze(-1) - - if self.transform is not None: - img = self.transform(img) - - return {'img': img, 'index': index, 'labels': label} - - -class CelebD2CAttrFewshotDataset(CelebAttrFewshotDataset): - def __init__(self, - cls_name, - K, - img_folder, - img_size=64, - ext='jpg', - seed=0, - only_cls_name: str = None, - only_cls_value: int = None, - all_neg: bool = False, - do_augment: bool = False, - do_transform: bool = True, - do_normalize: bool = True, - is_negative=False, - d2c: bool = True) -> None: - super().__init__(cls_name, - K, - img_folder, - img_size, - ext=ext, - seed=seed, - only_cls_name=only_cls_name, - only_cls_value=only_cls_value, - all_neg=all_neg, - do_augment=do_augment, - do_transform=do_transform, - do_normalize=do_normalize, - d2c=d2c) - self.is_negative = is_negative - - -class CelebHQAttrDataset(Dataset): - id_to_cls = [ - '5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', 'Bags_Under_Eyes', - 'Bald', 'Bangs', 'Big_Lips', 'Big_Nose', 'Black_Hair', 'Blond_Hair', - 'Blurry', 'Brown_Hair', 'Bushy_Eyebrows', 'Chubby', 'Double_Chin', - 'Eyeglasses', 'Goatee', 'Gray_Hair', 'Heavy_Makeup', 'High_Cheekbones', - 'Male', 'Mouth_Slightly_Open', 'Mustache', 'Narrow_Eyes', 'No_Beard', - 'Oval_Face', 'Pale_Skin', 'Pointy_Nose', 'Receding_Hairline', - 'Rosy_Cheeks', 'Sideburns', 'Smiling', 'Straight_Hair', 'Wavy_Hair', - 'Wearing_Earrings', 'Wearing_Hat', 'Wearing_Lipstick', - 'Wearing_Necklace', 'Wearing_Necktie', 'Young' - ] - cls_to_id = {v: k for k, v in enumerate(id_to_cls)} - - def __init__(self, - path=os.path.expanduser('datasets/celebahq256.lmdb'), - image_size=None, - attr_path=os.path.expanduser( - 'datasets/celeba_anno/CelebAMask-HQ-attribute-anno.txt'), - original_resolution=256, - do_augment: bool = False, - do_transform: bool = True, - do_normalize: bool = True): - super().__init__() - self.image_size = image_size - self.data = BaseLMDB(path, original_resolution, zfill=5) - - transform = [ - transforms.Resize(image_size), - transforms.CenterCrop(image_size), - ] - if do_augment: - transform.append(transforms.RandomHorizontalFlip()) - if do_transform: - transform.append(transforms.ToTensor()) - if do_normalize: - transform.append( - transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) - self.transform = transforms.Compose(transform) - - with open(attr_path) as f: - # discard the top line - f.readline() - self.df = pd.read_csv(f, delim_whitespace=True) - - def pos_count(self, cls_name): - return (self.df[cls_name] == 1).sum() - - def neg_count(self, cls_name): - return (self.df[cls_name] == -1).sum() - - def __len__(self): - return len(self.df) - - def __getitem__(self, index): - row = self.df.iloc[index] - img_name = row.name - img_idx, ext = img_name.split('.') - img = self.data[img_idx] - - labels = [0] * len(self.id_to_cls) - for k, v in row.items(): - labels[self.cls_to_id[k]] = int(v) - - if self.transform is not None: - img = self.transform(img) - return {'img': img, 'index': index, 'labels': torch.tensor(labels)} - - -class CelebHQAttrFewshotDataset(Dataset): - def __init__(self, - cls_name, - K, - path, - image_size, - original_resolution=256, - do_augment: bool = False, - do_transform: bool = True, - do_normalize: bool = True): - super().__init__() - self.image_size = image_size - self.cls_name = cls_name - self.K = K - self.data = BaseLMDB(path, original_resolution, zfill=5) - - transform = [ - transforms.Resize(image_size), - transforms.CenterCrop(image_size), - ] - if do_augment: - transform.append(transforms.RandomHorizontalFlip()) - if do_transform: - transform.append(transforms.ToTensor()) - if do_normalize: - transform.append( - transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) - self.transform = transforms.Compose(transform) - - self.df = pd.read_csv(f'data/celebahq_fewshots/K{K}_{cls_name}.csv', - index_col=0) - - def pos_count(self, cls_name): - return (self.df[cls_name] == 1).sum() - - def neg_count(self, cls_name): - return (self.df[cls_name] == -1).sum() - - def __len__(self): - return len(self.df) - - def __getitem__(self, index): - row = self.df.iloc[index] - img_name = row.name - img_idx, ext = img_name.split('.') - img = self.data[img_idx] - - # (1, 1) - label = torch.tensor(int(row[self.cls_name])).unsqueeze(-1) - - if self.transform is not None: - img = self.transform(img) - - return {'img': img, 'index': index, 'labels': label} - - -class Repeat(Dataset): - def __init__(self, dataset, new_len) -> None: - super().__init__() - self.dataset = dataset - self.original_len = len(dataset) - self.new_len = new_len - - def __len__(self): - return self.new_len - - def __getitem__(self, index): - index = index % self.original_len - return self.dataset[index] diff --git a/source/dataset_util.py b/source/dataset_util.py deleted file mode 100644 index d2075ac..0000000 --- a/source/dataset_util.py +++ /dev/null @@ -1,13 +0,0 @@ -import shutil -import os -from dist_utils import * - - -def use_cached_dataset_path(source_path, cache_path): - if get_rank() == 0: - if not os.path.exists(cache_path): - # shutil.rmtree(cache_path) - print(f'copying the data: {source_path} to {cache_path}') - shutil.copytree(source_path, cache_path) - barrier() - return cache_path \ No newline at end of file diff --git a/fixed_imports.py b/source/fixed_imports.py similarity index 100% rename from fixed_imports.py rename to source/fixed_imports.py diff --git a/source/templates.py b/source/templates.py deleted file mode 100644 index 63cb4de..0000000 --- a/source/templates.py +++ /dev/null @@ -1,323 +0,0 @@ -from experiment import * - -# PyTorch vs PyTorch Lighning Module - -# Check if the environment variable 'BREAK_CIRCULAR_IMPORTS' is set. -# This variable is likely used to prevent issues related to circular imports. -if os.environ.get('BREAK_CIRCULAR_IMPORTS'): - # If the variable is set, we deliberately skip importing modules that are known to cause circular import problems. - import sys - - # Create a dummy module for 'torch.types' by assigning an empty type (class) to it. - sys.modules['torch.types'] = type('', (), {}) - # Similarly, create a dummy module for 'torch.utils._python_dispatch'. - sys.modules['torch.utils._python_dispatch'] = type('', (), {}) - -# Try to import the symbol 'IMPORTS_READY' from the module 'fixed_imports'. -# This is wrapped in a try/except block to gracefully handle the situation where 'fixed_imports' isn't available. -try: - from fixed_imports import IMPORTS_READY -except ImportError: - # If the module or the symbol is not found, simply pass without raising an error. - pass - - - -def ddpm(): - """ - base configuration for all DDIM-based models. - """ - conf = TrainConfig() - conf.batch_size = 32 - conf.beatgans_gen_type = GenerativeType.ddim - conf.beta_scheduler = 'linear' - conf.data_name = 'ffhq' - conf.diffusion_type = 'beatgans' - conf.eval_ema_every_samples = 200_000 - conf.eval_every_samples = 200_000 - conf.fp16 = True - conf.lr = 1e-4 - conf.model_name = ModelName.beatgans_ddpm - conf.net_attn = (16, ) - conf.net_beatgans_attn_head = 1 - conf.net_beatgans_embed_channels = 512 - conf.net_ch_mult = (1, 2, 4, 8) - conf.net_ch = 64 - conf.sample_size = 32 - conf.T_eval = 20 - conf.T = 1000 - conf.make_model_conf() - return conf - - -def autoenc_base(): - """ - base configuration for all Diff-AE models. - """ - conf = TrainConfig() - conf.batch_size = 32 - conf.beatgans_gen_type = GenerativeType.ddim - conf.beta_scheduler = 'linear' - conf.data_name = 'ffhq' - conf.diffusion_type = 'beatgans' - conf.eval_ema_every_samples = 200_000 - conf.eval_every_samples = 200_000 - conf.fp16 = True - conf.lr = 1e-4 - conf.model_name = ModelName.beatgans_autoenc - conf.net_attn = (16, ) - conf.net_beatgans_attn_head = 1 - conf.net_beatgans_embed_channels = 512 - conf.net_beatgans_resnet_two_cond = True - conf.net_ch_mult = (1, 2, 4, 8) - conf.net_ch = 64 - conf.net_enc_channel_mult = (1, 2, 4, 8, 8) - conf.net_enc_pool = 'adaptivenonzero' - conf.sample_size = 32 - conf.T_eval = 20 - conf.T = 1000 - conf.make_model_conf() - return conf - - -def ffhq64_ddpm(): - conf = ddpm() - conf.data_name = 'ffhqlmdb256' - conf.warmup = 0 - conf.total_samples = 72_000_000 - conf.scale_up_gpus(4) - return conf - - -def ffhq64_autoenc(): - conf = autoenc_base() - conf.data_name = 'ffhqlmdb256' - conf.warmup = 0 - conf.total_samples = 72_000_000 - conf.net_ch_mult = (1, 2, 4, 8) - conf.net_enc_channel_mult = (1, 2, 4, 8, 8) - conf.eval_every_samples = 1_000_000 - conf.eval_ema_every_samples = 1_000_000 - conf.scale_up_gpus(4) - conf.make_model_conf() - return conf - - -def celeba64d2c_ddpm(): - conf = ffhq128_ddpm() - conf.data_name = 'celebalmdb' - conf.eval_every_samples = 10_000_000 - conf.eval_ema_every_samples = 10_000_000 - conf.total_samples = 72_000_000 - conf.name = 'celeba64d2c_ddpm' - return conf - - -def celeba64d2c_autoenc(): - conf = ffhq64_autoenc() - conf.data_name = 'celebalmdb' - conf.eval_every_samples = 10_000_000 - conf.eval_ema_every_samples = 10_000_000 - conf.total_samples = 72_000_000 - conf.name = 'celeba64d2c_autoenc' - return conf - - -def ffhq128_ddpm(): - conf = ddpm() - conf.data_name = 'ffhqlmdb256' - conf.warmup = 0 - conf.total_samples = 48_000_000 - conf.img_size = 128 - conf.net_ch = 128 - # channels: - # 3 => 128 * 1 => 128 * 1 => 128 * 2 => 128 * 3 => 128 * 4 - # sizes: - # 128 => 128 => 64 => 32 => 16 => 8 - conf.net_ch_mult = (1, 1, 2, 3, 4) - conf.eval_every_samples = 1_000_000 - conf.eval_ema_every_samples = 1_000_000 - conf.scale_up_gpus(4) - conf.eval_ema_every_samples = 10_000_000 - conf.eval_every_samples = 10_000_000 - conf.make_model_conf() - return conf - - -def ffhq128_autoenc_base(): - conf = autoenc_base() - conf.data_name = 'ffhqlmdb256' - conf.scale_up_gpus(4) - conf.img_size = 128 - conf.net_ch = 128 - # final resolution = 8x8 - conf.net_ch_mult = (1, 1, 2, 3, 4) - # final resolution = 4x4 - conf.net_enc_channel_mult = (1, 1, 2, 3, 4, 4) - conf.eval_ema_every_samples = 10_000_000 - conf.eval_every_samples = 10_000_000 - conf.make_model_conf() - return conf - - -def ffhq256_autoenc(): - conf = ffhq128_autoenc_base() - conf.img_size = 256 - conf.net_ch = 128 - conf.net_ch_mult = (1, 1, 2, 2, 4, 4) - conf.net_enc_channel_mult = (1, 1, 2, 2, 4, 4, 4) - conf.eval_every_samples = 10_000_000 - conf.eval_ema_every_samples = 10_000_000 - conf.total_samples = 200_000_000 - conf.batch_size = 64 - conf.make_model_conf() - conf.name = 'ffhq256_autoenc' - return conf - - -def ffhq256_autoenc_eco(): - conf = ffhq128_autoenc_base() - conf.img_size = 256 - conf.net_ch = 128 - conf.net_ch_mult = (1, 1, 2, 2, 4, 4) - conf.net_enc_channel_mult = (1, 1, 2, 2, 4, 4, 4) - conf.eval_every_samples = 10_000_000 - conf.eval_ema_every_samples = 10_000_000 - conf.total_samples = 200_000_000 - conf.batch_size = 64 - conf.make_model_conf() - conf.name = 'ffhq256_autoenc_eco' - return conf - - -def ffhq128_ddpm_72M(): - conf = ffhq128_ddpm() - conf.total_samples = 72_000_000 - conf.name = 'ffhq128_ddpm_72M' - return conf - - -def ffhq128_autoenc_72M(): - conf = ffhq128_autoenc_base() - conf.total_samples = 72_000_000 - conf.name = 'ffhq128_autoenc_72M' - return conf - - -def ffhq128_ddpm_130M(): - conf = ffhq128_ddpm() - conf.total_samples = 130_000_000 - conf.eval_ema_every_samples = 10_000_000 - conf.eval_every_samples = 10_000_000 - conf.name = 'ffhq128_ddpm_130M' - return conf - - -def ffhq128_autoenc_130M(): - conf = ffhq128_autoenc_base() - conf.total_samples = 130_000_000 - conf.eval_ema_every_samples = 10_000_000 - conf.eval_every_samples = 10_000_000 - conf.name = 'ffhq128_autoenc_130M' - return conf - - -def horse128_ddpm(): - conf = ffhq128_ddpm() - conf.data_name = 'horse256' - conf.total_samples = 130_000_000 - conf.eval_ema_every_samples = 10_000_000 - conf.eval_every_samples = 10_000_000 - conf.name = 'horse128_ddpm' - return conf - - -def horse128_autoenc(): - conf = ffhq128_autoenc_base() - conf.data_name = 'horse256' - conf.total_samples = 130_000_000 - conf.eval_ema_every_samples = 10_000_000 - conf.eval_every_samples = 10_000_000 - conf.name = 'horse128_autoenc' - return conf - - -def bedroom128_ddpm(): - conf = ffhq128_ddpm() - conf.data_name = 'bedroom256' - conf.eval_ema_every_samples = 10_000_000 - conf.eval_every_samples = 10_000_000 - conf.total_samples = 120_000_000 - conf.name = 'bedroom128_ddpm' - return conf - - -def bedroom128_autoenc(): - conf = ffhq128_autoenc_base() - conf.data_name = 'bedroom256' - conf.eval_ema_every_samples = 10_000_000 - conf.eval_every_samples = 10_000_000 - conf.total_samples = 120_000_000 - conf.name = 'bedroom128_autoenc' - return conf - - -def pretrain_celeba64d2c_72M(): - conf = celeba64d2c_autoenc() - conf.pretrain = PretrainConfig( - name='72M', - path=f'checkpoints/{celeba64d2c_autoenc().name}/last.ckpt', - ) - conf.latent_infer_path = f'checkpoints/{celeba64d2c_autoenc().name}/latent.pkl' - return conf - - -def pretrain_ffhq128_autoenc72M(): - conf = ffhq128_autoenc_base() - conf.postfix = '' - conf.pretrain = PretrainConfig( - name='72M', - path=f'checkpoints/{ffhq128_autoenc_72M().name}/last.ckpt', - ) - conf.latent_infer_path = f'checkpoints/{ffhq128_autoenc_72M().name}/latent.pkl' - return conf - - -def pretrain_ffhq128_autoenc130M(): - conf = ffhq128_autoenc_base() - conf.pretrain = PretrainConfig( - name='130M', - path=f'checkpoints/{ffhq128_autoenc_130M().name}/last.ckpt', - ) - conf.latent_infer_path = f'checkpoints/{ffhq128_autoenc_130M().name}/latent.pkl' - return conf - - -def pretrain_ffhq256_autoenc(): - conf = ffhq256_autoenc() - conf.pretrain = PretrainConfig( - name='90M', - path=f'checkpoints/{ffhq256_autoenc().name}/last.ckpt', - ) - conf.latent_infer_path = f'checkpoints/{ffhq256_autoenc().name}/latent.pkl' - return conf - - -def pretrain_horse128(): - conf = horse128_autoenc() - conf.pretrain = PretrainConfig( - name='82M', - path=f'checkpoints/{horse128_autoenc().name}/last.ckpt', - ) - conf.latent_infer_path = f'checkpoints/{horse128_autoenc().name}/latent.pkl' - return conf - - -def pretrain_bedroom128(): - conf = bedroom128_autoenc() - conf.pretrain = PretrainConfig( - name='120M', - path=f'checkpoints/{bedroom128_autoenc().name}/last.ckpt', - ) - conf.latent_infer_path = f'checkpoints/{bedroom128_autoenc().name}/latent.pkl' - return conf diff --git a/source/templates_cls.py b/source/templates_cls.py deleted file mode 100644 index 4fc3bcf..0000000 --- a/source/templates_cls.py +++ /dev/null @@ -1,38 +0,0 @@ -from templates import * - - -def ffhq128_autoenc_cls(): - conf = ffhq128_autoenc_130M() - conf.train_mode = TrainMode.manipulate - conf.manipulate_mode = ManipulateMode.celebahq_all - conf.manipulate_znormalize = True - conf.latent_infer_path = f'checkpoints/{ffhq128_autoenc_130M().name}/latent.pkl' - conf.batch_size = 32 - conf.lr = 1e-3 - conf.total_samples = 300_000 - # use the pretraining trick instead of contiuning trick - conf.pretrain = PretrainConfig( - '130M', - f'checkpoints/{ffhq128_autoenc_130M().name}/last.ckpt', - ) - conf.name = 'ffhq128_autoenc_cls' - return conf - - -def ffhq256_autoenc_cls(): - '''We first train the encoder on FFHQ dataset then use it as a pretrained to train a linear classifer on CelebA dataset with attribute labels''' - conf = ffhq256_autoenc() - conf.train_mode = TrainMode.manipulate - conf.manipulate_mode = ManipulateMode.celebahq_all - conf.manipulate_znormalize = True - conf.latent_infer_path = f'checkpoints/{ffhq256_autoenc().name}/latent.pkl' # we train on Celeb dataset, not FFHQ - conf.batch_size = 32 - conf.lr = 1e-3 - conf.total_samples = 300_000 - # use the pretraining trick instead of contiuning trick - conf.pretrain = PretrainConfig( - '130M', - f'checkpoints/{ffhq256_autoenc().name}/last.ckpt', - ) - conf.name = 'ffhq256_autoenc_cls' - return conf diff --git a/source/templates_latent.py b/source/templates_latent.py deleted file mode 100644 index b82c257..0000000 --- a/source/templates_latent.py +++ /dev/null @@ -1,150 +0,0 @@ -from templates import * - - -def latent_diffusion_config(conf: TrainConfig): - conf.batch_size = 128 - conf.train_mode = TrainMode.latent_diffusion - conf.latent_gen_type = GenerativeType.ddim - conf.latent_loss_type = LossType.mse - conf.latent_model_mean_type = ModelMeanType.eps - conf.latent_model_var_type = ModelVarType.fixed_large - conf.latent_rescale_timesteps = False - conf.latent_clip_sample = False - conf.latent_T_eval = 20 - conf.latent_znormalize = True - conf.total_samples = 96_000_000 - conf.sample_every_samples = 400_000 - conf.eval_every_samples = 20_000_000 - conf.eval_ema_every_samples = 20_000_000 - conf.save_every_samples = 2_000_000 - return conf - - -def latent_diffusion128_config(conf: TrainConfig): - conf = latent_diffusion_config(conf) - conf.batch_size_eval = 32 - return conf - - -def latent_mlp_2048_norm_10layers(conf: TrainConfig): - conf.net_latent_net_type = LatentNetType.skip - conf.net_latent_layers = 10 - conf.net_latent_skip_layers = list(range(1, conf.net_latent_layers)) - conf.net_latent_activation = Activation.silu - conf.net_latent_num_hid_channels = 2048 - conf.net_latent_use_norm = True - conf.net_latent_condition_bias = 1 - return conf - - -def latent_mlp_2048_norm_20layers(conf: TrainConfig): - conf = latent_mlp_2048_norm_10layers(conf) - conf.net_latent_layers = 20 - conf.net_latent_skip_layers = list(range(1, conf.net_latent_layers)) - return conf - - -def latent_256_batch_size(conf: TrainConfig): - conf.batch_size = 256 - conf.eval_ema_every_samples = 100_000_000 - conf.eval_every_samples = 100_000_000 - conf.sample_every_samples = 1_000_000 - conf.save_every_samples = 2_000_000 - conf.total_samples = 301_000_000 - return conf - - -def latent_512_batch_size(conf: TrainConfig): - conf.batch_size = 512 - conf.eval_ema_every_samples = 100_000_000 - conf.eval_every_samples = 100_000_000 - conf.sample_every_samples = 1_000_000 - conf.save_every_samples = 5_000_000 - conf.total_samples = 501_000_000 - return conf - - -def latent_2048_batch_size(conf: TrainConfig): - conf.batch_size = 2048 - conf.eval_ema_every_samples = 200_000_000 - conf.eval_every_samples = 200_000_000 - conf.sample_every_samples = 4_000_000 - conf.save_every_samples = 20_000_000 - conf.total_samples = 1_501_000_000 - return conf - - -def adamw_weight_decay(conf: TrainConfig): - conf.optimizer = OptimizerType.adamw - conf.weight_decay = 0.01 - return conf - - -def ffhq128_autoenc_latent(): - conf = pretrain_ffhq128_autoenc130M() - conf = latent_diffusion128_config(conf) - conf = latent_mlp_2048_norm_10layers(conf) - conf = latent_256_batch_size(conf) - conf = adamw_weight_decay(conf) - conf.total_samples = 101_000_000 - conf.latent_loss_type = LossType.l1 - conf.latent_beta_scheduler = 'const0.008' - conf.name = 'ffhq128_autoenc_latent' - return conf - - -def ffhq256_autoenc_latent(): - conf = pretrain_ffhq256_autoenc() - conf = latent_diffusion128_config(conf) - conf = latent_mlp_2048_norm_10layers(conf) - conf = latent_256_batch_size(conf) - conf = adamw_weight_decay(conf) - conf.total_samples = 101_000_000 - conf.latent_loss_type = LossType.l1 - conf.latent_beta_scheduler = 'const0.008' - conf.eval_ema_every_samples = 200_000_000 - conf.eval_every_samples = 200_000_000 - conf.sample_every_samples = 4_000_000 - conf.name = 'ffhq256_autoenc_latent' - return conf - - -def horse128_autoenc_latent(): - conf = pretrain_horse128() - conf = latent_diffusion128_config(conf) - conf = latent_2048_batch_size(conf) - conf = latent_mlp_2048_norm_20layers(conf) - conf.total_samples = 2_001_000_000 - conf.latent_beta_scheduler = 'const0.008' - conf.latent_loss_type = LossType.l1 - conf.name = 'horse128_autoenc_latent' - return conf - - -def bedroom128_autoenc_latent(): - conf = pretrain_bedroom128() - conf = latent_diffusion128_config(conf) - conf = latent_2048_batch_size(conf) - conf = latent_mlp_2048_norm_20layers(conf) - conf.total_samples = 2_001_000_000 - conf.latent_beta_scheduler = 'const0.008' - conf.latent_loss_type = LossType.l1 - conf.name = 'bedroom128_autoenc_latent' - return conf - - -def celeba64d2c_autoenc_latent(): - conf = pretrain_celeba64d2c_72M() - conf = latent_diffusion_config(conf) - conf = latent_512_batch_size(conf) - conf = latent_mlp_2048_norm_10layers(conf) - conf = adamw_weight_decay(conf) - # just for the name - conf.continue_from = PretrainConfig('200M', - f'log-latent/{conf.name}/last.ckpt') - conf.postfix = '_300M' - conf.total_samples = 301_000_000 - conf.latent_beta_scheduler = 'const0.008' - conf.latent_loss_type = LossType.l1 - conf.name = 'celeba64d2c_autoenc_latent' - return conf