Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions tests/kit/model_zoo/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def register(
"""
self[name] = (model_fn, data_gen_fn, output_transform_fn, loss_fn, model_attribute)

def get_sub_registry(self, keyword: Union[str, List[str]]):
def get_sub_registry(self, keyword: Union[str, List[str]], exclude: Union[str, List[str]] = None):
"""
Get a sub registry with models that contain the keyword.

Expand All @@ -76,10 +76,24 @@ def get_sub_registry(self, keyword: Union[str, List[str]]):
keyword_list = keyword
assert isinstance(keyword_list, (list, tuple))

if exclude is None:
exclude_keywords = []
elif isinstance(exclude, str):
exclude_keywords = [exclude]
else:
exclude_keywords = exclude
assert isinstance(exclude_keywords, (list, tuple))

for k, v in self.items():
for kw in keyword_list:
if kw in k:
new_dict[k] = v
should_exclude = False
for ex_kw in exclude_keywords:
if ex_kw in k:
should_exclude = True

if not should_exclude:
new_dict[k] = v

assert len(new_dict) > 0, f"No model found with keyword {keyword}"
return new_dict
Expand Down
2 changes: 1 addition & 1 deletion tests/test_shardformer/test_with_torch_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

@parameterize("lazy_init", [True, False])
def check_shardformer_with_ddp(lazy_init: bool):
sub_model_zoo = model_zoo.get_sub_registry("transformers_gpt")
sub_model_zoo = model_zoo.get_sub_registry("transformers_gpt", exclude="transformers_gptj")

# create shardformer
# ranks: [0, 1, 2, 3]
Expand Down