diff --git a/.compatibility b/.compatibility
index a918cb162216..d90a74b584d8 100644
--- a/.compatibility
+++ b/.compatibility
@@ -1,2 +1 @@
-2.0.0-11.7.0
-2.1.0-11.8.0
+2.1.0-12.1.0
diff --git a/.github/workflows/build_on_schedule.yml b/.github/workflows/build_on_schedule.yml
index 510665b46f4b..3ff19b37b4bf 100644
--- a/.github/workflows/build_on_schedule.yml
+++ b/.github/workflows/build_on_schedule.yml
@@ -67,7 +67,6 @@ jobs:
--durations=0 \
tests/
env:
- NCCL_SHM_DISABLE: 1
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
LLAMA_PATH: /data/scratch/llama-tiny
@@ -83,4 +82,4 @@ jobs:
SERVER_URL: ${{github.server_url }}
REPO: ${{ github.repository }}
RUN_ID: ${{ github.run_id }}
- WEBHOOK_URL: ${{ secrets.LARK_NOTIFICATION_WEBHOOK_URL }}
\ No newline at end of file
+ WEBHOOK_URL: ${{ secrets.LARK_NOTIFICATION_WEBHOOK_URL }}
diff --git a/.github/workflows/compatiblity_test_on_dispatch.yml b/.github/workflows/compatiblity_test_on_dispatch.yml
index a6f9582ac901..76493880651c 100644
--- a/.github/workflows/compatiblity_test_on_dispatch.yml
+++ b/.github/workflows/compatiblity_test_on_dispatch.yml
@@ -50,7 +50,7 @@ jobs:
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
container:
image: ${{ matrix.container }}
- options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
+ options: --gpus all --rm -v /dev/shm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
timeout-minutes: 120
steps:
- name: Install dependencies
@@ -87,9 +87,8 @@ jobs:
pip install -r requirements/requirements-test.txt
- name: Unit Testing
run: |
- PYTHONPATH=$PWD pytest tests
+ PYTHONPATH=$PWD pytest --durations=0 tests
env:
DATA: /data/scratch/cifar-10
- NCCL_SHM_DISABLE: 1
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
LLAMA_PATH: /data/scratch/llama-tiny
diff --git a/.github/workflows/compatiblity_test_on_pr.yml b/.github/workflows/compatiblity_test_on_pr.yml
index ede6c380a8ec..f582b30907bf 100644
--- a/.github/workflows/compatiblity_test_on_pr.yml
+++ b/.github/workflows/compatiblity_test_on_pr.yml
@@ -41,7 +41,7 @@ jobs:
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
container:
image: ${{ matrix.container }}
- options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
+ options: --gpus all --rm -v /dev/shm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
timeout-minutes: 120
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-run-test-${{ matrix.container }}
@@ -82,9 +82,8 @@ jobs:
pip install -r requirements/requirements-test.txt
- name: Unit Testing
run: |
- PYTHONPATH=$PWD pytest tests
+ PYTHONPATH=$PWD pytest --durations=0 tests
env:
DATA: /data/scratch/cifar-10
- NCCL_SHM_DISABLE: 1
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
LLAMA_PATH: /data/scratch/llama-tiny
diff --git a/.github/workflows/compatiblity_test_on_schedule.yml b/.github/workflows/compatiblity_test_on_schedule.yml
index 1cf456ff62c1..3348b51ecc6e 100644
--- a/.github/workflows/compatiblity_test_on_schedule.yml
+++ b/.github/workflows/compatiblity_test_on_schedule.yml
@@ -38,7 +38,7 @@ jobs:
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
container:
image: ${{ matrix.container }}
- options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
+ options: --gpus all --rm -v /dev/shm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
timeout-minutes: 120
steps:
- name: Install dependencies
@@ -80,10 +80,9 @@ jobs:
- name: Unit Testing
run: |
- PYTHONPATH=$PWD pytest tests
+ PYTHONPATH=$PWD pytest --durations=0 tests
env:
DATA: /data/scratch/cifar-10
- NCCL_SHM_DISABLE: 1
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
LLAMA_PATH: /data/scratch/llama-tiny
diff --git a/README.md b/README.md
index f045c56043be..7c234b15e75e 100644
--- a/README.md
+++ b/README.md
@@ -25,6 +25,7 @@
## Latest News
+* [2024/03] [Open-Sora: Revealing Complete Model Parameters, Training Details, and Everything for Sora-like Video Generation Models](https://hpc-ai.com/blog/open-sora-v1.0)
* [2024/03] [Open-Sora:Sora Replication Solution with 46% Cost Reduction, Sequence Expansion to Nearly a Million](https://hpc-ai.com/blog/open-sora)
* [2024/01] [Inference Performance Improved by 46%, Open Source Solution Breaks the Length Limit of LLM for Multi-Round Conversations](https://hpc-ai.com/blog/Colossal-AI-SwiftInfer)
* [2024/01] [Construct Refined 13B Private Model With Just $5000 USD, Upgraded Colossal-AI Llama-2 Open Source](https://hpc-ai.com/blog/colossal-llama-2-13b)
@@ -40,7 +41,7 @@
Colossal-AI for Real World Applications
- - Open-Sora: Open-Sora:Sora Replication Solution with 46% Cost Reduction, Sequence Expansion to Nearly a Million
+ - Open-Sora: Revealing Complete Model Parameters, Training Details, and Everything for Sora-like Video Generation Models
- Colossal-LLaMA-2: One Half-Day of Training Using a Few Hundred Dollars Yields Similar Results to Mainstream Large Models, Open-Source and Commercial-Free Domain-Specific Llm Solution
- ColossalChat: An Open-Source Solution for Cloning ChatGPT With a Complete RLHF Pipeline
- AIGC: Acceleration of Stable Diffusion
@@ -126,18 +127,19 @@ distributed training and inference in a few lines.
## Colossal-AI in the Real World
### Open-Sora
-[Open-Sora](https://github.com/hpcaitech/Open-Sora):Sora Replication Solution with 46% Cost Reduction, Sequence Expansion to Nearly a Million
+[Open-Sora](https://github.com/hpcaitech/Open-Sora):Revealing Complete Model Parameters, Training Details, and Everything for Sora-like Video Generation Models
[[code]](https://github.com/hpcaitech/Open-Sora)
-[[blog]](https://hpc-ai.com/blog/open-sora)
+[[blog]](https://hpc-ai.com/blog/open-sora-v1.0)
+[[HuggingFace model weights]](https://huggingface.co/hpcai-tech/Open-Sora)
+[[Demo]](https://github.com/hpcaitech/Open-Sora?tab=readme-ov-file#-latest-demo)
-
-
-
-
-
-
-
+
+(back to top)
### Colossal-LLaMA-2
diff --git a/applications/Colossal-LLaMA-2/train.py b/applications/Colossal-LLaMA-2/train.py
index 2e4bab75a085..d97da61e4dc8 100644
--- a/applications/Colossal-LLaMA-2/train.py
+++ b/applications/Colossal-LLaMA-2/train.py
@@ -56,6 +56,7 @@ def format_numel_str(numel: int) -> str:
def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
+ tensor = tensor.data
tensor.div_(dist.get_world_size())
return tensor
diff --git a/applications/ColossalQA/colossalqa/chain/retrieval_qa/base.py b/applications/ColossalQA/colossalqa/chain/retrieval_qa/base.py
index e80befdaccfa..a6e87e6bea9f 100644
--- a/applications/ColossalQA/colossalqa/chain/retrieval_qa/base.py
+++ b/applications/ColossalQA/colossalqa/chain/retrieval_qa/base.py
@@ -117,8 +117,8 @@ def _call(
) = copy.deepcopy(buffered_history_backup), copy.deepcopy(summarized_history_temp_backup)
# if rejection_trigger_keywords is not given, return the response from LLM directly
- rejection_trigger_keywrods = inputs.get('rejection_trigger_keywrods', [])
- answer = answer if all([rej not in answer for rej in rejection_trigger_keywrods]) else None
+ rejection_trigger_keywords = inputs.get('rejection_trigger_keywords', [])
+ answer = answer if all([rej not in answer for rej in rejection_trigger_keywords]) else None
if answer is None:
answer = inputs.get('rejection_answer', "抱歉,根据提供的信息无法回答该问题。")
if self.combine_documents_chain.memory is not None:
@@ -161,8 +161,8 @@ async def _acall(
input_documents=docs, question=question, callbacks=_run_manager.get_child(), **kwargs
)
# if rejection_trigger_keywords is not given, return the response from LLM directly
- rejection_trigger_keywrods = inputs.get('rejection_trigger_keywrods', [])
- answer = answer if all([rej not in answer for rej in rejection_trigger_keywrods]) or len(rejection_trigger_keywrods)==0 else None
+ rejection_trigger_keywords = inputs.get('rejection_trigger_keywords', [])
+ answer = answer if all([rej not in answer for rej in rejection_trigger_keywords]) or len(rejection_trigger_keywords)==0 else None
if answer is None:
answer = inputs.get('rejection_answer', "抱歉,根据提供的信息无法回答该问题。")
self.combine_documents_chain.memory.save_context({"question": question}, {"output": answer})
diff --git a/applications/ColossalQA/colossalqa/prompt/prompt.py b/applications/ColossalQA/colossalqa/prompt/prompt.py
index 533f0bd552b9..d62249ba9c51 100644
--- a/applications/ColossalQA/colossalqa/prompt/prompt.py
+++ b/applications/ColossalQA/colossalqa/prompt/prompt.py
@@ -75,7 +75,7 @@
# Below are English retrieval qa prompts
_EN_RETRIEVAL_QA_PROMPT = """[INST] <>Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist content.
-If the answer cannot be infered based on the given context, please say "I cannot answer the question based on the information given.".<>
+If the answer cannot be inferred based on the given context, please say "I cannot answer the question based on the information given.".<>
Use the context and chat history to answer the question.
context:
@@ -97,8 +97,8 @@
Human: I have a friend, Mike. Do you know him?
Assistant: Yes, I know a person named Mike
-sentence: What's his favorate food?
-disambiguated sentence: What's Mike's favorate food?
+sentence: What's his favorite food?
+disambiguated sentence: What's Mike's favorite food?
[/INST]
Chat history:
{chat_history}
diff --git a/applications/ColossalQA/colossalqa/retrieval_conversation_en.py b/applications/ColossalQA/colossalqa/retrieval_conversation_en.py
index d2626321d68d..96bce82b9ee0 100644
--- a/applications/ColossalQA/colossalqa/retrieval_conversation_en.py
+++ b/applications/ColossalQA/colossalqa/retrieval_conversation_en.py
@@ -80,7 +80,7 @@ def run(self, user_input: str, memory: ConversationBufferWithSummary) -> Tuple[s
self.retrieval_chain.run(
query=user_input,
stop=[self.memory.human_prefix + ": "],
- rejection_trigger_keywrods=["cannot answer the question"],
+ rejection_trigger_keywords=["cannot answer the question"],
rejection_answer="Sorry, this question cannot be answered based on the information provided.",
).split("\n")[0],
self.memory,
diff --git a/applications/ColossalQA/colossalqa/retrieval_conversation_universal.py b/applications/ColossalQA/colossalqa/retrieval_conversation_universal.py
index 76bec715fb6e..b23058d6dbe3 100644
--- a/applications/ColossalQA/colossalqa/retrieval_conversation_universal.py
+++ b/applications/ColossalQA/colossalqa/retrieval_conversation_universal.py
@@ -103,7 +103,7 @@ def load_supporting_docs(self, files: List[List[str]] = None, text_splitter: Tex
break
data_name = input("Enter a short description of the data:")
separator = input(
- "Enter a separator to force separating text into chunks, if no separator is given, the defaut separator is '\\n\\n', press ENTER directly to skip:"
+ "Enter a separator to force separating text into chunks, if no separator is given, the default separator is '\\n\\n', press ENTER directly to skip:"
)
separator = separator if separator != "" else "\n\n"
retriever_data = DocumentLoader([[file, data_name.replace(" ", "_")]]).all_data
diff --git a/applications/ColossalQA/colossalqa/retrieval_conversation_zh.py b/applications/ColossalQA/colossalqa/retrieval_conversation_zh.py
index 484be21c1553..4eef41947d11 100644
--- a/applications/ColossalQA/colossalqa/retrieval_conversation_zh.py
+++ b/applications/ColossalQA/colossalqa/retrieval_conversation_zh.py
@@ -87,7 +87,7 @@ def run(self, user_input: str, memory: ConversationBufferWithSummary) -> Tuple[s
query=user_input,
stop=["答案>"],
doc_prefix="支持文档",
- rejection_trigger_keywrods=["无法回答该问题"],
+ rejection_trigger_keywords=["无法回答该问题"],
rejection_answer="抱歉,根据提供的信息无法回答该问题。",
).split("\n")[0],
self.memory,
diff --git a/applications/ColossalQA/examples/retrieval_conversation_chatgpt.py b/applications/ColossalQA/examples/retrieval_conversation_chatgpt.py
index 00b920d274bc..1042adbf2095 100644
--- a/applications/ColossalQA/examples/retrieval_conversation_chatgpt.py
+++ b/applications/ColossalQA/examples/retrieval_conversation_chatgpt.py
@@ -61,7 +61,7 @@
information_retriever.add_documents(docs=documents, cleanup="incremental", mode="by_source", embedding=embedding)
prompt_template = """Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
- If the answer cannot be infered based on the given context, please don't share false information.
+ If the answer cannot be inferred based on the given context, please don't share false information.
Use the context and chat history to respond to the human's input at the end or carry on the conversation. You should generate one response only. No following up is needed.
context:
diff --git a/applications/ColossalQA/examples/retrieval_conversation_en.py b/applications/ColossalQA/examples/retrieval_conversation_en.py
index e0fe46ae6322..fe2b9b4db3c2 100644
--- a/applications/ColossalQA/examples/retrieval_conversation_en.py
+++ b/applications/ColossalQA/examples/retrieval_conversation_en.py
@@ -67,7 +67,7 @@ def disambiguity(input):
break
data_name = input("Enter a short description of the data:")
separator = input(
- "Enter a separator to force separating text into chunks, if no separator is given, the defaut separator is '\\n\\n'. Note that"
+ "Enter a separator to force separating text into chunks, if no separator is given, the default separator is '\\n\\n'. Note that"
+ "we use neural text spliter to split texts into chunks, the seperator only serves as a delimiter to force split long passage into"
+ " chunks before passing to the neural network. Press ENTER directly to skip:"
)
@@ -112,7 +112,7 @@ def disambiguity(input):
agent_response = retrieval_chain.run(
query=user_input,
stop=["Human: "],
- rejection_trigger_keywrods=EN_RETRIEVAL_QA_TRIGGER_KEYWORDS,
+ rejection_trigger_keywords=EN_RETRIEVAL_QA_TRIGGER_KEYWORDS,
rejection_answer=EN_RETRIEVAL_QA_REJECTION_ANSWER,
)
agent_response = agent_response.split("\n")[0]
diff --git a/applications/ColossalQA/examples/retrieval_conversation_en_customer_service.py b/applications/ColossalQA/examples/retrieval_conversation_en_customer_service.py
index d98a75592372..d4ba73b9468c 100644
--- a/applications/ColossalQA/examples/retrieval_conversation_en_customer_service.py
+++ b/applications/ColossalQA/examples/retrieval_conversation_en_customer_service.py
@@ -142,7 +142,7 @@ def metadata_func(data_sample, additional_fields):
agent_response = retrieval_chain.run(
query=user_input,
stop=["Human: "],
- rejection_trigger_keywrods=EN_RETRIEVAL_QA_TRIGGER_KEYWORDS,
+ rejection_trigger_keywords=EN_RETRIEVAL_QA_TRIGGER_KEYWORDS,
rejection_answer=EN_RETRIEVAL_QA_REJECTION_ANSWER,
)
agent_response = agent_response.split("\n")[0]
diff --git a/applications/ColossalQA/examples/retrieval_conversation_universal.py b/applications/ColossalQA/examples/retrieval_conversation_universal.py
index 361aa9833d27..5d13a63c3fad 100644
--- a/applications/ColossalQA/examples/retrieval_conversation_universal.py
+++ b/applications/ColossalQA/examples/retrieval_conversation_universal.py
@@ -11,7 +11,7 @@
parser.add_argument('--sql_file_path', type=str, default=None, help='path to the a empty folder for storing sql files for indexing')
args = parser.parse_args()
- # Will ask for documents path in runnning time
+ # Will ask for documents path in running time
session = UniversalRetrievalConversation(files_en=None,
files_zh=None,
zh_model_path=args.zh_model_path, en_model_path=args.en_model_path,
diff --git a/applications/ColossalQA/examples/retrieval_conversation_zh.py b/applications/ColossalQA/examples/retrieval_conversation_zh.py
index cbbbefad7c7b..b143b9baacc1 100644
--- a/applications/ColossalQA/examples/retrieval_conversation_zh.py
+++ b/applications/ColossalQA/examples/retrieval_conversation_zh.py
@@ -107,7 +107,7 @@ def disambiguity(input: str):
query=user_input,
stop=["答案>"],
doc_prefix="支持文档",
- rejection_trigger_keywrods=ZH_RETRIEVAL_QA_TRIGGER_KEYWORDS,
+ rejection_trigger_keywords=ZH_RETRIEVAL_QA_TRIGGER_KEYWORDS,
rejection_answer=ZH_RETRIEVAL_QA_REJECTION_ANSWER,
)
print(f"Agent: {agent_response}")
diff --git a/applications/ColossalQA/examples/webui_demo/RAG_ChatBot.py b/applications/ColossalQA/examples/webui_demo/RAG_ChatBot.py
index c58be9c33477..526328dda11b 100644
--- a/applications/ColossalQA/examples/webui_demo/RAG_ChatBot.py
+++ b/applications/ColossalQA/examples/webui_demo/RAG_ChatBot.py
@@ -140,7 +140,7 @@ def run(self, user_input: str, memory: ConversationBufferWithSummary) -> Tuple[s
result = self.rag_chain.run(
query=user_input,
stop=[memory.human_prefix + ": "],
- rejection_trigger_keywrods=ZH_RETRIEVAL_QA_TRIGGER_KEYWORDS,
+ rejection_trigger_keywords=ZH_RETRIEVAL_QA_TRIGGER_KEYWORDS,
rejection_answer=ZH_RETRIEVAL_QA_REJECTION_ANSWER,
)
return result, memory
diff --git a/applications/README.md b/applications/README.md
index 8abe1e52d96c..120767d5c9ea 100644
--- a/applications/README.md
+++ b/applications/README.md
@@ -4,7 +4,7 @@ This directory contains the applications that are powered by Colossal-AI.
The list of applications include:
-- [X] [Open-Sora](https://github.com/hpcaitech/Open-Sora): Sora Replication Solution with 46% Cost Reduction, Sequence Expansion to Nearly a Million
+- [X] [Open-Sora](https://github.com/hpcaitech/Open-Sora): Revealing Complete Model Parameters, Training Details, and Everything for Sora-like Video Generation Models
- [X] [Colossal-LLaMA-2](./Colossal-LLaMA-2/): Continual Pre-training of LLaMA-2.
- [X] [ColossalEval](./ColossalEval): Evaluation Pipeline for LLMs.
- [X] [ColossalChat](./Chat/README.md): Replication of ChatGPT with RLHF.
diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py
index 8cc76dd3e0f3..c37a6b4df72d 100644
--- a/colossalai/booster/plugin/hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py
@@ -199,7 +199,12 @@ def get_param_info(optim: Optimizer):
if optim is None:
return {}
- param_info = {"param_groups": [], "param2id": {}, "id2param": {}, "param2shape": {}}
+ param_info = {
+ "param_groups": [],
+ "param2id": {},
+ "id2param": {},
+ "param2shape": {},
+ }
start_index = 0
for group in optim.param_groups:
packed_group = {k: v for k, v in group.items() if k != "params"}
@@ -899,6 +904,7 @@ class HybridParallelPlugin(PipelinePluginBase):
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
+ parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True.
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
microbatch_size (int, optional): Microbatch size when using pipeline parallelism.
Either ``num_microbatches`` or ``microbatch_size`` should be provided if using pipeline.
@@ -939,6 +945,7 @@ def __init__(
enable_jit_fused: bool = False,
enable_sequence_parallelism: bool = False,
enable_sequence_overlap: bool = False,
+ parallel_output: bool = True,
num_microbatches: Optional[int] = None,
microbatch_size: Optional[int] = None,
initial_scale: float = 2**16,
@@ -1035,6 +1042,7 @@ def __init__(
enable_jit_fused=self.enable_jit_fused,
enable_sequence_parallelism=enable_sequence_parallelism,
enable_sequence_overlap=enable_sequence_overlap,
+ parallel_output=parallel_output,
)
self.amp_config = dict(
initial_scale=initial_scale,
diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
index 454710fccaa7..ae372dd034e0 100644
--- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
@@ -182,7 +182,7 @@ def __init__(
overlap_communication: bool = True,
use_ep_inside: bool = True,
custom_policy: Policy = None,
- checkpoint_io: Optional[MoECheckpintIO] = None,
+ checkpoint_io: Optional[MoECheckpointIO] = None,
) -> None:
assert (
dist.get_world_size() % (tp_size * pp_size) == 0
@@ -341,7 +341,6 @@ def seed_worker(worker_id):
**_kwargs,
)
-
def get_checkpoint_io(self) -> MoECheckpointIO:
if self.checkpoint_io is None:
self.checkpoint_io = MoECheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
diff --git a/colossalai/inference/README.md b/colossalai/inference/README.md
index dfac7cfd9be9..287853a86383 100644
--- a/colossalai/inference/README.md
+++ b/colossalai/inference/README.md
@@ -89,7 +89,7 @@ docker pull hpcaitech/colossalai-inference:v2
docker run -it --gpus all --name ANY_NAME -v $PWD:/workspace -w /workspace hpcaitech/colossalai-inference:v2 /bin/bash
# enter into docker container
-cd /path/to/CollossalAI
+cd /path/to/ColossalAI
pip install -e .
```
diff --git a/colossalai/legacy/inference/README.md b/colossalai/legacy/inference/README.md
index f466f46c1629..63b5f2a75fa8 100644
--- a/colossalai/legacy/inference/README.md
+++ b/colossalai/legacy/inference/README.md
@@ -86,7 +86,7 @@ docker pull hpcaitech/colossalai-inference:v2
docker run -it --gpus all --name ANY_NAME -v $PWD:/workspace -w /workspace hpcaitech/colossalai-inference:v2 /bin/bash
# enter into docker container
-cd /path/to/CollossalAI
+cd /path/to/ColossalAI
pip install -e .
# install lightllm
diff --git a/colossalai/legacy/inference/hybridengine/engine.py b/colossalai/legacy/inference/hybridengine/engine.py
index bb0b4c77a2a7..bc4e4fd199c0 100644
--- a/colossalai/legacy/inference/hybridengine/engine.py
+++ b/colossalai/legacy/inference/hybridengine/engine.py
@@ -46,7 +46,7 @@ class CaiInferEngine:
model = LlamaForCausalLM.from_pretrained("your_path_to_model")
tokenizer = LlamaTokenizer.from_pretrained("/home/lczyh/share/models/llama-7b-hf")
- # assume the model is infered with 2 pipeline stages
+ # assume the model is inferred with 2 pipeline stages
inferengine = CaiInferEngine(pp_size=2, model=model, model_policy=LlamaModelInferPolicy())
input = ["Introduce a landmark in China ","Introduce a landmark in China "]
@@ -70,7 +70,7 @@ def __init__(
max_input_len: int = 32,
max_output_len: int = 32,
verbose: bool = False,
- # TODO: implement early_stopping, and various gerneration options
+ # TODO: implement early_stopping, and various generation options
early_stopping: bool = False,
do_sample: bool = False,
num_beams: int = 1,
diff --git a/colossalai/nn/optimizer/README.md b/colossalai/nn/optimizer/README.md
index e89e6217d596..d3f8badc7313 100644
--- a/colossalai/nn/optimizer/README.md
+++ b/colossalai/nn/optimizer/README.md
@@ -47,7 +47,7 @@ be optimized jointly to further speed up training.
2. Model Accuracy
- Communication Efficiency
- - Reduce Volumn of Comm.
+ - Reduce Volume of Comm.
- Reduce Frequency of Comm.
- Memory Efficiency
- Mix-Precision Training
diff --git a/colossalai/pipeline/schedule/generate.py b/colossalai/pipeline/schedule/generate.py
index d6a6aec63a12..48ae54c1fe54 100644
--- a/colossalai/pipeline/schedule/generate.py
+++ b/colossalai/pipeline/schedule/generate.py
@@ -164,7 +164,7 @@ def _gen_token_action(self, model: Module):
self.timestamps[self.mb_manager.idx].append(time.time())
assert (
"logits" in logits
- ), f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}"
+ ), f"When first stage in GENERATE phase, the output should have attribute `logits`, but has {logits.keys()}"
new_token = self._get_token_id(logits["logits"])
self.mb_manager.step(new_token)
@@ -401,7 +401,7 @@ def generate_step_broadcast(self, model: Module, data_iter: Iterable) -> Union[t
self.timestamps[self.mb_manager.idx].append(time.time())
assert (
"logits" in logits
- ), f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}"
+ ), f"When first stage in GENERATE phase, the output should have attribute `logits`, but has {logits.keys()}"
new_token = self._get_token_id(logits["logits"])
self.mb_manager.step(new_token)
# If the current micro batch is not DONE, go through blocks
diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py
index 3e5cc6015adc..1e22d9094eae 100644
--- a/colossalai/shardformer/modeling/gpt2.py
+++ b/colossalai/shardformer/modeling/gpt2.py
@@ -25,6 +25,7 @@
from colossalai.shardformer.shard import ShardConfig
from ..layer import cross_entropy_1d
+from ..layer._operation import gather_forward_split_backward
class GPT2PipelineForwards:
@@ -337,6 +338,9 @@ def gpt2_lmhead_model_forward(
else:
loss = loss_fct(shift_logits, shift_labels)
+ if not shard_config.parallel_output:
+ lm_logits = gather_forward_split_backward(lm_logits, -1, shard_config.tensor_parallel_process_group)
+
if not return_dict:
output = (lm_logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
@@ -793,11 +797,12 @@ def forward(
scale = scale * (1 / float(self.layer_idx + 1))
# use coloattention
- attention = ColoAttention(
- embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.attn_dropout.p, scale=scale
- )
+ if not hasattr(self, "attention"):
+ self.attention = ColoAttention(
+ embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.attn_dropout.p, scale=scale
+ )
- attn_output = attention(query, key, value, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type)
+ attn_output = self.attention(query, key, value, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type)
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
@@ -1083,6 +1088,9 @@ def forward(
else:
loss = loss_fct(shift_logits, shift_labels)
+ if not shard_config.parallel_output:
+ lm_logits = gather_forward_split_backward(lm_logits, -1, shard_config.tensor_parallel_process_group)
+
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py
index f20ceb2d6760..eb8e9f748527 100644
--- a/colossalai/shardformer/modeling/llama.py
+++ b/colossalai/shardformer/modeling/llama.py
@@ -16,7 +16,7 @@
from colossalai.shardformer.shard import ShardConfig
from ..layer import cross_entropy_1d
-from ..layer._operation import _gather
+from ..layer._operation import gather_forward_split_backward
try:
from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask
@@ -290,7 +290,7 @@ def llama_for_causal_lm_forward(
loss = loss_fct(shift_logits, shift_labels)
if not shard_config.parallel_output:
- logits = _gather(logits, -1, shard_config.tensor_parallel_process_group)
+ logits = gather_forward_split_backward(logits, -1, shard_config.tensor_parallel_process_group)
if not return_dict:
output = (logits,) + outputs[1:]
@@ -485,8 +485,9 @@ def forward(
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
attn_mask_type = AttnMaskType.paddedcausal
- attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads)
- attn_output = attention(
+ if not hasattr(self, "attention"):
+ self.attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads)
+ attn_output = self.attention(
query_states,
key_states,
value_states,
@@ -593,7 +594,7 @@ def forward(
loss = loss_fct(shift_logits, shift_labels)
if not shard_config.parallel_output:
- logits = _gather(logits, -1, shard_config.tensor_parallel_process_group)
+ logits = gather_forward_split_backward(logits, -1, shard_config.tensor_parallel_process_group)
if not return_dict:
output = (logits,) + outputs[1:]
diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py
index 1d2b7a570681..9a49b1ba6a14 100644
--- a/colossalai/shardformer/policies/base_policy.py
+++ b/colossalai/shardformer/policies/base_policy.py
@@ -242,4 +242,4 @@ def get_stage_index(
end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1]
stage_indices.append([start_idx, end_idx])
- return stage_indices[0] if num_model_chunks == 1 else stage_indices
+ return stage_indices[0] if num_model_chunks == 1 else stage_indices
\ No newline at end of file
diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py
index 415fc6dd5f06..da27341d9c29 100644
--- a/colossalai/shardformer/shard/shard_config.py
+++ b/colossalai/shardformer/shard/shard_config.py
@@ -34,8 +34,10 @@ class ShardConfig:
enable_all_optimization: bool = False
enable_sequence_parallelism: bool = False
enable_sequence_overlap: bool = False
- parallel_output = True
+ parallel_output: bool = True
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
+ # TODO padding vocab
+ # make_vocab_size_divisible_by: int = 128
# pipeline_parallel_size: int
# data_parallel_size: int
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
diff --git a/docs/README-zh-Hans.md b/docs/README-zh-Hans.md
index 110e1a502b0f..93045ea6adc6 100644
--- a/docs/README-zh-Hans.md
+++ b/docs/README-zh-Hans.md
@@ -24,6 +24,7 @@
## 新闻
+* [2024/03] [Open-Sora: Revealing Complete Model Parameters, Training Details, and Everything for Sora-like Video Generation Models](https://hpc-ai.com/blog/open-sora-v1.0)
* [2024/03] [Open-Sora:Sora Replication Solution with 46% Cost Reduction, Sequence Expansion to Nearly a Million](https://hpc-ai.com/blog/open-sora)
* [2024/01] [Inference Performance Improved by 46%, Open Source Solution Breaks the Length Limit of LLM for Multi-Round Conversations](https://hpc-ai.com/blog/Colossal-AI-SwiftInfer)
* [2024/01] [Construct Refined 13B Private Model With Just $5000 USD, Upgraded Colossal-AI Llama-2 Open Source](https://hpc-ai.com/blog/colossal-llama-2-13b)
@@ -39,7 +40,7 @@
-
Colossal-AI 成功案例
- - Open-Sora:开源Sora复现方案,成本降低46%,序列扩充至近百万
+ - Open-Sora:全面开源类Sora模型参数和所有训练细节
- Colossal-LLaMA-2: 千元预算半天训练,效果媲美主流大模型,开源可商用中文LLaMA-2
- ColossalChat:完整RLHF流程0门槛克隆ChatGPT
- AIGC: 加速 Stable Diffusion
@@ -121,17 +122,17 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的
## Colossal-AI 成功案例
### Open-Sora
-[Open-Sora](https://github.com/hpcaitech/Open-Sora):开源Sora复现方案,成本降低46%,序列扩充至近百万
+[Open-Sora](https://github.com/hpcaitech/Open-Sora):全面开源类Sora模型参数和所有训练细节
[[代码]](https://github.com/hpcaitech/Open-Sora)
-[[博客]](https://hpc-ai.com/blog/open-sora)
+[[博客]](https://hpc-ai.com/blog/open-sora-v1.0)
+[[模型权重]](https://huggingface.co/hpcai-tech/Open-Sora)
+[[演示样例]](https://github.com/hpcaitech/Open-Sora?tab=readme-ov-file#-latest-demo)
-
-
-
-
-
-
-
+
### Colossal-LLaMA-2
diff --git a/examples/images/diffusion/ldm/modules/diffusionmodules/openaimodel.py b/examples/images/diffusion/ldm/modules/diffusionmodules/openaimodel.py
index 614fe510f20e..6c80f3229ce3 100644
--- a/examples/images/diffusion/ldm/modules/diffusionmodules/openaimodel.py
+++ b/examples/images/diffusion/ldm/modules/diffusionmodules/openaimodel.py
@@ -338,7 +338,7 @@ def count_flops_attn(model, _x, y):
class QKVAttentionLegacy(nn.Module):
"""
- A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
+ A module which performs QKV attention. Matches legacy QKVAttention + input/output heads shaping
"""
def __init__(self, n_heads):
diff --git a/examples/language/grok-1/README.md b/examples/language/grok-1/README.md
new file mode 100644
index 000000000000..c523f941262d
--- /dev/null
+++ b/examples/language/grok-1/README.md
@@ -0,0 +1,43 @@
+# Grok-1 Inference
+
+## Install
+
+```bash
+# Make sure you install colossalai from the latest source code
+git clone https://github.com/hpcaitech/ColossalAI.git
+cd ColossalAI
+pip install .
+cd examples/language/grok-1
+pip install -r requirements.txt
+```
+
+## Tokenizer preparation
+
+You should download the tokenizer from the official grok-1 repository.
+
+```bash
+wget https://github.com/xai-org/grok-1/raw/main/tokenizer.model
+```
+
+## Inference
+
+You need 8x A100 80GB or equivalent GPUs to run the inference.
+
+We provide two scripts for inference. `run_inference_fast.sh` uses tensor parallelism provided by ColossalAI, and it is faster. `run_inference_slow.sh` uses auto device provided by transformers, and it is slower.
+
+Command format:
+
+```bash
+./run_inference_fast.sh
+./run_inference_slow.sh
+```
+
+`model_name_or_path` can be a local path or a model name from Hugging Face model hub. We provided weights on model hub, named `hpcaitech/grok-1`.
+
+Command example:
+
+```bash
+./run_inference_fast.sh hpcaitech/grok-1 tokenizer.model
+```
+
+It will take 5-10 minutes to load checkpoints. Don't worry, it's not stuck.
diff --git a/examples/language/grok-1/grok1_policy.py b/examples/language/grok-1/grok1_policy.py
new file mode 100644
index 000000000000..aefea6f3df1c
--- /dev/null
+++ b/examples/language/grok-1/grok1_policy.py
@@ -0,0 +1,99 @@
+from typing import Dict, Union
+
+import torch.nn as nn
+
+from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
+from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
+
+
+class Grok1Policy(Policy):
+ def config_sanity_check(self):
+ pass
+
+ def preprocess(self) -> nn.Module:
+ if self.shard_config.enable_tensor_parallelism:
+ vocab_size = self.model.config.vocab_size
+ world_size = self.shard_config.tensor_parallel_size
+ assert vocab_size % world_size == 0, f"vocab_size {vocab_size} must be divisible by world_size {world_size}"
+ return self.model
+
+ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
+ policy = {}
+ if self.shard_config.enable_tensor_parallelism:
+ decoder_attribute_replacement = {
+ "attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
+ "attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
+ "attn.num_key_value_heads": self.model.config.num_key_value_heads
+ // self.shard_config.tensor_parallel_size,
+ }
+ decoder_submodule_replacement = [
+ SubModuleReplacementDescription(
+ suffix="attn.q_proj",
+ target_module=Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix="attn.k_proj",
+ target_module=Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix="attn.v_proj",
+ target_module=Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix="attn.o_proj",
+ target_module=Linear1D_Row,
+ ),
+ ]
+ for i in range(self.model.config.num_experts):
+ decoder_submodule_replacement.extend(
+ [
+ SubModuleReplacementDescription(
+ suffix=f"moe_block.experts[{i}].linear",
+ target_module=Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix=f"moe_block.experts[{i}].linear_v",
+ target_module=Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix=f"moe_block.experts[{i}].linear_1",
+ target_module=Linear1D_Row,
+ ),
+ ]
+ )
+
+ policy["DecoderLayer"] = ModulePolicyDescription(
+ attribute_replacement=decoder_attribute_replacement,
+ sub_module_replacement=decoder_submodule_replacement,
+ )
+ self.append_or_create_submodule_replacement(
+ description=SubModuleReplacementDescription(
+ suffix="embed_tokens",
+ target_module=VocabParallelEmbedding1D,
+ ),
+ policy=policy,
+ target_key="Grok1Model",
+ )
+ return policy
+
+ def postprocess(self):
+ return self.model
+
+
+class Grok1ModelPolicy(Grok1Policy):
+ pass
+
+
+class Grok1ForCausalLMPolicy(Grok1Policy):
+ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
+ policy = super().module_policy()
+ self.append_or_create_submodule_replacement(
+ description=SubModuleReplacementDescription(
+ suffix="lm_head",
+ target_module=Linear1D_Col,
+ kwargs={"gather_output": not self.shard_config.parallel_output},
+ ),
+ policy=policy,
+ target_key="Grok1ModelForCausalLM",
+ )
+ return policy
diff --git a/examples/language/grok-1/inference.py b/examples/language/grok-1/inference.py
new file mode 100644
index 000000000000..ca0ad0d4fe95
--- /dev/null
+++ b/examples/language/grok-1/inference.py
@@ -0,0 +1,32 @@
+import time
+
+import torch
+from sentencepiece import SentencePieceProcessor
+from transformers import AutoModelForCausalLM
+from utils import get_defualt_parser, inference, print_output
+
+if __name__ == "__main__":
+ parser = get_defualt_parser()
+ args = parser.parse_args()
+ start = time.time()
+ torch.set_default_dtype(torch.bfloat16)
+ model = AutoModelForCausalLM.from_pretrained(
+ args.pretrained,
+ trust_remote_code=True,
+ device_map="auto",
+ torch_dtype=torch.bfloat16,
+ )
+ sp = SentencePieceProcessor(model_file=args.tokenizer)
+ for text in args.text:
+ output = inference(
+ model,
+ sp,
+ text,
+ max_new_tokens=args.max_new_tokens,
+ do_sample=args.do_sample,
+ temperature=args.temperature,
+ top_k=args.top_k,
+ top_p=args.top_p,
+ )
+ print_output(text, sp.decode(output))
+ print(f"Overall time: {time.time() - start} seconds.")
diff --git a/examples/language/grok-1/inference_tp.py b/examples/language/grok-1/inference_tp.py
new file mode 100644
index 000000000000..99de60e1f6be
--- /dev/null
+++ b/examples/language/grok-1/inference_tp.py
@@ -0,0 +1,50 @@
+import time
+
+import torch
+from grok1_policy import Grok1ForCausalLMPolicy
+from sentencepiece import SentencePieceProcessor
+from transformers import AutoModelForCausalLM
+from utils import get_defualt_parser, inference, print_output
+
+import colossalai
+from colossalai.booster import Booster
+from colossalai.booster.plugin import HybridParallelPlugin
+from colossalai.cluster import DistCoordinator
+from colossalai.lazy import LazyInitContext
+from colossalai.utils import get_current_device
+
+if __name__ == "__main__":
+ parser = get_defualt_parser()
+ args = parser.parse_args()
+ start = time.time()
+ colossalai.launch_from_torch({})
+ coordinator = DistCoordinator()
+ plugin = HybridParallelPlugin(
+ tp_size=coordinator.world_size,
+ pp_size=1,
+ precision="bf16",
+ parallel_output=False,
+ custom_policy=Grok1ForCausalLMPolicy(),
+ )
+ booster = Booster(plugin=plugin)
+ torch.set_default_dtype(torch.bfloat16)
+ with LazyInitContext(default_device=get_current_device()):
+ model = AutoModelForCausalLM.from_pretrained(
+ args.pretrained, trust_remote_code=True, torch_dtype=torch.bfloat16
+ )
+ model, *_ = booster.boost(model)
+ sp = SentencePieceProcessor(model_file=args.tokenizer)
+ for text in args.text:
+ output = inference(
+ model.unwrap(),
+ sp,
+ text,
+ max_new_tokens=args.max_new_tokens,
+ do_sample=args.do_sample,
+ temperature=args.temperature,
+ top_k=args.top_k,
+ top_p=args.top_p,
+ )
+ if coordinator.is_master():
+ print_output(text, sp.decode(output))
+ coordinator.print_on_master(f"Overall time: {time.time() - start} seconds.")
diff --git a/examples/language/grok-1/requirements.txt b/examples/language/grok-1/requirements.txt
new file mode 100644
index 000000000000..15d5ea53a15e
--- /dev/null
+++ b/examples/language/grok-1/requirements.txt
@@ -0,0 +1,4 @@
+torch>=2.1.0,<2.2.0
+colossalai>=0.3.6
+sentencepiece==0.1.99
+transformers==4.35.0
diff --git a/examples/language/grok-1/run_inference_fast.sh b/examples/language/grok-1/run_inference_fast.sh
new file mode 100755
index 000000000000..0dc398c53e33
--- /dev/null
+++ b/examples/language/grok-1/run_inference_fast.sh
@@ -0,0 +1,11 @@
+#!/usr/bin/env bash
+
+PRETRAINED=${1:-"hpcaitech/grok-1"}
+TOKENIZER=${2:-"tokenizer.model"}
+
+torchrun --standalone --nproc_per_node 8 inference_tp.py --pretrained "$PRETRAINED" \
+ --tokenizer "$TOKENIZER" \
+ --max_new_tokens 64 \
+ --text "The company's annual conference, featuring keynote speakers and exclusive product launches, will be held at the Los Angeles Convention Center from October 20th to October 23rd, 2021. Extract the date mentioned in the above sentence." \
+ "将以下句子翻译成英语。 我喜欢看电影和读书。" \
+ "All books have the same weight, 10 books weigh 5kg, what is the weight of 2 books?"
diff --git a/examples/language/grok-1/run_inference_slow.sh b/examples/language/grok-1/run_inference_slow.sh
new file mode 100755
index 000000000000..c64dd93b9e62
--- /dev/null
+++ b/examples/language/grok-1/run_inference_slow.sh
@@ -0,0 +1,11 @@
+#!/usr/bin/env bash
+
+PRETRAINED=${1:-"hpcaitech/grok-1"}
+TOKENIZER=${2:-"tokenizer.model"}
+
+python3 inference.py --pretrained "$PRETRAINED" \
+ --tokenizer "$TOKENIZER" \
+ --max_new_tokens 64 \
+ --text "The company's annual conference, featuring keynote speakers and exclusive product launches, will be held at the Los Angeles Convention Center from October 20th to October 23rd, 2021. Extract the date mentioned in the above sentence." \
+ "将以下句子翻译成英语。 我喜欢看电影和读书。" \
+ "All books have the same weight, 10 books weigh 5kg, what is the weight of 2 books?"
diff --git a/examples/language/grok-1/test_ci.sh b/examples/language/grok-1/test_ci.sh
new file mode 100755
index 000000000000..f6a0d658462b
--- /dev/null
+++ b/examples/language/grok-1/test_ci.sh
@@ -0,0 +1 @@
+pip install -r requirements.txt
diff --git a/examples/language/grok-1/utils.py b/examples/language/grok-1/utils.py
new file mode 100644
index 000000000000..f113f852eff6
--- /dev/null
+++ b/examples/language/grok-1/utils.py
@@ -0,0 +1,46 @@
+import argparse
+
+import torch
+
+
+class Bcolors:
+ HEADER = "\033[95m"
+ OKBLUE = "\033[94m"
+ OKCYAN = "\033[96m"
+ OKGREEN = "\033[92m"
+ WARNING = "\033[93m"
+ FAIL = "\033[91m"
+ ENDC = "\033[0m"
+ BOLD = "\033[1m"
+ UNDERLINE = "\033[4m"
+
+
+def print_output(text, output):
+ print(f"-----\n{Bcolors.OKBLUE}{text}{Bcolors.ENDC}{output[len(text):]}")
+
+
+@torch.no_grad()
+def inference(model, sp, text, **generate_kwargs):
+ input_ids = sp.encode(text)
+ input_ids = torch.tensor([input_ids]).cuda()
+ attention_mask = torch.ones_like(input_ids)
+ inputs = {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ **generate_kwargs,
+ }
+ outputs = model.generate(**inputs)
+ return outputs[0].tolist()
+
+
+def get_defualt_parser():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--pretrained", type=str, default="hpcaitech/grok-1")
+ parser.add_argument("--tokenizer", type=str, default="tokenizer.model")
+ parser.add_argument("--text", type=str, nargs="+", default=["Hi, what's your name?"])
+ parser.add_argument("--max_new_tokens", type=int, default=30)
+ parser.add_argument("--do_sample", action="store_true", default=False)
+ parser.add_argument("--temperature", type=float, default=0.3, help="Set temperature value")
+ parser.add_argument("--top_k", type=int, default=50, help="Set top_k value for top-k-filtering")
+ parser.add_argument("--top_p", type=float, default=0.95, help="Set top_p value for generation")
+ return parser
diff --git a/tests/test_booster/test_plugin/test_3d_plugin.py b/tests/test_booster/test_plugin/test_3d_plugin.py
index 285c4866c441..38361d803c49 100644
--- a/tests/test_booster/test_plugin/test_3d_plugin.py
+++ b/tests/test_booster/test_plugin/test_3d_plugin.py
@@ -260,7 +260,7 @@ def run_grad_acc_test(test_args):
origin_model, origin_optimizer, dataloader=dataloader
)
for p1, p2 in zip(model.unwrap().parameters(), origin_model.unwrap().parameters()):
- assert_close(p1.to(p2.dtype), p2, atol=1e-2, rtol=1e-2)
+ assert_close(p1.to(p2.dtype), p2, atol=1e-2, rtol=1e-2)
def run_dist(rank, world_size, port, early_stop: bool = True):
diff --git a/tests/test_booster/test_plugin/test_gemini_plugin.py b/tests/test_booster/test_plugin/test_gemini_plugin.py
index 0f72d2bcd3e4..89214477239b 100644
--- a/tests/test_booster/test_plugin/test_gemini_plugin.py
+++ b/tests/test_booster/test_plugin/test_gemini_plugin.py
@@ -1,7 +1,6 @@
from contextlib import nullcontext
from typing import Optional
-import pytest
import torch
import torch.distributed as dist
@@ -12,13 +11,7 @@
from colossalai.lazy.lazy_init import LazyInitContext
from colossalai.nn.optimizer import HybridAdam
from colossalai.tensor.colo_parameter import ColoParameter
-from colossalai.testing import (
- clear_cache_before_run,
- parameterize,
- rerun_if_address_is_in_use,
- skip_if_not_enough_gpus,
- spawn,
-)
+from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import COMMON_MODELS, IS_FAST_TEST, model_zoo
@@ -177,12 +170,5 @@ def test_gemini_plugin(early_stop: bool = True):
spawn(run_dist, 4, early_stop=early_stop)
-@pytest.mark.largedist
-@skip_if_not_enough_gpus(8)
-@rerun_if_address_is_in_use()
-def test_gemini_plugin_3d(early_stop: bool = True):
- spawn(run_dist, 8, early_stop=early_stop)
-
-
if __name__ == "__main__":
- test_gemini_plugin(early_stop=False)
\ No newline at end of file
+ test_gemini_plugin(early_stop=False)
diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py
index daddf6dc7ca0..ece3b40360e8 100644
--- a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py
+++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py
@@ -16,7 +16,6 @@
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
- skip_if_not_enough_gpus,
spawn,
)
from tests.kit.model_zoo import model_zoo
@@ -178,12 +177,5 @@ def test_gemini_ckpIO():
spawn(run_dist, 4)
-@pytest.mark.largedist
-@skip_if_not_enough_gpus(min_gpus=8)
-@rerun_if_address_is_in_use()
-def test_gemini_ckpIO_3d():
- spawn(run_dist, 8)
-
-
if __name__ == "__main__":
- test_gemini_ckpIO()
\ No newline at end of file
+ test_gemini_ckpIO()
diff --git a/tests/test_shardformer/test_model/test_shard_falcon.py b/tests/test_shardformer/test_model/test_shard_falcon.py
index 9630451799c0..5e2efcd80367 100644
--- a/tests/test_shardformer/test_model/test_shard_falcon.py
+++ b/tests/test_shardformer/test_model/test_shard_falcon.py
@@ -1,5 +1,6 @@
import pytest
import torch
+import torch.distributed as dist
import colossalai
from colossalai.logging import disable_existing_loggers
@@ -72,6 +73,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
if stage_manager is None or stage_manager.is_first_stage():
if test_config["precision"] == "fp32":
atol, rtol = 2e-4, 1e-3
+ if dist.get_world_size() > 4:
+ atol, rtol = 4e-4, 3e-2
else:
atol, rtol = 5e-3, 5e-3
check_weight(falcon, sharded_falcon, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
diff --git a/version.txt b/version.txt
index c2c0004f0e2a..449d7e73a966 100644
--- a/version.txt
+++ b/version.txt
@@ -1 +1 @@
-0.3.5
+0.3.6