diff --git a/basics/base_task.py b/basics/base_task.py index faf18812e..a9f471596 100644 --- a/basics/base_task.py +++ b/basics/base_task.py @@ -87,11 +87,70 @@ def __init__(self, *args, **kwargs): def setup(self, stage): self.phone_encoder = self.build_phone_encoder() self.model = self.build_model() + # utils.load_warp(self) + if hparams['finetune_enabled'] and get_latest_checkpoint_path(pathlib.Path(hparams['work_dir'])) is None: + self.load_finetune_ckpt( self.load_pre_train_model()) self.print_arch() self.build_losses() self.train_dataset = self.dataset_cls(hparams['train_set_name']) self.valid_dataset = self.dataset_cls(hparams['valid_set_name']) + def load_finetune_ckpt( + self, state_dict + ): + + adapt_shapes = hparams['finetune_strict_shapes'] + if not adapt_shapes: + cur_model_state_dict = self.state_dict() + unmatched_keys = [] + for key, param in state_dict.items(): + if key in cur_model_state_dict: + new_param = cur_model_state_dict[key] + if new_param.shape != param.shape: + unmatched_keys.append(key) + print('| Unmatched keys: ', key, new_param.shape, param.shape) + for key in unmatched_keys: + del state_dict[key] + self.load_state_dict(state_dict, strict=False) + + def load_pre_train_model(self): + + pre_train_ckpt_path = hparams.get('finetune_ckpt_path') + blacklist = hparams.get('finetune_ignored_params') + # whitelist=hparams.get('pre_train_whitelist') + if blacklist is None: + blacklist = [] + # if whitelist is None: + # raise RuntimeError("") + + if pre_train_ckpt_path is not None: + ckpt = torch.load(pre_train_ckpt_path) + # if ckpt.get('category') is None: + # raise RuntimeError("") + + if isinstance(self.model, CategorizedModule): + self.model.check_category(ckpt.get('category')) + + state_dict = {} + for i in ckpt['state_dict']: + # if 'diffusion' in i: + # if i in rrrr: + # continue + skip = False + for b in blacklist: + if i.startswith(b): + skip = True + break + + if skip: + continue + + state_dict[i] = ckpt['state_dict'][i] + print(i) + return state_dict + else: + raise RuntimeError("") + @staticmethod def build_phone_encoder(): phone_list = build_phoneme_list() @@ -291,6 +350,11 @@ def on_test_end(self): def start(cls): pl.seed_everything(hparams['seed'], workers=True) task = cls() + + # if pre_train is not None: + # task.load_state_dict(pre_train,strict=False) + # print("load success-------------------------------------------------------------------") + work_dir = pathlib.Path(hparams['work_dir']) trainer = pl.Trainer( accelerator=hparams['pl_trainer_accelerator'], @@ -378,16 +442,16 @@ def on_load_checkpoint(self, checkpoint): from utils import simulate_lr_scheduler if checkpoint.get('trainer_stage', '') == RunningStage.VALIDATING.value: self.skip_immediate_validation = True - + optimizer_args = hparams['optimizer_args'] scheduler_args = hparams['lr_scheduler_args'] - + if 'beta1' in optimizer_args and 'beta2' in optimizer_args and 'betas' not in optimizer_args: optimizer_args['betas'] = (optimizer_args['beta1'], optimizer_args['beta2']) if checkpoint.get('optimizer_states', None): opt_states = checkpoint['optimizer_states'] - assert len(opt_states) == 1 # only support one optimizer + assert len(opt_states) == 1 # only support one optimizer opt_state = opt_states[0] for param_group in opt_state['param_groups']: for k, v in optimizer_args.items(): @@ -397,13 +461,14 @@ def on_load_checkpoint(self, checkpoint): rank_zero_info(f'| Overriding optimizer parameter {k} from checkpoint: {param_group[k]} -> {v}') param_group[k] = v if 'initial_lr' in param_group and param_group['initial_lr'] != optimizer_args['lr']: - rank_zero_info(f'| Overriding optimizer parameter initial_lr from checkpoint: {param_group["initial_lr"]} -> {optimizer_args["lr"]}') + rank_zero_info( + f'| Overriding optimizer parameter initial_lr from checkpoint: {param_group["initial_lr"]} -> {optimizer_args["lr"]}') param_group['initial_lr'] = optimizer_args['lr'] if checkpoint.get('lr_schedulers', None): assert checkpoint.get('optimizer_states', False) schedulers = checkpoint['lr_schedulers'] - assert len(schedulers) == 1 # only support one scheduler + assert len(schedulers) == 1 # only support one scheduler scheduler = schedulers[0] for k, v in scheduler_args.items(): if k in scheduler and scheduler[k] != v: @@ -418,5 +483,6 @@ def on_load_checkpoint(self, checkpoint): scheduler['_last_lr'] = new_lrs for param_group, new_lr in zip(checkpoint['optimizer_states'][0]['param_groups'], new_lrs): if param_group['lr'] != new_lr: - rank_zero_info(f'| Overriding optimizer parameter lr from checkpoint: {param_group["lr"]} -> {new_lr}') + rank_zero_info( + f'| Overriding optimizer parameter lr from checkpoint: {param_group["lr"]} -> {new_lr}') param_group['lr'] = new_lr diff --git a/configs/acoustic.yaml b/configs/acoustic.yaml index 66b79c5ea..368a7b5af 100644 --- a/configs/acoustic.yaml +++ b/configs/acoustic.yaml @@ -92,3 +92,13 @@ max_updates: 320000 num_ckpt_keep: 5 permanent_ckpt_start: 200000 permanent_ckpt_interval: 40000 + + +finetune_enabled: false +finetune_ckpt_path: null + +finetune_ignored_params: + - model.fs2.encoder.embed_tokens + - model.fs2.txt_embed + - model.fs2.spk_embed +finetune_strict_shapes: true diff --git a/configs/base.yaml b/configs/base.yaml index 7d9f37325..bc570e46a 100644 --- a/configs/base.yaml +++ b/configs/base.yaml @@ -89,3 +89,14 @@ pl_trainer_precision: '32-true' pl_trainer_num_nodes: 1 pl_trainer_strategy: 'auto' ddp_backend: 'nccl' # choose from 'gloo', 'nccl', 'nccl_no_p2p' + +########### +# finetune +########### + +finetune_enabled: false +finetune_ckpt_path: null +finetune_ignored_params: [] + + +finetune_strict_shapes: true diff --git a/configs/variance.yaml b/configs/variance.yaml index ebfae4091..d4dc91a3b 100644 --- a/configs/variance.yaml +++ b/configs/variance.yaml @@ -103,3 +103,11 @@ max_updates: 288000 num_ckpt_keep: 5 permanent_ckpt_start: 180000 permanent_ckpt_interval: 10000 + +finetune_enabled: false +finetune_ckpt_path: null +finetune_ignored_params: + - model.spk_embed + - model.fs2.txt_embed + - model.fs2.encoder.embed_tokens +finetune_strict_shapes: true diff --git a/docs/ConfigurationSchemas.md b/docs/ConfigurationSchemas.md index de5fdb523..417350a45 100644 --- a/docs/ConfigurationSchemas.md +++ b/docs/ConfigurationSchemas.md @@ -1306,6 +1306,98 @@ int 2048 +### finetune_enabled + +Whether to finetune from a pretrained model. + +#### visibility + +all + +#### scope + +training + +#### customizability + +normal + +#### type + +bool + +#### default + +False + +### finetune_ckpt_path + +Path to the pretrained model for finetuning. + +#### visibility + +all + +#### scope + +training + +#### customizability + +normal + +#### type + +str + +#### default + +null + +### finetune_ignored_params + +Prefixes of parameter key names in the state dict of the pretrained model that need to be dropped before finetuning. + +#### visibility + +all + +#### scope + +training + +#### customizability + +normal + +#### type + +list + +### finetune_strict_shapes + +Whether to raise error if the tensor shapes of any parameter of the pretrained model and the target model mismatch. If set to `False`, parameters with mismatching shapes will be skipped. + +#### visibility + +all + +#### scope + +training + +#### customizability + +normal + +#### type + +bool + +#### default + +True + ### fmax Maximum frequency of mel extraction. @@ -3324,3 +3416,4 @@ int 2048 + diff --git a/scripts/train.py b/scripts/train.py index 0ce341cac..1df7b6bc7 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -1,5 +1,6 @@ import importlib import os + import sys from pathlib import Path @@ -22,6 +23,7 @@ def run_task(): pkg = ".".join(hparams["task_cls"].split(".")[:-1]) cls_name = hparams["task_cls"].split(".")[-1] task_cls = getattr(importlib.import_module(pkg), cls_name) + task_cls.start() diff --git a/utils/__init__.py b/utils/__init__.py index e9b1b0178..4c88f1598 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -11,6 +11,8 @@ import torch.nn.functional as F from basics.base_module import CategorizedModule +from utils.hparams import hparams +from utils.training_utils import get_latest_checkpoint_path def tensors_to_scalars(metrics): @@ -149,7 +151,8 @@ def filter_kwargs(dict_to_filter, kwarg_obj): sig = inspect.signature(kwarg_obj) filter_keys = [param.name for param in sig.parameters.values() if param.kind == param.POSITIONAL_OR_KEYWORD] - filtered_dict = {filter_key: dict_to_filter[filter_key] for filter_key in filter_keys if filter_key in dict_to_filter} + filtered_dict = {filter_key: dict_to_filter[filter_key] for filter_key in filter_keys if + filter_key in dict_to_filter} return filtered_dict @@ -208,6 +211,14 @@ def load_ckpt( print(f'| load {shown_model_name} from \'{checkpoint_path}\'.') + + + + + + # return load_pre_train_model() + + def remove_padding(x, padding_idx=0): if x is None: return None @@ -265,7 +276,8 @@ def simulate_lr_scheduler(optimizer_args, scheduler_args, last_epoch=-1, num_par [{'params': torch.nn.Parameter(), 'initial_lr': optimizer_args['lr']} for _ in range(num_param_groups)], **optimizer_args ) - scheduler = build_object_from_config(scheduler_args['scheduler_cls'], optimizer, last_epoch=last_epoch, **scheduler_args) + scheduler = build_object_from_config(scheduler_args['scheduler_cls'], optimizer, last_epoch=last_epoch, + **scheduler_args) if hasattr(scheduler, '_get_closed_form_lr'): return scheduler._get_closed_form_lr()