From c2086ef13c3bf78c723f0f9fede31a27fc3f10db Mon Sep 17 00:00:00 2001 From: digger yu Date: Wed, 3 Apr 2024 11:49:01 +0800 Subject: [PATCH] fix typo s/get_defualt_parser /get_default_parser --- examples/language/grok-1/inference.py | 4 ++-- examples/language/grok-1/inference_tp.py | 4 ++-- examples/language/grok-1/utils.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/language/grok-1/inference.py b/examples/language/grok-1/inference.py index faef7ae9d7ca..58ba3872f856 100644 --- a/examples/language/grok-1/inference.py +++ b/examples/language/grok-1/inference.py @@ -2,10 +2,10 @@ import torch from transformers import AutoModelForCausalLM, AutoTokenizer -from utils import get_defualt_parser, inference, print_output +from utils import get_default_parser, inference, print_output if __name__ == "__main__": - parser = get_defualt_parser() + parser = get_default_parser() args = parser.parse_args() start = time.time() torch.set_default_dtype(torch.bfloat16) diff --git a/examples/language/grok-1/inference_tp.py b/examples/language/grok-1/inference_tp.py index cf05880dc21d..e10c4929cdbf 100644 --- a/examples/language/grok-1/inference_tp.py +++ b/examples/language/grok-1/inference_tp.py @@ -3,7 +3,7 @@ import torch from grok1_policy import Grok1ForCausalLMPolicy from transformers import AutoModelForCausalLM, AutoTokenizer -from utils import get_defualt_parser, inference, print_output +from utils import get_default_parser, inference, print_output import colossalai from colossalai.booster import Booster @@ -13,7 +13,7 @@ from colossalai.utils import get_current_device if __name__ == "__main__": - parser = get_defualt_parser() + parser = get_default_parser() args = parser.parse_args() start = time.time() colossalai.launch_from_torch({}) diff --git a/examples/language/grok-1/utils.py b/examples/language/grok-1/utils.py index 7663127a5515..29c86e411db1 100644 --- a/examples/language/grok-1/utils.py +++ b/examples/language/grok-1/utils.py @@ -33,7 +33,7 @@ def inference(model, tokenizer, text, **generate_kwargs): return outputs[0].tolist() -def get_defualt_parser(): +def get_default_parser(): parser = argparse.ArgumentParser() parser.add_argument("--pretrained", type=str, default="hpcaitech/grok-1") parser.add_argument("--tokenizer", type=str, default="tokenizer.model")