From 5fdc3aa76aa2afcf43549e4f2f35bfbaa8b25715 Mon Sep 17 00:00:00 2001 From: autumn <2> Date: Thu, 20 Jul 2023 09:03:22 +0800 Subject: [PATCH 1/5] fix some --- basics/base_task.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/basics/base_task.py b/basics/base_task.py index 059440bdb..6f3da7bd8 100644 --- a/basics/base_task.py +++ b/basics/base_task.py @@ -276,7 +276,7 @@ def build_optimizer(self, model): optimizer_args['betas'] = (optimizer_args['beta1'], optimizer_args['beta2']) optimizer = build_object_from_config( optimizer_args['optimizer_cls'], - filter(lambda p: p.requires_grad, model.parameters()), + model.parameters(), **optimizer_args ) return optimizer From ce6a7593e706189bc38dbce7c669e3d0b88943d4 Mon Sep 17 00:00:00 2001 From: autumn <2> Date: Fri, 21 Jul 2023 16:52:35 +0800 Subject: [PATCH 2/5] support freeze model params --- basics/base_task.py | 40 +++++++++++++++++++++++++++++++++++++--- configs/acoustic.yaml | 3 +++ configs/base.yaml | 3 +++ configs/variance.yaml | 3 +++ 4 files changed, 46 insertions(+), 3 deletions(-) diff --git a/basics/base_task.py b/basics/base_task.py index 6f3da7bd8..67a8b3cad 100644 --- a/basics/base_task.py +++ b/basics/base_task.py @@ -88,16 +88,50 @@ def setup(self, stage): self.phone_encoder = self.build_phone_encoder() self.model = self.build_model() # utils.load_warp(self) + self.unfreeze_all_params() + if hparams['freezed_params_on_train_enabled']: + self.freeze_params() if hparams['finetune_enabled'] and get_latest_checkpoint_path(pathlib.Path(hparams['work_dir'])) is None: - self.load_finetune_ckpt( self.load_pre_train_model()) + self.load_finetune_ckpt(self.load_pre_train_model()) self.print_arch() self.build_losses() self.train_dataset = self.dataset_cls(hparams['train_set_name']) self.valid_dataset = self.dataset_cls(hparams['valid_set_name']) + def get_need_freeze_state_dict_key(self, model_state_dict) -> list: + key_list = [] + for i in hparams['freezed_params']: + for j in model_state_dict: + if j.startswith(i): + key_list.append(j) + return list(set(key_list)) + + def freeze_params(self) -> None: + model_state_dict = self.state_dict().keys() + freeze_key = self.get_need_freeze_state_dict_key(model_state_dict=model_state_dict) + for i in freeze_key: + i: str + key_s = i.split('.') + key_str = '' + for j in key_s: + try: + int(j) + key_str = key_str + '[' + j + ']' + except: + key_str = key_str + '.' + j + + key_str = key_str.strip('.') + + objects = eval(f'self.{key_str} ') + objects.requires_grad = False + + def unfreeze_all_params(self) -> None: + for i in self.model.parameters(): + i.requires_grad = True + def load_finetune_ckpt( self, state_dict - ): + ) -> None: adapt_shapes = hparams['finetune_strict_shapes'] if not adapt_shapes: @@ -276,7 +310,7 @@ def build_optimizer(self, model): optimizer_args['betas'] = (optimizer_args['beta1'], optimizer_args['beta2']) optimizer = build_object_from_config( optimizer_args['optimizer_cls'], - model.parameters(), + model.parameters(), **optimizer_args ) return optimizer diff --git a/configs/acoustic.yaml b/configs/acoustic.yaml index 368a7b5af..f095c693c 100644 --- a/configs/acoustic.yaml +++ b/configs/acoustic.yaml @@ -102,3 +102,6 @@ finetune_ignored_params: - model.fs2.txt_embed - model.fs2.spk_embed finetune_strict_shapes: true + +freezed_params_on_train_enabled: false +freezed_params: [] diff --git a/configs/base.yaml b/configs/base.yaml index bc570e46a..395659b5d 100644 --- a/configs/base.yaml +++ b/configs/base.yaml @@ -100,3 +100,6 @@ finetune_ignored_params: [] finetune_strict_shapes: true + +freezed_params_on_train_enabled: false +freezed_params: [] diff --git a/configs/variance.yaml b/configs/variance.yaml index d4dc91a3b..675ba5d19 100644 --- a/configs/variance.yaml +++ b/configs/variance.yaml @@ -111,3 +111,6 @@ finetune_ignored_params: - model.fs2.txt_embed - model.fs2.encoder.embed_tokens finetune_strict_shapes: true + +freezed_params_on_train_enabled: false +freezed_params: [] From ad2f4c7905b7650658b929f48626382addd302ba Mon Sep 17 00:00:00 2001 From: autumn <2> Date: Fri, 21 Jul 2023 17:43:54 +0800 Subject: [PATCH 3/5] support freeze model params --- basics/base_task.py | 4 ++-- configs/acoustic.yaml | 4 ++-- configs/base.yaml | 4 ++-- configs/variance.yaml | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/basics/base_task.py b/basics/base_task.py index 67a8b3cad..f054febbb 100644 --- a/basics/base_task.py +++ b/basics/base_task.py @@ -89,7 +89,7 @@ def setup(self, stage): self.model = self.build_model() # utils.load_warp(self) self.unfreeze_all_params() - if hparams['freezed_params_on_train_enabled']: + if hparams['frozen_params_on_train_enabled']: self.freeze_params() if hparams['finetune_enabled'] and get_latest_checkpoint_path(pathlib.Path(hparams['work_dir'])) is None: self.load_finetune_ckpt(self.load_pre_train_model()) @@ -100,7 +100,7 @@ def setup(self, stage): def get_need_freeze_state_dict_key(self, model_state_dict) -> list: key_list = [] - for i in hparams['freezed_params']: + for i in hparams['frozen_params']: for j in model_state_dict: if j.startswith(i): key_list.append(j) diff --git a/configs/acoustic.yaml b/configs/acoustic.yaml index f095c693c..1453936b9 100644 --- a/configs/acoustic.yaml +++ b/configs/acoustic.yaml @@ -103,5 +103,5 @@ finetune_ignored_params: - model.fs2.spk_embed finetune_strict_shapes: true -freezed_params_on_train_enabled: false -freezed_params: [] +frozen_params_on_train_enabled: false +frozen_params: [] diff --git a/configs/base.yaml b/configs/base.yaml index 395659b5d..4a3ba2a7d 100644 --- a/configs/base.yaml +++ b/configs/base.yaml @@ -101,5 +101,5 @@ finetune_ignored_params: [] finetune_strict_shapes: true -freezed_params_on_train_enabled: false -freezed_params: [] +frozen_params_on_train_enabled: false +frozen_params: [] diff --git a/configs/variance.yaml b/configs/variance.yaml index 675ba5d19..ec74e1e28 100644 --- a/configs/variance.yaml +++ b/configs/variance.yaml @@ -112,5 +112,5 @@ finetune_ignored_params: - model.fs2.encoder.embed_tokens finetune_strict_shapes: true -freezed_params_on_train_enabled: false -freezed_params: [] +frozen_params_on_train_enabled: false +frozen_params: [] From 9688d31678a92c4caa9e7123e15aa761fab388f7 Mon Sep 17 00:00:00 2001 From: autumn <2> Date: Fri, 21 Jul 2023 18:33:47 +0800 Subject: [PATCH 4/5] support freeze model params --- basics/base_task.py | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/basics/base_task.py b/basics/base_task.py index f054febbb..c1c498b7e 100644 --- a/basics/base_task.py +++ b/basics/base_task.py @@ -109,21 +109,11 @@ def get_need_freeze_state_dict_key(self, model_state_dict) -> list: def freeze_params(self) -> None: model_state_dict = self.state_dict().keys() freeze_key = self.get_need_freeze_state_dict_key(model_state_dict=model_state_dict) + for i in freeze_key: - i: str - key_s = i.split('.') - key_str = '' - for j in key_s: - try: - int(j) - key_str = key_str + '[' + j + ']' - except: - key_str = key_str + '.' + j - - key_str = key_str.strip('.') - - objects = eval(f'self.{key_str} ') - objects.requires_grad = False + params=self.get_parameter(i) + + params.requires_grad = False def unfreeze_all_params(self) -> None: for i in self.model.parameters(): From 9d35cbcf1ea2b0dcc2bfc5481808a15ec51a726f Mon Sep 17 00:00:00 2001 From: yqzhishen Date: Fri, 4 Aug 2023 19:39:26 +0800 Subject: [PATCH 5/5] Rename parameter and add docs --- basics/base_task.py | 2 +- configs/acoustic.yaml | 2 +- configs/base.yaml | 2 +- configs/variance.yaml | 2 +- docs/ConfigurationSchemas.md | 49 +++++++++++++++++++++++++++++++++++- 5 files changed, 52 insertions(+), 5 deletions(-) diff --git a/basics/base_task.py b/basics/base_task.py index c1c498b7e..4e3af855c 100644 --- a/basics/base_task.py +++ b/basics/base_task.py @@ -89,7 +89,7 @@ def setup(self, stage): self.model = self.build_model() # utils.load_warp(self) self.unfreeze_all_params() - if hparams['frozen_params_on_train_enabled']: + if hparams['freezing_enabled']: self.freeze_params() if hparams['finetune_enabled'] and get_latest_checkpoint_path(pathlib.Path(hparams['work_dir'])) is None: self.load_finetune_ckpt(self.load_pre_train_model()) diff --git a/configs/acoustic.yaml b/configs/acoustic.yaml index 1453936b9..88ae1b12b 100644 --- a/configs/acoustic.yaml +++ b/configs/acoustic.yaml @@ -103,5 +103,5 @@ finetune_ignored_params: - model.fs2.spk_embed finetune_strict_shapes: true -frozen_params_on_train_enabled: false +freezing_enabled: false frozen_params: [] diff --git a/configs/base.yaml b/configs/base.yaml index 4a3ba2a7d..4561b3325 100644 --- a/configs/base.yaml +++ b/configs/base.yaml @@ -101,5 +101,5 @@ finetune_ignored_params: [] finetune_strict_shapes: true -frozen_params_on_train_enabled: false +freezing_enabled: false frozen_params: [] diff --git a/configs/variance.yaml b/configs/variance.yaml index ec74e1e28..b23c2a5a8 100644 --- a/configs/variance.yaml +++ b/configs/variance.yaml @@ -112,5 +112,5 @@ finetune_ignored_params: - model.fs2.encoder.embed_tokens finetune_strict_shapes: true -frozen_params_on_train_enabled: false +freezing_enabled: false frozen_params: [] diff --git a/docs/ConfigurationSchemas.md b/docs/ConfigurationSchemas.md index 417350a45..b37423400 100644 --- a/docs/ConfigurationSchemas.md +++ b/docs/ConfigurationSchemas.md @@ -1422,6 +1422,54 @@ int 16000 +### freezing_enabled + +Whether enabling parameter freezing during training. + +#### visibility + +all + +#### scope + +training + +#### customizability + +normal + +#### type + +bool + +#### default + +False + +### frozen_params + +Parameter name prefixes to freeze during training. + +#### visibility + +all + +#### scope + +training + +#### customizability + +normal + +#### type + +list + +#### default + +[] + ### fmin Minimum frequency of mel extraction. @@ -3416,4 +3464,3 @@ int 2048 -