From f6575b4af8d553801da602e8f1aa2c0ed5818375 Mon Sep 17 00:00:00 2001 From: FrankLeeeee Date: Thu, 11 Jan 2024 08:32:28 +0000 Subject: [PATCH 1/2] [ci] fixed ddp test --- tests/kit/model_zoo/registry.py | 18 ++++++++++++++++-- tests/test_shardformer/test_with_torch_ddp.py | 3 ++- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/tests/kit/model_zoo/registry.py b/tests/kit/model_zoo/registry.py index 44a0adc6a3af..5e8e0b3822df 100644 --- a/tests/kit/model_zoo/registry.py +++ b/tests/kit/model_zoo/registry.py @@ -61,7 +61,7 @@ def register( """ self[name] = (model_fn, data_gen_fn, output_transform_fn, loss_fn, model_attribute) - def get_sub_registry(self, keyword: Union[str, List[str]]): + def get_sub_registry(self, keyword: Union[str, List[str]], exclude: Union[str, List[str]] = None): """ Get a sub registry with models that contain the keyword. @@ -76,10 +76,24 @@ def get_sub_registry(self, keyword: Union[str, List[str]]): keyword_list = keyword assert isinstance(keyword_list, (list, tuple)) + if exclude is None: + exclude_keywords = [] + elif isinstance(exclude, str): + exclude_keywords = [exclude] + else: + exclude_keywords = exclude + assert isinstance(exclude_keywords, (list, tuple)) + for k, v in self.items(): for kw in keyword_list: if kw in k: - new_dict[k] = v + should_exclude = False + for ex_kw in exclude_keywords: + if ex_kw in k: + should_exclude = True + + if not should_exclude: + new_dict[k] = v assert len(new_dict) > 0, f"No model found with keyword {keyword}" return new_dict diff --git a/tests/test_shardformer/test_with_torch_ddp.py b/tests/test_shardformer/test_with_torch_ddp.py index f642a9dcada4..29e9cedfe9e9 100644 --- a/tests/test_shardformer/test_with_torch_ddp.py +++ b/tests/test_shardformer/test_with_torch_ddp.py @@ -16,7 +16,7 @@ @parameterize("lazy_init", [True, False]) def check_shardformer_with_ddp(lazy_init: bool): - sub_model_zoo = model_zoo.get_sub_registry("transformers_gpt") + sub_model_zoo = model_zoo.get_sub_registry("transformers_gpt", exclude="transformers_gptj") # create shardformer # ranks: [0, 1, 2, 3] @@ -45,6 +45,7 @@ def check_shardformer_with_ddp(lazy_init: bool): ctx = LazyInitContext() if lazy_init else nullcontext() for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + print(name) # create and shard model with ctx: model = model_fn().cuda() From 69f76ba481aa858b2eeb6d7180b611109e8ce741 Mon Sep 17 00:00:00 2001 From: FrankLeeeee Date: Thu, 11 Jan 2024 08:43:02 +0000 Subject: [PATCH 2/2] polish --- tests/test_shardformer/test_with_torch_ddp.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_shardformer/test_with_torch_ddp.py b/tests/test_shardformer/test_with_torch_ddp.py index 29e9cedfe9e9..4b741c21b48c 100644 --- a/tests/test_shardformer/test_with_torch_ddp.py +++ b/tests/test_shardformer/test_with_torch_ddp.py @@ -45,7 +45,6 @@ def check_shardformer_with_ddp(lazy_init: bool): ctx = LazyInitContext() if lazy_init else nullcontext() for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - print(name) # create and shard model with ctx: model = model_fn().cuda()