diff --git a/basics/base_task.py b/basics/base_task.py index 059440bdb..4e3af855c 100644 --- a/basics/base_task.py +++ b/basics/base_task.py @@ -88,16 +88,40 @@ def setup(self, stage): self.phone_encoder = self.build_phone_encoder() self.model = self.build_model() # utils.load_warp(self) + self.unfreeze_all_params() + if hparams['freezing_enabled']: + self.freeze_params() if hparams['finetune_enabled'] and get_latest_checkpoint_path(pathlib.Path(hparams['work_dir'])) is None: - self.load_finetune_ckpt( self.load_pre_train_model()) + self.load_finetune_ckpt(self.load_pre_train_model()) self.print_arch() self.build_losses() self.train_dataset = self.dataset_cls(hparams['train_set_name']) self.valid_dataset = self.dataset_cls(hparams['valid_set_name']) + def get_need_freeze_state_dict_key(self, model_state_dict) -> list: + key_list = [] + for i in hparams['frozen_params']: + for j in model_state_dict: + if j.startswith(i): + key_list.append(j) + return list(set(key_list)) + + def freeze_params(self) -> None: + model_state_dict = self.state_dict().keys() + freeze_key = self.get_need_freeze_state_dict_key(model_state_dict=model_state_dict) + + for i in freeze_key: + params=self.get_parameter(i) + + params.requires_grad = False + + def unfreeze_all_params(self) -> None: + for i in self.model.parameters(): + i.requires_grad = True + def load_finetune_ckpt( self, state_dict - ): + ) -> None: adapt_shapes = hparams['finetune_strict_shapes'] if not adapt_shapes: @@ -276,7 +300,7 @@ def build_optimizer(self, model): optimizer_args['betas'] = (optimizer_args['beta1'], optimizer_args['beta2']) optimizer = build_object_from_config( optimizer_args['optimizer_cls'], - filter(lambda p: p.requires_grad, model.parameters()), + model.parameters(), **optimizer_args ) return optimizer diff --git a/configs/acoustic.yaml b/configs/acoustic.yaml index 368a7b5af..88ae1b12b 100644 --- a/configs/acoustic.yaml +++ b/configs/acoustic.yaml @@ -102,3 +102,6 @@ finetune_ignored_params: - model.fs2.txt_embed - model.fs2.spk_embed finetune_strict_shapes: true + +freezing_enabled: false +frozen_params: [] diff --git a/configs/base.yaml b/configs/base.yaml index bc570e46a..4561b3325 100644 --- a/configs/base.yaml +++ b/configs/base.yaml @@ -100,3 +100,6 @@ finetune_ignored_params: [] finetune_strict_shapes: true + +freezing_enabled: false +frozen_params: [] diff --git a/configs/variance.yaml b/configs/variance.yaml index d4dc91a3b..b23c2a5a8 100644 --- a/configs/variance.yaml +++ b/configs/variance.yaml @@ -111,3 +111,6 @@ finetune_ignored_params: - model.fs2.txt_embed - model.fs2.encoder.embed_tokens finetune_strict_shapes: true + +freezing_enabled: false +frozen_params: [] diff --git a/docs/ConfigurationSchemas.md b/docs/ConfigurationSchemas.md index 417350a45..b37423400 100644 --- a/docs/ConfigurationSchemas.md +++ b/docs/ConfigurationSchemas.md @@ -1422,6 +1422,54 @@ int 16000 +### freezing_enabled + +Whether enabling parameter freezing during training. + +#### visibility + +all + +#### scope + +training + +#### customizability + +normal + +#### type + +bool + +#### default + +False + +### frozen_params + +Parameter name prefixes to freeze during training. + +#### visibility + +all + +#### scope + +training + +#### customizability + +normal + +#### type + +list + +#### default + +[] + ### fmin Minimum frequency of mel extraction. @@ -3416,4 +3464,3 @@ int 2048 -