diff --git a/.github/workflows/compatiblity_test_on_dispatch.yml b/.github/workflows/compatiblity_test_on_dispatch.yml index 3dcc4dfd182a..1778d64ee287 100644 --- a/.github/workflows/compatiblity_test_on_dispatch.yml +++ b/.github/workflows/compatiblity_test_on_dispatch.yml @@ -72,7 +72,7 @@ jobs: ssh-key: ${{ secrets.SSH_KEY_FOR_CI }} - name: Download cub for CUDA 10.2 run: | - CUDA_VERSION=$(cat $CUDA_HOME/version.txt | grep "CUDA Version" | awk '{print $NF}' | cut -d. -f1,2) + CUDA_VERSION=$(nvcc -V | awk -F ',| ' '/release/{print $6}') # check if it is CUDA 10.2 # download cub diff --git a/.github/workflows/compatiblity_test_on_pr.yml b/.github/workflows/compatiblity_test_on_pr.yml index 5098b8e364d0..c0f45c65a7fc 100644 --- a/.github/workflows/compatiblity_test_on_pr.yml +++ b/.github/workflows/compatiblity_test_on_pr.yml @@ -66,7 +66,7 @@ jobs: ssh-key: ${{ secrets.SSH_KEY_FOR_CI }} - name: Download cub for CUDA 10.2 run: | - CUDA_VERSION=$(cat $CUDA_HOME/version.txt | grep "CUDA Version" | awk '{print $NF}' | cut -d. -f1,2) + CUDA_VERSION=$(nvcc -V | awk -F ',| ' '/release/{print $6}') # check if it is CUDA 10.2 # download cub diff --git a/.github/workflows/compatiblity_test_on_schedule.yml b/.github/workflows/compatiblity_test_on_schedule.yml index 9802795fad24..15ac4f1a92bb 100644 --- a/.github/workflows/compatiblity_test_on_schedule.yml +++ b/.github/workflows/compatiblity_test_on_schedule.yml @@ -61,6 +61,18 @@ jobs: with: ssh-key: ${{ secrets.SSH_KEY_FOR_CI }} + - name: Download cub for CUDA 10.2 + run: | + CUDA_VERSION=$(nvcc -V | awk -F ',| ' '/release/{print $6}') + + # check if it is CUDA 10.2 + # download cub + if [ "$CUDA_VERSION" = "10.2" ]; then + wget https://github.com/NVIDIA/cub/archive/refs/tags/1.8.0.zip + unzip 1.8.0.zip + cp -r cub-1.8.0/cub/ colossalai/kernel/cuda_native/csrc/kernels/include/ + fi + - name: Install Colossal-AI run: | pip install -v --no-cache-dir . diff --git a/.github/workflows/cuda_ext_check_before_merge.yml b/.github/workflows/cuda_ext_check_before_merge.yml index eba5bb98ec07..686f0f395c73 100644 --- a/.github/workflows/cuda_ext_check_before_merge.yml +++ b/.github/workflows/cuda_ext_check_before_merge.yml @@ -37,6 +37,18 @@ jobs: - name: Install PyTorch run: eval ${{ matrix.build.torch_command }} + - name: Download cub for CUDA 10.2 + run: | + CUDA_VERSION=$(nvcc -V | awk -F ',| ' '/release/{print $6}') + + # check if it is CUDA 10.2 + # download cub + if [ "$CUDA_VERSION" = "10.2" ]; then + wget https://github.com/NVIDIA/cub/archive/refs/tags/1.8.0.zip + unzip 1.8.0.zip + cp -r cub-1.8.0/cub/ colossalai/kernel/cuda_native/csrc/kernels/include/ + fi + - name: Build run: | CUDA_EXT=1 pip install -v . diff --git a/.github/workflows/run_chatgpt_examples.yml b/.github/workflows/run_chatgpt_examples.yml index 510f6b6f0985..650689498fda 100644 --- a/.github/workflows/run_chatgpt_examples.yml +++ b/.github/workflows/run_chatgpt_examples.yml @@ -43,7 +43,9 @@ jobs: run: | cd applications/Chat rm -rf ~/.cache/colossalai - ./examples/test_ci.sh + ./tests/test_inference.sh + ./tests/test_benchmarks.sh + ./tests/test_train.sh env: NCCL_SHM_DISABLE: 1 MAX_JOBS: 8 diff --git a/applications/Chat/coati/dataset/__init__.py b/applications/Chat/coati/dataset/__init__.py index f650668e90b0..bd4e5460d11e 100644 --- a/applications/Chat/coati/dataset/__init__.py +++ b/applications/Chat/coati/dataset/__init__.py @@ -1,9 +1,10 @@ from .prompt_dataset import PromptDataset from .reward_dataset import HhRlhfDataset, RmStaticDataset -from .sft_dataset import DataCollatorForSupervisedDataset, SFTDataset, SupervisedDataset +from .sft_dataset import SFTDataset, SupervisedDataset from .utils import is_rank_0 __all__ = [ - 'RmStaticDataset', 'HhRlhfDataset', 'is_rank_0', 'SFTDataset', 'SupervisedDataset', - 'DataCollatorForSupervisedDataset', 'PromptDataset' + 'RmStaticDataset', 'HhRlhfDataset', + 'SFTDataset', 'SupervisedDataset', + 'PromptDataset', 'is_rank_0', ] diff --git a/applications/Chat/coati/dataset/prompt_dataset.py b/applications/Chat/coati/dataset/prompt_dataset.py index 0bdcbbc5928e..2c953fffa513 100644 --- a/applications/Chat/coati/dataset/prompt_dataset.py +++ b/applications/Chat/coati/dataset/prompt_dataset.py @@ -1,20 +1,13 @@ -import copy -import random from collections import defaultdict -from dataclasses import dataclass, field -from typing import Callable, Dict, Sequence +from typing import Dict import torch -import torch.distributed as dist import transformers from torch.utils.data import Dataset -from tqdm import tqdm from colossalai.logging import get_dist_logger -from .utils import is_rank_0, jload - -logger = get_dist_logger() +from .utils import jload class PromptDataset(Dataset): @@ -27,12 +20,13 @@ def __init__(self, max_length: int = 96): super(PromptDataset, self).__init__() self.keyed_prompt = defaultdict(list) - logger.info("Loading data...") + self.logger = get_dist_logger() + self.logger.info("Loading data...") list_data_dict = jload(data_path) - logger.info(f"Loaded {len(list_data_dict)} examples.") + self.logger.info(f"Loaded {len(list_data_dict)} examples.") if max_datasets_size is not None: - logger.info(f"Limiting dataset to {max_datasets_size} examples.") + self.logger.info(f"Limiting dataset to {max_datasets_size} examples.") list_data_dict = list_data_dict[:max_datasets_size] instructions = [data_dict["instruction"] for data_dict in list_data_dict] diff --git a/applications/Chat/coati/dataset/reward_dataset.py b/applications/Chat/coati/dataset/reward_dataset.py index 5dacf7e81464..3c4ec8b214bb 100644 --- a/applications/Chat/coati/dataset/reward_dataset.py +++ b/applications/Chat/coati/dataset/reward_dataset.py @@ -20,44 +20,44 @@ class RmStaticDataset(Dataset): def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None: super().__init__() - self.chosen = [] - self.reject = [] - if special_token is None: - self.end_token = tokenizer.eos_token - else: - self.end_token = special_token - for data in tqdm(dataset, disable=not is_rank_0()): - prompt = data['prompt'] - - chosen = prompt + data['chosen'] + self.end_token - chosen_token = tokenizer(chosen, - max_length=max_length, - padding="max_length", - truncation=True, - return_tensors="pt") - self.chosen.append({ - "input_ids": chosen_token['input_ids'], - "attention_mask": chosen_token['attention_mask'] - }) - - reject = prompt + data['rejected'] + self.end_token - reject_token = tokenizer(reject, - max_length=max_length, - padding="max_length", - truncation=True, - return_tensors="pt") - self.reject.append({ - "input_ids": reject_token['input_ids'], - "attention_mask": reject_token['attention_mask'] - }) + self.end_token = tokenizer.eos_token \ + if special_token is None else special_token + + chosen = [ + data["prompt"] + data["chosen"] + self.end_token + for data in tqdm(dataset, disable=not is_rank_0()) + ] + chosen_token = tokenizer(chosen, + max_length=max_length, + padding="max_length", + truncation=True, + return_tensors="pt") + self.chosen = { + "input_ids": chosen_token["input_ids"], + "attention_mask": chosen_token["attention_mask"] + } + + reject = [ + data["prompt"] + data["rejected"] + self.end_token + for data in tqdm(dataset, disable=not is_rank_0()) + ] + reject_token = tokenizer(reject, + max_length=max_length, + padding="max_length", + truncation=True, + return_tensors="pt") + self.reject = { + "input_ids": reject_token["input_ids"], + "attention_mask": reject_token["attention_mask"] + } def __len__(self): - length = len(self.chosen) + length = self.chosen["input_ids"].shape[0] return length def __getitem__(self, idx): - return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][ - "input_ids"], self.reject[idx]["attention_mask"] + return self.chosen["input_ids"][idx], self.chosen["attention_mask"][idx], \ + self.reject["input_ids"][idx], self.reject["attention_mask"][idx] # Anthropic/hh-rlhf @@ -74,39 +74,41 @@ class HhRlhfDataset(Dataset): def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None: super().__init__() - self.chosen = [] - self.reject = [] - if special_token is None: - self.end_token = tokenizer.eos_token - else: - self.end_token = special_token - for data in tqdm(dataset, disable=not is_rank_0()): - chosen = data['chosen'] + self.end_token - chosen_token = tokenizer(chosen, - max_length=max_length, - padding="max_length", - truncation=True, - return_tensors="pt") - self.chosen.append({ - "input_ids": chosen_token['input_ids'], - "attention_mask": chosen_token['attention_mask'] - }) - - reject = data['rejected'] + self.end_token - reject_token = tokenizer(reject, - max_length=max_length, - padding="max_length", - truncation=True, - return_tensors="pt") - self.reject.append({ - "input_ids": reject_token['input_ids'], - "attention_mask": reject_token['attention_mask'] - }) + self.end_token = tokenizer.eos_token \ + if special_token is None else special_token + + chosen = [ + data["chosen"] + self.end_token + for data in tqdm(dataset, disable=not is_rank_0()) + ] + chosen_token = tokenizer(chosen, + max_length=max_length, + padding="max_length", + truncation=True, + return_tensors="pt") + self.chosen = { + "input_ids": chosen_token["input_ids"], + "attention_mask": chosen_token["attention_mask"] + } + + reject = [ + data["rejected"] + self.end_token + for data in tqdm(dataset, disable=not is_rank_0()) + ] + reject_token = tokenizer(reject, + max_length=max_length, + padding="max_length", + truncation=True, + return_tensors="pt") + self.reject = { + "input_ids": reject_token["input_ids"], + "attention_mask": reject_token["attention_mask"] + } def __len__(self): - length = len(self.chosen) + length = self.chosen["input_ids"].shape[0] return length def __getitem__(self, idx): - return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][ - "input_ids"], self.reject[idx]["attention_mask"] + return self.chosen["input_ids"][idx], self.chosen["attention_mask"][idx], \ + self.reject["input_ids"][idx], self.reject["attention_mask"][idx] diff --git a/applications/Chat/coati/dataset/sft_dataset.py b/applications/Chat/coati/dataset/sft_dataset.py index 0b04cf79ee54..636b4e6772cb 100644 --- a/applications/Chat/coati/dataset/sft_dataset.py +++ b/applications/Chat/coati/dataset/sft_dataset.py @@ -13,44 +13,64 @@ # limitations under the License. import copy -import random -from dataclasses import dataclass, field -from typing import Callable, Dict, List, Sequence, Tuple +from typing import Dict, Sequence, Tuple import torch -import torch.distributed as dist -import transformers from torch.utils.data import Dataset from tqdm import tqdm +from transformers import PreTrainedTokenizer from colossalai.logging import get_dist_logger -from .conversation import default_conversation from .utils import is_rank_0, jload -# The following is a template prompt for a 4-round conversation. -""" -A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. - -Human: xxxAssistant: xxxHuman: xxxAssistant: xxxHuman: xxxAssistant: xxxHuman: xxxAssistant: xxx -""" -# Please note that we only calculate loss on assistant's answer tokens. - logger = get_dist_logger() IGNORE_INDEX = -100 -DEFAULT_EOS_TOKEN = "" PROMPT_DICT = { - "prompt_input": - ("Below is an instruction that describes a task, paired with an input that provides further context. " - "Write a response that appropriately completes the request.\n\n" - "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"), + "prompt_input": ("Below is an instruction that describes a task, paired with an input that provides further context. " + "Write a response that appropriately completes the request.\n\n" + "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"), "prompt_no_input": ("Below is an instruction that describes a task. " "Write a response that appropriately completes the request.\n\n" "### Instruction:\n{instruction}\n\n### Response:"), } +def _preprocess(sources: Sequence[str], + targets: Sequence[str], + tokenizer: PreTrainedTokenizer, + max_length: int, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Preprocess the data by tokenizing.""" + sequences = [s + t for s, t in zip(sources, targets)] + sequences_token = tokenizer(sequences, + max_length=max_length, + padding="max_length", + truncation=True, + return_tensors="pt") + sources_token = tokenizer(sources, + max_length=max_length, + padding="max_length", + truncation=True, + return_tensors="pt") + + labels = copy.deepcopy(sequences_token["input_ids"]) + for i in range(labels.shape[0]): + source_len = sources_token["attention_mask"][i].sum().item() + pad_len = max_length - sequences_token["attention_mask"][i].sum().item() + if tokenizer.padding_side == "right": + # |prompt|completion|eos|pad| + labels[i][:source_len] = IGNORE_INDEX + elif tokenizer.padding_side == "left": + # |pad|prompt|completion|eos| + labels[i][pad_len:pad_len + source_len] = IGNORE_INDEX + else: + raise RuntimeError() + + return sequences_token["input_ids"], labels, sequences_token["attention_mask"] + + class SFTDataset(Dataset): """ Dataset for sft model @@ -61,115 +81,31 @@ class SFTDataset(Dataset): max_length: max length of input """ - def __init__(self, dataset, tokenizer: Callable, max_length: int = 512) -> None: + def __init__(self, + dataset: Dict, + tokenizer: PreTrainedTokenizer, + max_length: int = 512 + ) -> None: super().__init__() self.input_ids = [] - for data in tqdm(dataset, disable=not is_rank_0()): - prompt = data['prompt'] + data['completion'] + tokenizer.eos_token - prompt_token = tokenizer(prompt, - max_length=max_length, - padding="max_length", - truncation=True, - return_tensors="pt") + sources = [data["prompt"] for data in dataset] + targets = [ + data["completion"] + tokenizer.eos_token + for data in tqdm(dataset, disable=not is_rank_0()) + ] - self.input_ids.append(prompt_token['input_ids'][0]) - self.labels = copy.deepcopy(self.input_ids) + self.input_ids, self.labels, self.attention_mask = \ + _preprocess(sources, targets, tokenizer, max_length) def __len__(self): - length = len(self.input_ids) + length = self.input_ids.shape[0] return length def __getitem__(self, idx): - return dict(input_ids=self.input_ids[idx], labels=self.labels[idx]) - - -def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, - max_length: int) -> Dict[str, torch.Tensor]: - """Tokenize a list of strings.""" - tokenized_list = tokenizer(strings, return_tensors="pt", padding="longest", max_length=max_length, truncation=True) - input_ids = labels = tokenized_list["input_ids"] - input_ids_lens = labels_lens = \ - tokenized_list["input_ids"].ne(tokenizer.pad_token_id).sum(dim=-1) - return dict( - input_ids=input_ids, - labels=labels, - input_ids_lens=input_ids_lens, - labels_lens=labels_lens, - ) - - -def preprocess( - sources: Sequence[str], - targets: Sequence[str], - tokenizer: transformers.PreTrainedTokenizer, - max_length: int, -) -> Dict: - """Preprocess the data by tokenizing.""" - examples = [s + t for s, t in zip(sources, targets)] - examples_tokenized, sources_tokenized = [ - _tokenize_fn(strings, tokenizer, max_length) for strings in (examples, sources) - ] - input_ids = examples_tokenized["input_ids"] - labels = copy.deepcopy(input_ids) - for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]): - label[:source_len] = IGNORE_INDEX - return dict(input_ids=input_ids, labels=labels) - - -def preprocess_conversation(sources: List[List[Dict]], tokenizer: transformers.PreTrainedTokenizer, - max_length: int) -> Dict: - """Preprocess the conversation data by tokenizing.""" - conversations = [] - intermediates = [] - for source in sources: - header = f"{default_conversation.system}" - conversation, intermediate = _add_speaker_and_signal(header, source) - conversations.append(conversation) - intermediates.append(intermediate) - - conversations_tokenized = _tokenize_fn(conversations, tokenizer, max_length) - input_ids = conversations_tokenized["input_ids"] - targets = copy.deepcopy(input_ids) - - assert len(targets) == len(intermediates) - for target, inters in zip(targets, intermediates): - mask = torch.zeros_like(target, dtype=torch.bool) - for inter in inters: - tokenized = _tokenize_fn(inter, tokenizer, max_length) - - start_idx = tokenized["input_ids"][0].size(0) - 1 - end_idx = tokenized["input_ids"][1].size(0) - - mask[start_idx:end_idx] = True - target[~mask] = IGNORE_INDEX - - return dict(input_ids=input_ids, labels=targets) - - -def _add_speaker_and_signal(header: str, - source: List[Dict], - get_conversation: bool = True) -> Tuple[str, List[List[str]]]: - END_SIGNAL = DEFAULT_EOS_TOKEN - conversation = header - intermediate = [] - for sentence in source: - from_str = sentence["from"] - if from_str.lower() == "human": - from_str = default_conversation.roles[0] - elif from_str.lower() == "gpt": - from_str = default_conversation.roles[1] - else: - from_str = 'unknown' - - value = from_str + ": " + sentence["value"] + END_SIGNAL - if sentence["from"].lower() == "gpt": - start = conversation + from_str + ": " - end = conversation + value - intermediate.append([start, end]) - if get_conversation: - conversation += value - return conversation, intermediate + return dict(input_ids=self.input_ids[idx], + labels=self.labels[idx], + attention_mask=self.attention_mask[idx]) class SupervisedDataset(Dataset): @@ -177,10 +113,10 @@ class SupervisedDataset(Dataset): def __init__(self, data_path: str, - tokenizer: transformers.PreTrainedTokenizer, + tokenizer: PreTrainedTokenizer, max_datasets_size: int = None, max_length: int = 512): - super(SupervisedDataset, self).__init__() + super().__init__() logger.info("Loading data...") list_data_dict = jload(data_path) logger.info(f"Loaded {len(list_data_dict)} examples.") @@ -190,52 +126,25 @@ def __init__(self, list_data_dict = list_data_dict[:max_datasets_size] logger.info("Formatting inputs...") - if "conversations" not in list_data_dict[0]: - prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"] - sources = [ - prompt_input.format_map(example) - if example.get("input", "") != "" else prompt_no_input.format_map(example) for example in list_data_dict - ] - targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict] - - if is_rank_0(): - logger.info("Tokenizing inputs... This may take some time...") - - data_dict = preprocess(sources, targets, tokenizer, max_length) - else: - if is_rank_0(): - logger.info("Tokenizing inputs... This may take some time...") - - sources = [conv["conversations"] for conv in list_data_dict] - data_dict = preprocess_conversation(sources, tokenizer, max_length) - - if is_rank_0(): - logger.info("Tokenizing finish.") - - self.input_ids = data_dict["input_ids"] - self.labels = data_dict["labels"] + prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"] + sources = [ + prompt_input.format_map(example) if "input" in example else prompt_no_input.format_map(example) + for example in list_data_dict + ] + targets = [ + example['output'] + tokenizer.eos_token + for example in list_data_dict + ] + + logger.info("Tokenizing inputs... This may take some time...") + self.input_ids, self.labels, self.attention_mask = \ + _preprocess(sources, targets, tokenizer, max_length) def __len__(self): - return len(self.input_ids) - - def __getitem__(self, i) -> Dict[str, torch.Tensor]: - return dict(input_ids=self.input_ids[i], labels=self.labels[i]) - - -@dataclass -class DataCollatorForSupervisedDataset(object): - """Collate examples for supervised fine-tuning.""" - - tokenizer: transformers.PreTrainedTokenizer + length = self.input_ids.shape[0] + return length - def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: - input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) - input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, - batch_first=True, - padding_value=self.tokenizer.pad_token_id) - labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) - return dict( - input_ids=input_ids, - labels=labels, - attention_mask=input_ids.ne(self.tokenizer.pad_token_id), - ) + def __getitem__(self, idx): + return dict(input_ids=self.input_ids[idx], + labels=self.labels[idx], + attention_mask=self.attention_mask[idx]) diff --git a/applications/Chat/coati/experience_buffer/__init__.py b/applications/Chat/coati/experience_buffer/__init__.py new file mode 100644 index 000000000000..c0188dc4a471 --- /dev/null +++ b/applications/Chat/coati/experience_buffer/__init__.py @@ -0,0 +1,4 @@ +from .base import ExperienceBuffer +from .naive import NaiveExperienceBuffer + +__all__ = ['ExperienceBuffer', 'NaiveExperienceBuffer'] diff --git a/applications/Chat/coati/replay_buffer/base.py b/applications/Chat/coati/experience_buffer/base.py similarity index 91% rename from applications/Chat/coati/replay_buffer/base.py rename to applications/Chat/coati/experience_buffer/base.py index 4c3812461a10..9ccdc935d506 100644 --- a/applications/Chat/coati/replay_buffer/base.py +++ b/applications/Chat/coati/experience_buffer/base.py @@ -4,8 +4,8 @@ from coati.experience_maker.base import Experience -class ReplayBuffer(ABC): - """Replay buffer base class. It stores experience. +class ExperienceBuffer(ABC): + """Experience buffer base class. It stores experience. Args: sample_batch_size (int): Batch size when sampling. diff --git a/applications/Chat/coati/replay_buffer/naive.py b/applications/Chat/coati/experience_buffer/naive.py similarity index 92% rename from applications/Chat/coati/replay_buffer/naive.py rename to applications/Chat/coati/experience_buffer/naive.py index 938f500643c9..bd5213b38993 100644 --- a/applications/Chat/coati/replay_buffer/naive.py +++ b/applications/Chat/coati/experience_buffer/naive.py @@ -4,12 +4,12 @@ import torch from coati.experience_maker.base import Experience -from .base import ReplayBuffer +from .base import ExperienceBuffer from .utils import BufferItem, make_experience_batch, split_experience_batch -class NaiveReplayBuffer(ReplayBuffer): - """Naive replay buffer class. It stores experience. +class NaiveExperienceBuffer(ExperienceBuffer): + """Naive experience buffer class. It stores experience. Args: sample_batch_size (int): Batch size when sampling. diff --git a/applications/Chat/coati/replay_buffer/utils.py b/applications/Chat/coati/experience_buffer/utils.py similarity index 83% rename from applications/Chat/coati/replay_buffer/utils.py rename to applications/Chat/coati/experience_buffer/utils.py index 6ad0db2c3b60..c2a34212e2f4 100644 --- a/applications/Chat/coati/replay_buffer/utils.py +++ b/applications/Chat/coati/experience_buffer/utils.py @@ -33,7 +33,8 @@ class BufferItem: def split_experience_batch(experience: Experience) -> List[BufferItem]: batch_size = experience.sequences.size(0) batch_kwargs = [{} for _ in range(batch_size)] - keys = ('sequences', 'action_log_probs', 'values', 'reward', 'advantages', 'attention_mask', 'action_mask') + keys = ('sequences', 'action_log_probs', 'values', + 'reward', 'advantages', 'attention_mask', 'action_mask') for key in keys: value = getattr(experience, key) if isinstance(value, torch.Tensor): @@ -48,7 +49,7 @@ def split_experience_batch(experience: Experience) -> List[BufferItem]: return items -def zero_pad_sequences(sequences: List[torch.Tensor], side: str = 'left') -> torch.Tensor: +def _zero_pad_sequences(sequences: List[torch.Tensor], side: str = 'left') -> torch.Tensor: assert side in ('left', 'right') max_len = max(seq.size(0) for seq in sequences) padded_sequences = [] @@ -62,11 +63,12 @@ def zero_pad_sequences(sequences: List[torch.Tensor], side: str = 'left') -> tor def make_experience_batch(items: List[BufferItem]) -> Experience: kwargs = {} to_pad_keys = set(('action_log_probs', 'action_mask')) - keys = ('sequences', 'action_log_probs', 'values', 'reward', 'advantages', 'attention_mask', 'action_mask') + keys = ('sequences', 'action_log_probs', 'values', + 'reward', 'advantages', 'attention_mask', 'action_mask') for key in keys: vals = [getattr(item, key) for item in items] if key in to_pad_keys: - batch_data = zero_pad_sequences(vals) + batch_data = _zero_pad_sequences(vals) else: batch_data = torch.stack(vals, dim=0) kwargs[key] = batch_data diff --git a/applications/Chat/coati/experience_maker/naive.py b/applications/Chat/coati/experience_maker/naive.py index e5bb029e63d0..496f8ab445fc 100644 --- a/applications/Chat/coati/experience_maker/naive.py +++ b/applications/Chat/coati/experience_maker/naive.py @@ -1,6 +1,7 @@ import torch -from coati.models.generation import generate_with_actor -from coati.models.utils import calc_action_log_probs, compute_reward, normalize +import torch.nn.functional as F +from coati.models.generation import generate +from coati.models.utils import calc_action_log_probs, compute_reward from .base import Experience, ExperienceMaker @@ -17,10 +18,27 @@ def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experie self.initial_model.eval() self.reward_model.eval() - sequences, attention_mask, action_mask = generate_with_actor(self.actor, - input_ids, - return_action_mask=True, - **generate_kwargs) + # generate sequences + sequences = generate(self.actor, input_ids, **generate_kwargs) + + # calculate auxiliary tensors + attention_mask = None + pad_token_id = generate_kwargs.get('pad_token_id', None) + if pad_token_id is not None: + attention_mask = sequences.not_equal(pad_token_id)\ + .to(dtype=torch.long, device=sequences.device) + + input_len = input_ids.size(1) + eos_token_id = generate_kwargs.get('eos_token_id', None) + if eos_token_id is None: + action_mask = torch.ones_like(sequences, dtype=torch.bool) + else: + # left padding may be applied, only mask action + action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0 + action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input + action_mask[:, :input_len] = False + action_mask = action_mask[:, 1:] + action_mask = action_mask[:, -(sequences.size(1) - input_len):] num_actions = action_mask.size(1) actor_output = self.actor(sequences, attention_mask) diff --git a/applications/Chat/coati/models/__init__.py b/applications/Chat/coati/models/__init__.py index 709bc5ac0948..0a296a863756 100644 --- a/applications/Chat/coati/models/__init__.py +++ b/applications/Chat/coati/models/__init__.py @@ -1,8 +1,8 @@ from .base import Actor, Critic, RewardModel from .lora import LoRAModule, convert_to_lora_module -from .loss import LogExpLoss, LogSigLoss, PolicyLoss, PPOPtxActorLoss, ValueLoss +from .loss import LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss __all__ = [ - 'Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'PPOPtxActorLoss', 'LogSigLoss', 'LogExpLoss', + 'Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'LogSigLoss', 'LogExpLoss', 'LoRAModule', 'convert_to_lora_module' ] diff --git a/applications/Chat/coati/models/bloom/bloom_critic.py b/applications/Chat/coati/models/bloom/bloom_critic.py index a32fb2e102f9..a3716ca94138 100644 --- a/applications/Chat/coati/models/bloom/bloom_critic.py +++ b/applications/Chat/coati/models/bloom/bloom_critic.py @@ -14,7 +14,6 @@ class BLOOMCritic(Critic): Args: pretrained (str): Pretrained model name or path. config (BloomConfig): Model config. - checkpoint (bool): Enable gradient checkpointing. lora_rank (int): LoRA rank. lora_train_bias (str): LoRA bias training mode. """ @@ -22,7 +21,6 @@ class BLOOMCritic(Critic): def __init__(self, pretrained: str = None, config: Optional[BloomConfig] = None, - checkpoint: bool = False, lora_rank: int = 0, lora_train_bias: str = 'none', **kwargs) -> None: @@ -32,7 +30,6 @@ def __init__(self, model = BloomModel(config) else: model = BloomModel(BloomConfig()) - if checkpoint: - model.gradient_checkpointing_enable() + value_head = nn.Linear(model.config.hidden_size, 1) super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs) diff --git a/applications/Chat/coati/models/bloom/bloom_rm.py b/applications/Chat/coati/models/bloom/bloom_rm.py index 22cfab441abb..e6ca9b1d4851 100644 --- a/applications/Chat/coati/models/bloom/bloom_rm.py +++ b/applications/Chat/coati/models/bloom/bloom_rm.py @@ -13,7 +13,6 @@ class BLOOMRM(RewardModel): Args: pretrained (str): Pretrained model name or path. config (BloomConfig): Model config. - checkpoint (bool): Enable gradient checkpointing. lora_rank (int): LoRA rank. lora_train_bias (str): LoRA bias training mode. """ @@ -21,7 +20,6 @@ class BLOOMRM(RewardModel): def __init__(self, pretrained: str = None, config: Optional[BloomConfig] = None, - checkpoint: bool = False, lora_rank: int = 0, lora_train_bias: str = 'none') -> None: if pretrained is not None: @@ -30,8 +28,7 @@ def __init__(self, model = BloomModel(config) else: model = BloomModel(BloomConfig()) - if checkpoint: - model.gradient_checkpointing_enable() + value_head = nn.Linear(model.config.hidden_size, 1) value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.hidden_size + 1)) super().__init__(model, value_head, lora_rank, lora_train_bias) diff --git a/applications/Chat/coati/models/generation.py b/applications/Chat/coati/models/generation.py index d96ad78a89ce..de0d63f95f50 100644 --- a/applications/Chat/coati/models/generation.py +++ b/applications/Chat/coati/models/generation.py @@ -1,9 +1,9 @@ -from typing import Any, Callable, Optional, Tuple, Union +from typing import Any, Callable, Optional import torch import torch.distributed as dist -import torch.nn as nn -import torch.nn.functional as F + +from .base import Actor try: from transformers.generation_logits_process import ( @@ -16,9 +16,9 @@ from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper -def prepare_logits_processor(top_k: Optional[int] = None, - top_p: Optional[float] = None, - temperature: Optional[float] = None) -> LogitsProcessorList: +def _prepare_logits_processor(top_k: Optional[int] = None, + top_p: Optional[float] = None, + temperature: Optional[float] = None) -> LogitsProcessorList: processor_list = LogitsProcessorList() if temperature is not None and temperature != 1.0: processor_list.append(TemperatureLogitsWarper(temperature)) @@ -37,22 +37,22 @@ def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool: return unfinished_sequences.max() == 0 -def sample(model: nn.Module, - input_ids: torch.Tensor, - max_length: int, - early_stopping: bool = False, - eos_token_id: Optional[int] = None, - pad_token_id: Optional[int] = None, - top_k: Optional[int] = None, - top_p: Optional[float] = None, - temperature: Optional[float] = None, - prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None, - update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None, - **model_kwargs) -> torch.Tensor: +def _sample(model: Actor, + input_ids: torch.Tensor, + max_length: int, + early_stopping: bool = False, + eos_token_id: Optional[int] = None, + pad_token_id: Optional[int] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + temperature: Optional[float] = None, + prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None, + update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None, + **model_kwargs) -> torch.Tensor: if input_ids.size(1) >= max_length: return input_ids - logits_processor = prepare_logits_processor(top_k, top_p, temperature) + logits_processor = _prepare_logits_processor(top_k, top_p, temperature) unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) for _ in range(input_ids.size(1), max_length): @@ -89,7 +89,8 @@ def sample(model: nn.Module, return input_ids -def generate(model: nn.Module, +@torch.no_grad() +def generate(model: Actor, input_ids: torch.Tensor, max_length: int, num_beams: int = 1, @@ -128,51 +129,19 @@ def generate(model: nn.Module, raise NotImplementedError elif is_sample_gen_mode: # run sample - return sample(model, - input_ids, - max_length, - early_stopping=early_stopping, - eos_token_id=eos_token_id, - pad_token_id=pad_token_id, - top_k=top_k, - top_p=top_p, - temperature=temperature, - prepare_inputs_fn=prepare_inputs_fn, - update_model_kwargs_fn=update_model_kwargs_fn, - **model_kwargs) + return _sample(model, + input_ids, + max_length, + early_stopping=early_stopping, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + top_k=top_k, + top_p=top_p, + temperature=temperature, + prepare_inputs_fn=prepare_inputs_fn, + update_model_kwargs_fn=update_model_kwargs_fn, + **model_kwargs) elif is_beam_gen_mode: raise NotImplementedError else: raise ValueError("Unsupported generation mode") - - -@torch.no_grad() -def generate_with_actor( - actor_model: nn.Module, - input_ids: torch.Tensor, - return_action_mask: bool = True, - **kwargs -) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]: - """Generate token sequence with actor model. Refer to `generate` for more details. - """ - # generate sequences - sequences = generate(actor_model, input_ids, **kwargs) - - # calculate auxiliary tensors - attention_mask = None - pad_token_id = kwargs.get('pad_token_id', None) - if pad_token_id is not None: - attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device) - if not return_action_mask: - return sequences, attention_mask, None - input_len = input_ids.size(1) - eos_token_id = kwargs.get('eos_token_id', None) - if eos_token_id is None: - action_mask = torch.ones_like(sequences, dtype=torch.bool) - else: - # left padding may be applied, only mask action - action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0 - action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input - action_mask[:, :input_len] = False - action_mask = action_mask[:, 1:] - return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len):] diff --git a/applications/Chat/coati/models/gpt/gpt_critic.py b/applications/Chat/coati/models/gpt/gpt_critic.py index 2e70f5f1fc96..01e1cd10ef57 100644 --- a/applications/Chat/coati/models/gpt/gpt_critic.py +++ b/applications/Chat/coati/models/gpt/gpt_critic.py @@ -14,7 +14,6 @@ class GPTCritic(Critic): Args: pretrained (str): Pretrained model name or path. config (GPT2Config): Model config. - checkpoint (bool): Enable gradient checkpointing. lora_rank (int): Rank of the LO-RA decomposition. lora_train_bias (str): LoRA bias training mode. """ @@ -22,7 +21,6 @@ class GPTCritic(Critic): def __init__(self, pretrained: Optional[str] = None, config: Optional[GPT2Config] = None, - checkpoint: bool = False, lora_rank: int = 0, lora_train_bias: str = 'none', **kwargs) -> None: @@ -32,7 +30,6 @@ def __init__(self, model = GPT2Model(config) else: model = GPT2Model(GPT2Config()) - if checkpoint: - model.gradient_checkpointing_enable() + value_head = nn.Linear(model.config.n_embd, 1) super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs) diff --git a/applications/Chat/coati/models/gpt/gpt_rm.py b/applications/Chat/coati/models/gpt/gpt_rm.py index 054432e1ce86..e52a5a14c1da 100644 --- a/applications/Chat/coati/models/gpt/gpt_rm.py +++ b/applications/Chat/coati/models/gpt/gpt_rm.py @@ -14,7 +14,6 @@ class GPTRM(RewardModel): Args: pretrained (str): Pretrained model name or path. config (GPT2Config): Model config. - checkpoint (bool): Enable gradient checkpointing. lora_rank (int): Rank of the low-rank approximation. lora_train_bias (str): LoRA bias training mode. """ @@ -22,7 +21,6 @@ class GPTRM(RewardModel): def __init__(self, pretrained: Optional[str] = None, config: Optional[GPT2Config] = None, - checkpoint: bool = False, lora_rank: int = 0, lora_train_bias: str = 'none') -> None: if pretrained is not None: @@ -31,8 +29,6 @@ def __init__(self, model = GPT2Model(config) else: model = GPT2Model(GPT2Config()) - if checkpoint: - model.gradient_checkpointing_enable() value_head = nn.Linear(model.config.n_embd, 1) value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.n_embd + 1)) diff --git a/applications/Chat/coati/models/llama/llama_critic.py b/applications/Chat/coati/models/llama/llama_critic.py index dd9e5e7bfa1a..a67e5de5def6 100644 --- a/applications/Chat/coati/models/llama/llama_critic.py +++ b/applications/Chat/coati/models/llama/llama_critic.py @@ -13,7 +13,6 @@ class LlamaCritic(Critic): Args: pretrained (str): Pretrained model name or path. config (LlamaConfig): Model config. - checkpoint (bool): Enable gradient checkpointing. lora_rank (int): LoRA rank. lora_train_bias (str): LoRA bias training mode. """ @@ -21,7 +20,6 @@ class LlamaCritic(Critic): def __init__(self, pretrained: Optional[str] = None, config: Optional[LlamaConfig] = None, - checkpoint: bool = False, lora_rank: int = 0, lora_train_bias: str = 'none', **kwargs) -> None: @@ -33,9 +31,5 @@ def __init__(self, else: model = LlamaModel(LlamaConfig()) - if checkpoint: - model.gradient_checkpointing_enable() - value_head = nn.Linear(model.config.hidden_size, 1) - super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs) diff --git a/applications/Chat/coati/models/llama/llama_rm.py b/applications/Chat/coati/models/llama/llama_rm.py index f936019d62d2..d6b62922686e 100644 --- a/applications/Chat/coati/models/llama/llama_rm.py +++ b/applications/Chat/coati/models/llama/llama_rm.py @@ -13,7 +13,6 @@ class LlamaRM(RewardModel): Args: pretrained (str): Pretrained model name or path. config (LlamaConfig): Model config. - checkpoint (bool): Enable gradient checkpointing. lora_rank (int): LoRA rank. lora_train_bias (str): LoRA bias training mode. """ @@ -21,7 +20,6 @@ class LlamaRM(RewardModel): def __init__(self, pretrained: Optional[str] = None, config: Optional[LlamaConfig] = None, - checkpoint: bool = False, lora_rank: int = 0, lora_train_bias: str = 'none') -> None: @@ -32,8 +30,6 @@ def __init__(self, else: model = LlamaModel(LlamaConfig()) - if checkpoint: - model.gradient_checkpointing_enable() value_head = nn.Linear(model.config.hidden_size, 1) value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.hidden_size + 1)) diff --git a/applications/Chat/coati/models/lora.py b/applications/Chat/coati/models/lora.py index 2a9059e6901e..546f675d7d37 100644 --- a/applications/Chat/coati/models/lora.py +++ b/applications/Chat/coati/models/lora.py @@ -98,18 +98,18 @@ def T(w): return F.linear(x, T(self.weight), bias=self.bias) -def lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear: +def _lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear: assert lora_rank <= linear.in_features, f'LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})' lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank, merge_weights=False) return lora_linear -def convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None: +def _convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None: for name, child in module.named_children(): if isinstance(child, nn.Linear): - setattr(module, name, lora_linear_wrapper(child, lora_rank)) + setattr(module, name, _lora_linear_wrapper(child, lora_rank)) else: - convert_to_lora_recursively(child, lora_rank) + _convert_to_lora_recursively(child, lora_rank) def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: str = 'none') -> nn.Module: @@ -124,7 +124,7 @@ def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: s """ if lora_rank <= 0: return module - convert_to_lora_recursively(module, lora_rank) + _convert_to_lora_recursively(module, lora_rank) lora.mark_only_lora_as_trainable(module, lora_train_bias) return module diff --git a/applications/Chat/coati/models/loss.py b/applications/Chat/coati/models/loss.py index 926c6e2a4e41..05a0b4821797 100644 --- a/applications/Chat/coati/models/loss.py +++ b/applications/Chat/coati/models/loss.py @@ -68,31 +68,6 @@ def forward(self, return 0.5 * loss -class PPOPtxActorLoss(nn.Module): - """ - To Do: - - PPO-ptx Actor Loss - """ - - def __init__(self, policy_clip_eps: float = 0.2, pretrain_coef: float = 0.0, pretrain_loss_fn=GPTLMLoss()) -> None: - super().__init__() - self.pretrain_coef = pretrain_coef - self.policy_loss_fn = PolicyLoss(clip_eps=policy_clip_eps) - self.pretrain_loss_fn = pretrain_loss_fn - - def forward(self, - log_probs: torch.Tensor, - old_log_probs: torch.Tensor, - advantages: torch.Tensor, - lm_logits: torch.Tensor, - lm_input_ids: torch.Tensor, - action_mask: Optional[torch.Tensor] = None) -> torch.Tensor: - policy_loss = self.policy_loss_fn(log_probs, old_log_probs, advantages, action_mask=action_mask) - lm_loss = self.pretrain_loss_fn(lm_logits, lm_input_ids) - return policy_loss + self.pretrain_coef * lm_loss - - class LogSigLoss(nn.Module): """ Pairwise Loss for Reward Model diff --git a/applications/Chat/coati/models/opt/opt_critic.py b/applications/Chat/coati/models/opt/opt_critic.py index fcfebd8a8b03..f66c4173fa52 100644 --- a/applications/Chat/coati/models/opt/opt_critic.py +++ b/applications/Chat/coati/models/opt/opt_critic.py @@ -14,7 +14,6 @@ class OPTCritic(Critic): Args: pretrained (str): Pretrained model name or path. config (OPTConfig): Model config. - checkpoint (bool): Enable gradient checkpointing. lora_rank (int): Rank of the low-rank approximation. lora_train_bias (str): LoRA bias training mode. """ @@ -22,7 +21,6 @@ class OPTCritic(Critic): def __init__(self, pretrained: Optional[str] = None, config: Optional[OPTConfig] = None, - checkpoint: bool = False, lora_rank: int = 0, lora_train_bias: str = 'none', **kwargs) -> None: @@ -32,7 +30,6 @@ def __init__(self, model = OPTModel(config) else: model = OPTModel(OPTConfig()) - if checkpoint: - model.gradient_checkpointing_enable() + value_head = nn.Linear(model.config.word_embed_proj_dim, 1) super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs) diff --git a/applications/Chat/coati/models/opt/opt_rm.py b/applications/Chat/coati/models/opt/opt_rm.py index 50fc0dee8568..6f75344e6aae 100644 --- a/applications/Chat/coati/models/opt/opt_rm.py +++ b/applications/Chat/coati/models/opt/opt_rm.py @@ -13,7 +13,6 @@ class OPTRM(RewardModel): Args: pretrained (str): Pretrained model name or path. config (OPTConfig): Model config. - checkpoint (bool): Enable gradient checkpointing. lora_rank (int): Rank of the low-rank approximation. lora_train_bias (str): LoRA bias training mode. """ @@ -21,7 +20,6 @@ class OPTRM(RewardModel): def __init__(self, pretrained: Optional[str] = None, config: Optional[OPTConfig] = None, - checkpoint: bool = False, lora_rank: int = 0, lora_train_bias: str = 'none') -> None: if pretrained is not None: @@ -30,8 +28,6 @@ def __init__(self, model = OPTModel(config) else: model = OPTModel(OPTConfig()) - if checkpoint: - model.gradient_checkpointing_enable() value_head = nn.Linear(model.config.word_embed_proj_dim, 1) value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.word_embed_proj_dim + 1)) diff --git a/applications/Chat/coati/models/utils.py b/applications/Chat/coati/models/utils.py index 772bfc32982a..97637d3523b0 100644 --- a/applications/Chat/coati/models/utils.py +++ b/applications/Chat/coati/models/utils.py @@ -1,14 +1,12 @@ from typing import Optional, Union -import loralib as lora import torch -import torch.nn as nn import torch.nn.functional as F -def compute_approx_kl(log_probs: torch.Tensor, - log_probs_base: torch.Tensor, - action_mask: Optional[torch.Tensor] = None) -> torch.Tensor: +def _compute_approx_kl(log_probs: torch.Tensor, + log_probs_base: torch.Tensor, + action_mask: Optional[torch.Tensor] = None) -> torch.Tensor: """ Compute the approximate KL divergence between two distributions. Schulman blog: http://joschu.net/blog/kl-approx.html @@ -19,7 +17,7 @@ def compute_approx_kl(log_probs: torch.Tensor, action_mask: Mask for actions. """ - log_ratio = log_probs - log_probs_base + log_ratio = log_probs_base - log_probs approx_kl = (log_ratio.exp() - 1) - log_ratio if action_mask is not None: approx_kl = masked_mean(approx_kl, action_mask, dim=1) @@ -35,12 +33,12 @@ def compute_reward(r: Union[torch.Tensor, float], action_mask: Optional[torch.Tensor] = None) -> torch.Tensor: if kl_coef <= 0.0: return r - kl = compute_approx_kl(log_probs, log_probs_base, action_mask=action_mask) + kl = _compute_approx_kl(log_probs, log_probs_base, action_mask=action_mask) reward = r - kl_coef * kl return reward -def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: +def _log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: log_probs = F.log_softmax(logits, dim=-1) log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)) return log_probs_labels.squeeze(-1) @@ -58,7 +56,7 @@ def calc_action_log_probs(output: torch.Tensor, sequences: torch.LongTensor, num torch.Tensor: Action log probs. """ logits = output['logits'] - log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:]) + log_probs = _log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:]) return log_probs[:, -num_actions:] @@ -68,41 +66,3 @@ def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch mask_sum = mask.sum(dim=dim) mean = tensor / (mask_sum + 1e-8) return mean - - -def masked_normalize(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1, eps: float = 1e-8) -> torch.Tensor: - tensor = tensor * mask - mean = masked_mean(tensor, mask, dim=dim) - mean_centered = tensor - mean - var = masked_mean(mean_centered**2, mask, dim=dim) - return mean_centered * var.clamp(min=eps).rsqrt() - - -def normalize(tensor: torch.Tensor, dim: int = 0, eps: float = 1e-8) -> torch.Tensor: - mean = tensor.mean(dim) - mean_centered = tensor - mean - var = (mean_centered**2).mean(dim) - norm = mean_centered * var.clamp(min=eps).rsqrt() - return norm - - -def convert_to_lora(model: nn.Module, - input_size: int, - output_size: int, - lora_rank: int = 16, - lora_alpha: int = 1, - lora_dropout: float = 0., - fan_in_fan_out: bool = False, - merge_weights: bool = True): - if lora_rank > min(input_size, output_size): - raise ValueError(f"LoRA rank {lora_rank} must be less or equal than {min(input_size, output_size)}") - - for name, module in model.named_modules(): - if isinstance(module, nn.Linear): - module._modules[name] = lora.Linear(input_size, - output_size, - r=lora_rank, - lora_alpha=lora_alpha, - lora_dropout=lora_dropout, - fan_in_fan_out=fan_in_fan_out, - merge_weights=merge_weights) diff --git a/applications/Chat/coati/ray/callbacks/performance_evaluator.py b/applications/Chat/coati/ray/callbacks/performance_evaluator.py index cd3517609e7a..d3df8f9ae3e0 100644 --- a/applications/Chat/coati/ray/callbacks/performance_evaluator.py +++ b/applications/Chat/coati/ray/callbacks/performance_evaluator.py @@ -115,12 +115,12 @@ def on_loop_end(self) -> None: avg_send_time_per_sample = (avg_send_duration + 1e-12) / (self.total_samples * self.world_size) print_rank_0( - 'Making Experience Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n' + - f'TFLOPS per GPU: {avg_make_experience_tflops:.3f}\n' + - f'Sample time (overall): {avg_time_per_sample:.3f} s\n' + - f'Sample time (make experience): {avg_make_experience_time_per_sample:.3f} s, {avg_make_experience_time_per_sample/avg_time_per_sample*100:.2f}%\n' - + - f'Sample time (send): {avg_send_time_per_sample:.3f} s, {avg_send_time_per_sample/avg_time_per_sample*100:.2f}%\n' + 'Making Experience Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n' + + f'TFLOPS per GPU: {avg_make_experience_tflops:.3f}\n' + + f'Sample time (overall): {avg_time_per_sample:.3f} s\n' + + f'Sample time (make experience): {avg_make_experience_time_per_sample:.3f} s, {avg_make_experience_time_per_sample/avg_time_per_sample*100:.2f}%\n' + + + f'Sample time (send): {avg_send_time_per_sample:.3f} s, {avg_send_time_per_sample/avg_time_per_sample*100:.2f}%\n' ) @@ -204,9 +204,9 @@ def on_fit_end(self) -> None: avg_update_time_per_sample = (avg_update_duration + 1e-12) / (self.total_samples * self.world_size) print_rank_0( - 'Learning Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n' + - f'TFLOPS per GPU: {avg_learn_tflops:.3f}\n' + f'Sample time (overall): {avg_time_per_sample:.3f} s\n' + - f'Sample time (train): {avg_train_time_per_sample:.3f} s, {avg_train_time_per_sample/avg_time_per_sample*100:.2f}%\n' - + - f'Sample time (update): {avg_update_time_per_sample:.3f} s, {avg_update_time_per_sample/avg_time_per_sample*100:.2f}%\n' + 'Learning Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n' + + f'TFLOPS per GPU: {avg_learn_tflops:.3f}\n' + f'Sample time (overall): {avg_time_per_sample:.3f} s\n' + + f'Sample time (train): {avg_train_time_per_sample:.3f} s, {avg_train_time_per_sample/avg_time_per_sample*100:.2f}%\n' + + + f'Sample time (update): {avg_update_time_per_sample:.3f} s, {avg_update_time_per_sample/avg_time_per_sample*100:.2f}%\n' ) diff --git a/applications/Chat/coati/ray/detached_replay_buffer.py b/applications/Chat/coati/ray/detached_replay_buffer.py index 2f765281178a..7b9df2ee139b 100644 --- a/applications/Chat/coati/ray/detached_replay_buffer.py +++ b/applications/Chat/coati/ray/detached_replay_buffer.py @@ -6,9 +6,9 @@ import ray import torch +from coati.experience_buffer import ExperienceBuffer +from coati.experience_buffer.utils import BufferItem, make_experience_batch, split_experience_batch from coati.experience_maker.base import Experience -from coati.replay_buffer import ReplayBuffer -from coati.replay_buffer.utils import BufferItem, make_experience_batch, split_experience_batch # from torch.multiprocessing import Queue from ray.util.queue import Queue diff --git a/applications/Chat/coati/ray/detached_trainer_base.py b/applications/Chat/coati/ray/detached_trainer_base.py index ac2d35e9da19..90399781187a 100644 --- a/applications/Chat/coati/ray/detached_trainer_base.py +++ b/applications/Chat/coati/ray/detached_trainer_base.py @@ -4,8 +4,8 @@ import ray import torch +from coati.experience_buffer.utils import BufferItem from coati.experience_maker import Experience -from coati.replay_buffer.utils import BufferItem from torch.utils.data import DataLoader from tqdm import tqdm diff --git a/applications/Chat/coati/ray/experience_maker_holder.py b/applications/Chat/coati/ray/experience_maker_holder.py index 07d9c3e4f396..13314bdafd5f 100644 --- a/applications/Chat/coati/ray/experience_maker_holder.py +++ b/applications/Chat/coati/ray/experience_maker_holder.py @@ -8,9 +8,9 @@ import ray import torch import torch.nn as nn +from coati.experience_buffer.utils import BufferItem, make_experience_batch, split_experience_batch from coati.experience_maker import Experience, ExperienceMaker, NaiveExperienceMaker from coati.models.base import Actor, Critic, RewardModel -from coati.replay_buffer.utils import BufferItem, make_experience_batch, split_experience_batch from coati.trainer.callbacks import Callback from coati.trainer.strategies import Strategy from coati.trainer.strategies.sampler import DistributedSampler @@ -19,13 +19,9 @@ from tqdm import tqdm from .callbacks import ExperienceMakerPerformanceEvaluator, MakerCallback -from .utils import (get_model_numel, - get_rank, - get_world_size, - is_rank_0, - set_dist_env, - state_dict_to) from .lora_constructor import LoRAConstructor +from .utils import get_model_numel, get_rank, get_world_size, is_rank_0, set_dist_env, state_dict_to + @ray.remote(concurrency_groups={"experience_io": 1, "model_io": 1, "compute": 1}) class ExperienceMakerHolder: @@ -41,7 +37,7 @@ def __init__( self, detached_trainer_name_list: List[str], strategy_fn: Callable[[], Strategy], - # a function returns (actor, critic, reward_model, initial_model) + # a function returns (actor, critic, reward_model, initial_model) model_fn: Callable[[], Tuple[Actor, Critic, RewardModel, Actor]], env_info: Dict[str, str] = None, sync_models_from_trainers: bool = False, @@ -205,15 +201,19 @@ def update_experience_maker(self, self.experience_maker.actor.model.load_state_dict(new_actor_state_dict, strict=False) else: new_actor_state_dict = state_dict_to(new_actor_state_dict, device=torch.cuda.current_device()) - state_dict_increase = self.actor_lora_constructor.reconstruct_increase(new_actor_state_dict, new_actor_lora_config_dict) - self.actor_lora_constructor.load_state_dict_increase(self.experience_maker.actor.model, state_dict_increase) + state_dict_increase = self.actor_lora_constructor.reconstruct_increase( + new_actor_state_dict, new_actor_lora_config_dict) + self.actor_lora_constructor.load_state_dict_increase( + self.experience_maker.actor.model, state_dict_increase) if new_critic_state_dict is not None: if not self._update_lora_weights or fully_update: self.experience_maker.critic.load_state_dict(new_critic_state_dict, strict=False) else: new_critic_state_dict = state_dict_to(new_critic_state_dict, device=torch.cuda.current_device()) - state_dict_increase = self.critic_lora_constructor.reconstruct_increase(new_critic_state_dict, new_critic_lora_config_dict) - self.critic_lora_constructor.load_state_dict_increase(self.experience_maker.critic, state_dict_increase) + state_dict_increase = self.critic_lora_constructor.reconstruct_increase( + new_critic_state_dict, new_critic_lora_config_dict) + self.critic_lora_constructor.load_state_dict_increase( + self.experience_maker.critic, state_dict_increase) # the lock must be released after both actor and critic being updated if chunk_end: diff --git a/applications/Chat/coati/ray/lora_constructor.py b/applications/Chat/coati/ray/lora_constructor.py index 4809617f647b..a98545d4d751 100644 --- a/applications/Chat/coati/ray/lora_constructor.py +++ b/applications/Chat/coati/ray/lora_constructor.py @@ -1,11 +1,11 @@ -from typing import Any, Callable, Dict, List, Optional from collections import OrderedDict from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional import torch import torch.nn as nn -from loralib.layers import LoRALayer from coati.models.lora import LoraLinear +from loralib.layers import LoRALayer @dataclass @@ -23,19 +23,19 @@ class LoRAConstructor: Usage: Step 1 (Sender): filter_state_dict_lora() - + Step 2 (Sender, Optional): extract_lora_config() - + Step 3 (Sender): send state_dict_lora and lora_config_dict - + Step 4 (Receiver): reconstruct_increase() - + Step 5 (Receiver): load_state_dict_increase() - + ''' def __init__(self): diff --git a/applications/Chat/coati/replay_buffer/__init__.py b/applications/Chat/coati/replay_buffer/__init__.py deleted file mode 100644 index 1ebf60382913..000000000000 --- a/applications/Chat/coati/replay_buffer/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .base import ReplayBuffer -from .naive import NaiveReplayBuffer - -__all__ = ['ReplayBuffer', 'NaiveReplayBuffer'] diff --git a/applications/Chat/coati/trainer/base.py b/applications/Chat/coati/trainer/base.py index b4d168a563d9..0629c9c00cca 100644 --- a/applications/Chat/coati/trainer/base.py +++ b/applications/Chat/coati/trainer/base.py @@ -4,8 +4,8 @@ import torch.nn as nn import tqdm +from coati.experience_buffer import NaiveExperienceBuffer from coati.experience_maker import Experience -from coati.replay_buffer import NaiveReplayBuffer from torch.optim import Optimizer from torch.utils.data import DataLoader @@ -62,7 +62,7 @@ class OnPolicyTrainer(ABC): Args: strategy (Strategy):the strategy to use for training - buffer (NaiveReplayBuffer): the buffer to collect experiences + data_buffer (NaiveExperienceBuffer): the buffer to collect experiences sample_buffer (bool, defaults to False): whether to sample from buffer dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader callbacks (List[Callback], defaults to []): the callbacks to call during training process @@ -70,13 +70,13 @@ class OnPolicyTrainer(ABC): def __init__(self, strategy: Strategy, - buffer: NaiveReplayBuffer, + data_buffer: NaiveExperienceBuffer, sample_buffer: bool, dataloader_pin_memory: bool, callbacks: List[Callback] = []) -> None: super().__init__() self.strategy = strategy - self.buffer = buffer + self.data_buffer = data_buffer self.sample_buffer = sample_buffer self.dataloader_pin_memory = dataloader_pin_memory self.callbacks = callbacks @@ -144,7 +144,7 @@ def _collect_phase(self, collect_step: int): self._on_make_experience_start() experience = self._make_experience(collect_step) self._on_make_experience_end(experience) - self.buffer.append(experience) + self.data_buffer.append(experience) def _update_phase(self, update_step: int): self._on_learn_epoch_start(update_step) @@ -181,8 +181,8 @@ def fit( # HACK(cwher): according to the design of boost API, dataloader should also be boosted, # but it is impractical to adapt this pattern in RL training. Thus, I left dataloader unboosted. # I only call strategy.setup_dataloader() to setup dataloader. - self.dataloader = self.strategy.setup_dataloader(self.buffer, self.dataloader_pin_memory) + self.dataloader = self.strategy.setup_dataloader(self.data_buffer, self.dataloader_pin_memory) for update_step in tqdm.trange(num_update_steps, desc="Update steps", disable=not is_rank_0()): self._update_phase(update_step) # NOTE: this is for on-policy algorithms - self.buffer.clear() + self.data_buffer.clear() diff --git a/applications/Chat/coati/trainer/callbacks/performance_evaluator.py b/applications/Chat/coati/trainer/callbacks/performance_evaluator.py index 925455444597..9b44dafa7eaa 100644 --- a/applications/Chat/coati/trainer/callbacks/performance_evaluator.py +++ b/applications/Chat/coati/trainer/callbacks/performance_evaluator.py @@ -171,13 +171,13 @@ def on_fit_end(self) -> None: learn_time_per_sample = divide(avg_learn_duration, num_effective_samples) print_rank_0( - f'Performance summary:\n' + - f'Generate {self.make_experience_num_samples * self.world_size} samples, throughput: {avg_make_experience_throughput:.2f} samples/s, TFLOPS per GPU: {avg_make_experience_tflops:.2f}\n' - + - f'Train {self.learn_num_samples * self.world_size} samples, throughput: {avg_learn_throughput:.2f} samples/s, TFLOPS per GPU: {avg_learn_tflops:.2f}\n' - + f'Overall throughput: {avg_overall_throughput:.2f} samples/s\n' + - f'Overall time per sample: {overall_time_per_sample:.2f} s\n' + - f'Make experience time per sample: {make_experience_time_per_sample:.2f} s, {make_experience_time_per_sample/overall_time_per_sample*100:.2f}%\n' - + - f'Learn time per sample: {learn_time_per_sample:.2f} s, {learn_time_per_sample/overall_time_per_sample*100:.2f}%' + f'Performance summary:\n' + + f'Generate {self.make_experience_num_samples * self.world_size} samples, throughput: {avg_make_experience_throughput:.2f} samples/s, TFLOPS per GPU: {avg_make_experience_tflops:.2f}\n' + + + f'Train {self.learn_num_samples * self.world_size} samples, throughput: {avg_learn_throughput:.2f} samples/s, TFLOPS per GPU: {avg_learn_tflops:.2f}\n' + + f'Overall throughput: {avg_overall_throughput:.2f} samples/s\n' + + f'Overall time per sample: {overall_time_per_sample:.2f} s\n' + + f'Make experience time per sample: {make_experience_time_per_sample:.2f} s, {make_experience_time_per_sample/overall_time_per_sample*100:.2f}%\n' + + + f'Learn time per sample: {learn_time_per_sample:.2f} s, {learn_time_per_sample/overall_time_per_sample*100:.2f}%' ) diff --git a/applications/Chat/coati/trainer/ppo.py b/applications/Chat/coati/trainer/ppo.py index 4c4a1002e96d..ef625a1c1b3d 100644 --- a/applications/Chat/coati/trainer/ppo.py +++ b/applications/Chat/coati/trainer/ppo.py @@ -1,11 +1,11 @@ from typing import Dict, List import torch.nn as nn +from coati.experience_buffer import NaiveExperienceBuffer from coati.experience_maker import Experience, NaiveExperienceMaker from coati.models.base import Actor, Critic, get_base_model from coati.models.loss import GPTLMLoss, PolicyLoss, ValueLoss from coati.models.utils import calc_action_log_probs -from coati.replay_buffer import NaiveReplayBuffer from torch import Tensor from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler @@ -86,9 +86,9 @@ def __init__(self, assert not offload_inference_models, \ "GeminiPlugin is not compatible with manual model.to('cpu')" - buffer = NaiveReplayBuffer(train_batch_size, buffer_limit, buffer_cpu_offload) + data_buffer = NaiveExperienceBuffer(train_batch_size, buffer_limit, buffer_cpu_offload) super().__init__( - strategy, buffer, + strategy, data_buffer, sample_buffer, dataloader_pin_memory, callbacks ) @@ -170,7 +170,7 @@ def _learn(self, update_step: int): # buffer may be empty at first, we should rebuild at each training if self.sample_buffer: - experience = self.buffer.sample() + experience = self.data_buffer.sample() self._on_learn_batch_start() experience.to_device(self.device) metrics = self._training_step(experience) diff --git a/applications/Chat/coati/trainer/strategies/base.py b/applications/Chat/coati/trainer/strategies/base.py index 3d1dfaf784cf..c20b2b16e396 100644 --- a/applications/Chat/coati/trainer/strategies/base.py +++ b/applications/Chat/coati/trainer/strategies/base.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn -from coati.replay_buffer import ReplayBuffer +from coati.experience_buffer import ExperienceBuffer from torch.optim import Optimizer from torch.utils.data import DataLoader from transformers.tokenization_utils_base import PreTrainedTokenizerBase @@ -45,7 +45,7 @@ def setup_distributed(self) -> None: pass @abstractmethod - def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader: + def setup_dataloader(self, data_buffer: ExperienceBuffer, pin_memory: bool = False) -> DataLoader: pass def model_init_context(self): diff --git a/applications/Chat/coati/trainer/strategies/colossalai.py b/applications/Chat/coati/trainer/strategies/colossalai.py index 1b59d704eec3..fa55f97ad661 100644 --- a/applications/Chat/coati/trainer/strategies/colossalai.py +++ b/applications/Chat/coati/trainer/strategies/colossalai.py @@ -4,7 +4,6 @@ import torch import torch.distributed as dist import torch.nn as nn -from transformers.tokenization_utils_base import PreTrainedTokenizerBase import colossalai from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin @@ -44,7 +43,7 @@ class LowLevelZeroStrategy(DDPStrategy): """ def __init__(self, - stage: int = 3, + stage: int = 2, precision: str = 'fp16', seed: int = 42, placement_policy: str = 'cuda', @@ -214,14 +213,3 @@ def unwrap_model(self, model: nn.Module) -> nn.Module: ddp_model = model.unwrap() assert isinstance(ddp_model, GeminiDDP) return ddp_model.module - - def save_pretrained(self, - model: nn.Module, - path: str, - only_rank0: bool = True, - tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None: - raise RuntimeError('ColossalAI strategy with stage-3 does not support save_pretrained() now') - - def get_model_state_dict_shard(self, model: nn.Module, **config): - assert isinstance(self.plugin, GeminiPlugin) - yield from super().get_model_state_dict_shard(model, **config) diff --git a/applications/Chat/coati/trainer/strategies/ddp.py b/applications/Chat/coati/trainer/strategies/ddp.py index e1c1bbf19f35..a52b0460daa8 100644 --- a/applications/Chat/coati/trainer/strategies/ddp.py +++ b/applications/Chat/coati/trainer/strategies/ddp.py @@ -7,7 +7,8 @@ import torch import torch.distributed as dist import torch.nn as nn -from coati.replay_buffer import ReplayBuffer +from coati.experience_buffer import ExperienceBuffer +from coati.models import Actor, Critic, RewardModel from torch.utils.data import DataLoader from transformers.modeling_utils import PreTrainedModel from transformers.tokenization_utils_base import PreTrainedTokenizerBase @@ -71,13 +72,13 @@ def set_seed(self, seed: int) -> None: np.random.seed(seed) torch.manual_seed(seed) - def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader: - return self.plugin.prepare_dataloader(replay_buffer, - batch_size=replay_buffer.sample_batch_size, + def setup_dataloader(self, data_buffer: ExperienceBuffer, pin_memory: bool = False) -> DataLoader: + return self.plugin.prepare_dataloader(data_buffer, + batch_size=data_buffer.sample_batch_size, shuffle=True, drop_last=True, pin_memory=pin_memory, - collate_fn=replay_buffer.collate_fn) + collate_fn=data_buffer.collate_fn) def setup_sampler(self, dataset) -> DistributedSampler: # FIXME(cwher): this is only invoked in train_on_ray, not tested after adapt Boost API. @@ -92,13 +93,33 @@ def save_pretrained(self, path: str, only_rank0: bool = True, tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None: - if only_rank0 and dist.get_rank() != 0: - return - unwrapped_model = self.unwrap_model(model) - assert isinstance(unwrapped_model, PreTrainedModel) - unwrapped_model.save_pretrained(path) - if tokenizer is not None: - tokenizer.save_pretrained(path) + if not only_rank0 or dist.get_rank() == 0: + unwrapped_model = self.unwrap_model(model) + assert isinstance(unwrapped_model, (Actor, Critic, RewardModel)) + pretrained_model = unwrapped_model.model + assert isinstance(pretrained_model, PreTrainedModel) + # HACK: only use hf save_pretrained to save config + pretrained_model.save_pretrained(path, save_function=lambda *args, **kwargs: None) + if tokenizer is not None: + tokenizer.save_pretrained(path) + model_path = os.path.join(path, "pytorch_model.bin") + self.save_model(model, + model_path, + only_rank0=only_rank0) + + def _replace_keys(model_path: str, + replace_fn: Callable): + state_dict = torch.load(model_path, map_location="cpu") + state_dict = { + replace_fn(k): v + for k, v in state_dict.items() + } + torch.save(state_dict, model_path) + + # FIXME: save_model would add "model." prefix to keys of pytorch_model.bin + # HACK: rename keys of pytorch_model.bin + if dist.get_rank() == 0: + _replace_keys(model_path, lambda k: k.replace("model.", "", 1)) def get_model_state_dict_shard(self, model: nn.Module, **config): # TODO: implement sharding on naive strategy diff --git a/applications/Chat/coati/trainer/strategies/sampler.py b/applications/Chat/coati/trainer/strategies/sampler.py index 65e199dbf029..d726fa640fa2 100644 --- a/applications/Chat/coati/trainer/strategies/sampler.py +++ b/applications/Chat/coati/trainer/strategies/sampler.py @@ -27,7 +27,6 @@ def __init__(self, dataset, num_replicas: int, rank: int) -> None: assert len(indices) == self.num_samples self.indices = indices - def sample(self, batch_size: int) -> list: sampled_indices = np.random.choice(self.indices, batch_size, replace=False) return [self.dataset[idx] for idx in sampled_indices] diff --git a/applications/Chat/coati/trainer/utils.py b/applications/Chat/coati/trainer/utils.py index 4d45061bab09..7e2cb9c634f7 100644 --- a/applications/Chat/coati/trainer/utils.py +++ b/applications/Chat/coati/trainer/utils.py @@ -21,9 +21,13 @@ def __init__( self.dataloader = dataloader self.count = 0 - self.dataloader_iter = iter(dataloader) + self.dataloader_iter = None def next(self): + # defer initialization + if self.dataloader_iter is None: + self.dataloader_iter = iter(self.dataloader) + self.count += 1 try: return next(self.dataloader_iter) diff --git a/applications/Chat/examples/download_model.py b/applications/Chat/examples/download_model.py new file mode 100644 index 000000000000..c2b5f9a859a9 --- /dev/null +++ b/applications/Chat/examples/download_model.py @@ -0,0 +1,84 @@ +import argparse +import dataclasses +import os +import parser +from typing import List + +import tqdm +from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic +from coati.models.gpt import GPTRM, GPTActor, GPTCritic +from coati.models.opt import OPTRM, OPTActor, OPTCritic +from huggingface_hub import hf_hub_download, snapshot_download +from transformers import AutoConfig, AutoTokenizer, BloomConfig, BloomTokenizerFast, GPT2Config, GPT2Tokenizer + + +@dataclasses.dataclass +class HFRepoFiles: + repo_id: str + files: List[str] + + def download(self, dir_path: str): + for file in self.files: + file_path = hf_hub_download(self.repo_id, file, local_dir=dir_path) + + def download_all(self): + file_path = snapshot_download(self.repo_id) + + +def test_init(model: str, dir_path: str): + if model == "gpt2": + config = GPT2Config.from_pretrained(dir_path) + actor = GPTActor(config=config) + critic = GPTCritic(config=config) + reward_model = GPTRM(config=config) + tokenizer = GPT2Tokenizer.from_pretrained(dir_path) + elif model == "bloom": + config = BloomConfig.from_pretrained(dir_path) + actor = BLOOMActor(config=config) + critic = BLOOMCritic(config=config) + reward_model = BLOOMRM(config=config) + tokenizer = BloomTokenizerFast.from_pretrained(dir_path) + elif model == "opt": + config = AutoConfig.from_pretrained(dir_path) + actor = OPTActor(config=config) + critic = OPTCritic(config=config) + reward_model = OPTRM(config=config) + tokenizer = AutoTokenizer.from_pretrained(dir_path) + else: + raise NotImplementedError(f"Model {model} not implemented") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model-dir", type=str, default="test_models") + parser.add_argument("--config-only", default=False, action="store_true") + args = parser.parse_args() + + if os.path.exists(args.model_dir): + print(f"[INFO]: {args.model_dir} already exists") + exit(0) + + repo_list = { + "gpt2": HFRepoFiles( + repo_id="gpt2", + files=["config.json", "tokenizer.json", "vocab.json", "merges.txt"] + ), + "bloom": HFRepoFiles( + repo_id="bigscience/bloom-560m", + files=["config.json", "tokenizer.json", "tokenizer_config.json"] + ), + "opt": HFRepoFiles( + repo_id="facebook/opt-350m", + files=["config.json", "tokenizer_config.json", "vocab.json", "merges.txt"] + ), + } + + os.mkdir(args.model_dir) + for model_name in tqdm.tqdm(repo_list): + dir_path = os.path.join(args.model_dir, model_name) + if args.config_only: + os.mkdir(dir_path) + repo_list[model_name].download(dir_path) + else: + repo_list[model_name].download_all() + test_init(model_name, dir_path) diff --git a/applications/Chat/examples/generate_prompt_dataset.py b/applications/Chat/examples/generate_prompt_dataset.py index 95e40fefe7ff..2abb31c09f82 100644 --- a/applications/Chat/examples/generate_prompt_dataset.py +++ b/applications/Chat/examples/generate_prompt_dataset.py @@ -1,7 +1,6 @@ import argparse - -import random import json +import random random.seed(42) @@ -10,8 +9,10 @@ def sample(args): with open(args.dataset_path, mode='r') as f: dataset_list = json.load(f) - sampled_dataset = [{"instruction": sample["instruction"], "id":idx} - for idx, sample in enumerate(random.sample(dataset_list, args.sample_size))] + sampled_dataset = [ + {"instruction": sample["instruction"], "id": idx} + for idx, sample in enumerate(random.sample(dataset_list, args.sample_size)) + ] with open(args.save_path, mode='w') as f: json.dump(sampled_dataset, f, indent=4, diff --git a/applications/Chat/examples/inference.py b/applications/Chat/examples/inference.py index 4b49e76088bc..e1e57e3cd376 100644 --- a/applications/Chat/examples/inference.py +++ b/applications/Chat/examples/inference.py @@ -4,40 +4,50 @@ from coati.models.bloom import BLOOMActor from coati.models.generation import generate from coati.models.gpt import GPTActor +from coati.models.llama import LlamaActor from coati.models.opt import OPTActor -from transformers import AutoTokenizer -from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer +from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer def eval(args): # configure model if args.model == 'gpt2': - actor = GPTActor(pretrained=args.pretrain).to(torch.cuda.current_device()) + actor = GPTActor(pretrained=args.pretrain) elif args.model == 'bloom': - actor = BLOOMActor(pretrained=args.pretrain).to(torch.cuda.current_device()) + actor = BLOOMActor(pretrained=args.pretrain) elif args.model == 'opt': - actor = OPTActor(pretrained=args.pretrain).to(torch.cuda.current_device()) + actor = OPTActor(pretrained=args.pretrain) + elif args.model == 'llama': + actor = LlamaActor(pretrained=args.pretrain) else: raise ValueError(f'Unsupported model "{args.model}"') - state_dict = torch.load(args.model_path) - actor.load_state_dict(state_dict) + actor.to(torch.cuda.current_device()) + if args.model_path is not None: + state_dict = torch.load(args.model_path) + actor.load_state_dict(state_dict) # configure tokenizer if args.model == 'gpt2': tokenizer = GPT2Tokenizer.from_pretrained('gpt2') tokenizer.pad_token = tokenizer.eos_token elif args.model == 'bloom': - tokenizer = AutoTokenizer.from_pretrained('bigscience/bloom-560m') + tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m') tokenizer.pad_token = tokenizer.eos_token elif args.model == 'opt': - tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m') + tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + tokenizer.pad_token = tokenizer.eos_token + elif args.model == 'llama': + tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") + tokenizer.eos_token = '<\s>' + tokenizer.pad_token = tokenizer.unk_token else: raise ValueError(f'Unsupported model "{args.model}"') actor.eval() - input = args.input - input_ids = tokenizer.encode(input, return_tensors='pt').to(torch.cuda.current_device()) + input_ids = tokenizer.encode(args.input, + return_tensors='pt')\ + .to(torch.cuda.current_device()) outputs = generate(actor, input_ids, max_length=args.max_length, @@ -45,13 +55,14 @@ def eval(args): top_k=50, top_p=0.95, num_return_sequences=1) - output = tokenizer.batch_decode(outputs[0], skip_special_tokens=True) - print(output) + output = tokenizer.batch_decode(outputs[0], + skip_special_tokens=True) + print(f"[Output]: {''.join(output)}") if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt']) + parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama']) # We suggest to use the pretrained model from HuggingFace, use pretrain to configure model parser.add_argument('--pretrain', type=str, default=None) parser.add_argument('--model_path', type=str, default=None) diff --git a/applications/Chat/examples/test_ci.sh b/applications/Chat/examples/test_ci.sh deleted file mode 100755 index fe2af471017e..000000000000 --- a/applications/Chat/examples/test_ci.sh +++ /dev/null @@ -1,160 +0,0 @@ -#!/usr/bin/env bash - -set_n_least_used_CUDA_VISIBLE_DEVICES() { - local n=${1:-"9999"} - echo "GPU Memory Usage:" - local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv | - tail -n +2 | - nl -v 0 | - tee /dev/tty | - sort -g -k 2 | - awk '{print $1}' | - head -n $n) - export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') - echo "Now CUDA_VISIBLE_DEVICES is set to:" - echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" -} - -set_n_least_used_CUDA_VISIBLE_DEVICES 4 - -set -xue - -if [ -z "$SFT_DATASET" ]; then - echo "Please set \$SFT_DATASET to the path to sft dataset." - exit 1 -fi - -if [ -z "$PROMPT_PATH" ]; then - echo "Please set \$PROMPT_PATH to the path to prompts csv." - exit 1 -fi - -if [ -z "$PRETRAIN_DATASET" ]; then - echo "Please set \$PRETRAIN_DATASET to the path to alpaca data." - exit 1 -fi - -BASE=$(realpath $(dirname $0)) - -export OMP_NUM_THREADS=8 - -# install requirements -pip install -r ${BASE}/requirements.txt - -wandb init -m offline - -# FIXME: This is a hack to skip tests that are not working -# - gpt2-ddp: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation -# - llama-*: These tests can be passed locally, skipped for long execution time -SKIPPED_TESTS=( - "gpt2-ddp" - "llama-ddp" - "llama-colossalai_gemini" - "llama-colossalai_zero2" -) - -# These tests are quick and do not have any dependencies -for model in 'gpt2' 'bloom' 'opt' 'llama'; do - for strategy in 'ddp' 'colossalai_gemini' 'colossalai_zero2'; do - if [[ " ${SKIPPED_TESTS[*]} " =~ " ${model}-${strategy} " ]]; then - echo "[Test]: Skipped $model-$strategy" - continue - fi - torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \ - --prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \ - --strategy $strategy --model $model \ - --num_episodes 1 --num_collect_steps 2 --num_update_steps 1 \ - --train_batch_size 2 --lora_rank 4 - done -done - -# train sft -torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'bigscience/bloom-560m' \ - --model 'bloom' --strategy colossalai_zero2 --lora_rank 4 \ - --dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \ - --save_path ${BASE}/output -rm -rf ${BASE}/output - -torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'gpt2' \ - --model 'gpt2' --strategy colossalai_zero2 \ - --dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \ - --save_path ${BASE}/output -rm -rf ${BASE}/output - -torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'facebook/opt-350m' \ - --model 'opt' --strategy colossalai_zero2 --lora_rank 4 \ - --dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \ - --save_path ${BASE}/output -rm -rf ${BASE}/output - -torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'gpt2' \ - --model 'gpt2' --strategy ddp --lora_rank 4 \ - --dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \ - --save_path ${BASE}/output -rm -rf ${BASE}/output - -# train rm -torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \ - --pretrain 'facebook/opt-350m' --model 'opt' \ - --strategy colossalai_zero2 --loss_fn 'log_sig' \ - --dataset 'Anthropic/hh-rlhf' --subset 'harmless-base' \ - --test True --lora_rank 0 \ - --save_path ${BASE}/rm_ckpt_opt.pt - -torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \ - --pretrain 'gpt2' --model 'gpt2' \ - --strategy colossalai_zero2 --loss_fn 'log_exp' \ - --dataset 'Dahoas/rm-static' \ - --test True --lora_rank 0 \ - --save_path ${BASE}/rm_ckpt_gpt.pt - -torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \ - --pretrain 'gpt2' --model 'gpt2' \ - --strategy ddp --loss_fn 'log_exp' \ - --dataset 'Dahoas/rm-static' \ - --test True --lora_rank 4 \ - --save_path ${BASE}/rm_ckpt.pt -rm -rf ${BASE}/rm_ckpt.pt - -torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \ - --pretrain 'bigscience/bloom-560m' --model 'bloom' \ - --strategy colossalai_zero2 --loss_fn 'log_sig' \ - --dataset 'Anthropic/hh-rlhf' --subset 'harmless-base' \ - --test True --lora_rank 4 \ - --save_path ${BASE}/rm_ckpt.pt -rm -rf ${BASE}/rm_ckpt.pt - -# train rl -torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \ - --prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \ - --strategy colossalai_zero2 --num_episodes 1 \ - --num_collect_steps 2 --num_update_steps 1 --train_batch_size 2 \ - --pretrain 'facebook/opt-350m' --model opt \ - --rm_pretrain 'facebook/opt-350m' \ - --rm_path ${BASE}/rm_ckpt_opt.pt \ - --save_path ${BASE}/actor_checkpoint_prompts.pt -rm -rf ${BASE}/rm_ckpt_opt.pt - -torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \ - --prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \ - --strategy colossalai_zero2 --num_episodes 1 \ - --num_collect_steps 2 --num_update_steps 1 --train_batch_size 2 \ - --pretrain 'gpt2' --model gpt2 \ - --rm_pretrain 'gpt2' \ - --rm_path ${BASE}/rm_ckpt_gpt.pt \ - --save_path ${BASE}/actor_checkpoint_prompts.pt - -torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \ - --prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \ - --strategy colossalai_gemini --num_episodes 1 \ - --num_collect_steps 2 --num_update_steps 1 --train_batch_size 2 \ - --pretrain 'gpt2' --model gpt2 \ - --rm_pretrain 'gpt2' \ - --rm_path ${BASE}/rm_ckpt_gpt.pt \ - --save_path ${BASE}/actor_checkpoint_prompts.pt -rm -rf ${BASE}/rm_ckpt_gpt.pt - -rm -rf ${BASE}/actor_checkpoint_prompts.pt - -# 3080 doesn't support P2P, skip this test -# cd ${BASE}/ray && bash test_ci.sh && cd ${BASE} diff --git a/applications/Chat/examples/train_prompts.py b/applications/Chat/examples/train_prompts.py index 7338a6d51142..d27a70a3fef6 100644 --- a/applications/Chat/examples/train_prompts.py +++ b/applications/Chat/examples/train_prompts.py @@ -1,8 +1,9 @@ import argparse +import warnings import torch import torch.distributed as dist -from coati.dataset import DataCollatorForSupervisedDataset, PromptDataset, SupervisedDataset +from coati.dataset import PromptDataset, SupervisedDataset from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic from coati.models.gpt import GPTRM, GPTActor, GPTCritic from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM @@ -29,6 +30,7 @@ def main(args): raise ValueError(f'Unsupported strategy "{args.strategy}"') if args.rm_path is not None: + warnings.warn('LoRA weights should be merged with the model weights') state_dict = torch.load(args.rm_path, map_location='cpu') with strategy.model_init_context(): @@ -50,18 +52,18 @@ def main(args): rm_model_name = args.rm_model if rm_model_name == 'gpt2': - reward_model = GPTRM(pretrained=args.rm_pretrain) + reward_model = GPTRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank) elif rm_model_name == 'bloom': - reward_model = BLOOMRM(pretrained=args.rm_pretrain) + reward_model = BLOOMRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank) elif rm_model_name == 'opt': - reward_model = OPTRM(pretrained=args.rm_pretrain) + reward_model = OPTRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank) elif rm_model_name == 'llama': - reward_model = LlamaRM(pretrained=args.rm_pretrain) + reward_model = LlamaRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank) else: raise ValueError(f'Unsupported reward model "{rm_model_name}"') if args.rm_path is not None: - reward_model.load_state_dict(state_dict) + reward_model.load_state_dict(state_dict, strict=False) initial_model.to(torch.float16).to(torch.cuda.current_device()) reward_model.to(torch.float16).to(torch.cuda.current_device()) @@ -89,7 +91,7 @@ def main(args): raise ValueError(f'Unsupported reward model "{rm_model_name}"') if args.rm_path is not None: - critic.load_state_dict(state_dict) + critic.load_state_dict(state_dict, strict=False) del state_dict if args.strategy != 'colossalai_gemini': @@ -106,23 +108,25 @@ def main(args): # configure tokenizer if args.model == 'gpt2': - tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + tokenizer = GPT2Tokenizer.from_pretrained( + 'gpt2' if args.tokenizer is None else args.tokenizer) tokenizer.pad_token = tokenizer.eos_token elif args.model == 'bloom': - tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m') + tokenizer = BloomTokenizerFast.from_pretrained( + 'bigscience/bloom-560m' if args.tokenizer is None else args.tokenizer) tokenizer.pad_token = tokenizer.eos_token elif args.model == 'opt': - tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + tokenizer = AutoTokenizer.from_pretrained( + "facebook/opt-350m" if args.tokenizer is None else args.tokenizer) tokenizer.pad_token = tokenizer.eos_token elif args.model == 'llama': - tokenizer = LlamaTokenizer.from_pretrained(args.pretrain) + tokenizer = LlamaTokenizer.from_pretrained( + "hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer) tokenizer.eos_token = '<\s>' tokenizer.pad_token = tokenizer.unk_token else: raise ValueError(f'Unsupported model "{args.model}"') - data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) - prompt_dataset = PromptDataset(tokenizer=tokenizer, data_path=args.prompt_dataset, max_datasets_size=16384) if dist.is_initialized() and dist.get_world_size() > 1: prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True) @@ -144,8 +148,7 @@ def main(args): pretrain_dataloader = DataLoader(pretrain_dataset, shuffle=(pretrain_sampler is None), sampler=pretrain_sampler, - batch_size=args.ptx_batch_size, - collate_fn=data_collator) + batch_size=args.ptx_batch_size) # NOTE: For small models like opt-1.3b, reward model and initial model are not required to be parallelized. (actor, actor_optim), (critic, critic_optim), reward_model, initial_model = \ @@ -197,6 +200,7 @@ def main(args): default='colossalai_zero2', help='strategy to use') parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama']) + parser.add_argument('--tokenizer', type=str, default=None) parser.add_argument('--pretrain', type=str, default=None) parser.add_argument('--rm_model', default=None, choices=['gpt2', 'bloom', 'opt', 'llama']) parser.add_argument('--rm_path', type=str, default=None) diff --git a/applications/Chat/examples/train_reward_model.py b/applications/Chat/examples/train_reward_model.py index fb9802e38542..190460bc20f6 100644 --- a/applications/Chat/examples/train_reward_model.py +++ b/applications/Chat/examples/train_reward_model.py @@ -36,34 +36,39 @@ def train(args): # configure model with strategy.model_init_context(): if args.model == 'bloom': - model = BLOOMRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) + model = BLOOMRM(pretrained=args.pretrain, lora_rank=args.lora_rank) elif args.model == 'opt': - model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) + model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank) elif args.model == 'gpt2': - model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) + model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank) elif args.model == 'llama': - model = LlamaRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) + model = LlamaRM(pretrained=args.pretrain, lora_rank=args.lora_rank) else: raise ValueError(f'Unsupported model "{args.model}"') + model.to(torch.float16).to(torch.cuda.current_device()) + if args.model_path is not None: state_dict = torch.load(args.model_path) model.load_state_dict(state_dict) - model = model.to(torch.float16) - # configure tokenizer if args.model == 'gpt2': - tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + tokenizer = GPT2Tokenizer.from_pretrained( + 'gpt2' if args.tokenizer is None else args.tokenizer) tokenizer.pad_token = tokenizer.eos_token elif args.model == 'bloom': - tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m') + tokenizer = BloomTokenizerFast.from_pretrained( + 'bigscience/bloom-560m' if args.tokenizer is None else args.tokenizer) tokenizer.pad_token = tokenizer.eos_token elif args.model == 'opt': - tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + tokenizer = AutoTokenizer.from_pretrained( + "facebook/opt-350m" if args.tokenizer is None else args.tokenizer) tokenizer.pad_token = tokenizer.eos_token elif args.model == 'llama': - tokenizer = LlamaTokenizer.from_pretrained(args.pretrain) + tokenizer = LlamaTokenizer.from_pretrained( + "hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer) + tokenizer.eos_token = '<\s>' tokenizer.pad_token = tokenizer.unk_token else: raise ValueError(f'Unsupported model "{args.model}"') @@ -89,8 +94,8 @@ def train(args): data = load_dataset(args.dataset) if args.test: - train_data = data['train'].select(range(100)) - eval_data = data['test'].select(range(10)) + train_data = data['train'].select(range(20)) + eval_data = data['test'].select(range(5)) else: train_data = data['train'] eval_data = data['test'] @@ -177,6 +182,7 @@ def train(args): choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'], default='colossalai_zero2') parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom') + parser.add_argument('--tokenizer', type=str, default=None) parser.add_argument('--pretrain', type=str, default=None) parser.add_argument('--model_path', type=str, default=None) parser.add_argument('--need_optim_ckpt', type=bool, default=False) @@ -184,7 +190,7 @@ def train(args): type=str, choices=['Anthropic/hh-rlhf', 'Dahoas/rm-static'], default='Dahoas/rm-static') - parser.add_argument('--subset', type=str, default=None) + parser.add_argument('--subset', type=lambda x: None if x == 'None' else x, default=None) parser.add_argument('--save_path', type=str, default='rm_ckpt') parser.add_argument('--max_epochs', type=int, default=1) parser.add_argument('--batch_size', type=int, default=1) diff --git a/applications/Chat/examples/train_rm.sh b/applications/Chat/examples/train_rm.sh index 80abe62d2a3f..cc1b7be2815f 100755 --- a/applications/Chat/examples/train_rm.sh +++ b/applications/Chat/examples/train_rm.sh @@ -1,13 +1,13 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() { local n=${1:-"9999"} echo "GPU Memory Usage:" - local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \ - | tail -n +2 \ - | nl -v 0 \ - | tee /dev/tty \ - | sort -g -k 2 \ - | awk '{print $1}' \ - | head -n $n) + local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv | + tail -n +2 | + nl -v 0 | + tee /dev/tty | + sort -g -k 2 | + awk '{print $1}' | + head -n $n) export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') echo "Now CUDA_VISIBLE_DEVICES is set to:" echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" @@ -16,9 +16,7 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() { set_n_least_used_CUDA_VISIBLE_DEVICES 2 torchrun --standalone --nproc_per_node=2 train_reward_model.py \ - --pretrain \ - --model 'bloom' \ - --strategy colossalai_zero2 \ - --loss_fn 'log_sig'\ - --save_path \ - --dataset 'Anthropic/hh-rlhf'\ + --model 'bloom' \ + --strategy colossalai_zero2 \ + --loss_fn 'log_sig' \ + --dataset 'Anthropic/hh-rlhf' diff --git a/applications/Chat/examples/train_sft.py b/applications/Chat/examples/train_sft.py index 4676d47dd331..7585cf3ed0da 100644 --- a/applications/Chat/examples/train_sft.py +++ b/applications/Chat/examples/train_sft.py @@ -1,24 +1,22 @@ import argparse import math -import os +import warnings -import loralib as lora import torch import torch.distributed as dist -from coati.dataset import DataCollatorForSupervisedDataset, SFTDataset, SupervisedDataset -from coati.models import convert_to_lora_module +from coati.dataset import SFTDataset, SupervisedDataset +from coati.models.bloom import BLOOMActor +from coati.models.gpt import GPTActor +from coati.models.llama import LlamaActor +from coati.models.opt import OPTActor from coati.trainer import SFTTrainer from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy from datasets import load_dataset from torch.optim import Adam from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler -from transformers import AutoTokenizer, BloomConfig, BloomForCausalLM, BloomTokenizerFast, LlamaConfig, LlamaForCausalLM -from transformers.models.gpt2.configuration_gpt2 import GPT2Config -from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel +from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer -from transformers.models.opt.configuration_opt import OPTConfig -from transformers.models.opt.modeling_opt import OPTForCausalLM from transformers.trainer import get_scheduler from colossalai.logging import get_dist_logger @@ -31,8 +29,6 @@ def train(args): if args.strategy == 'ddp': strategy = DDPStrategy() elif args.strategy == 'colossalai_gemini': - raise NotImplementedError( - 'Gemini is not supported .from_pretrained() yet. We will update this after checkpoint io is ready.') strategy = GeminiStrategy(placement_policy='cuda') elif args.strategy == 'colossalai_zero2': strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda') @@ -42,40 +38,49 @@ def train(args): raise ValueError(f'Unsupported strategy "{args.strategy}"') # configure model + if args.lora_rank > 0: + warnings.warn("Gradient checkpoint is disabled when using LoRA") + args.grad_checkpoint = False with strategy.model_init_context(): if args.model == 'bloom': - model = convert_to_lora_module(BloomForCausalLM.from_pretrained(args.pretrain), - args.lora_rank).half().cuda() + model = BLOOMActor(pretrained=args.pretrain, + lora_rank=args.lora_rank, + checkpoint=args.grad_checkpoint) elif args.model == 'opt': - model = convert_to_lora_module(OPTForCausalLM.from_pretrained(args.pretrain), args.lora_rank).half().cuda() + model = OPTActor(pretrained=args.pretrain, + lora_rank=args.lora_rank, + checkpoint=args.grad_checkpoint) elif args.model == 'gpt2': - model = convert_to_lora_module(GPT2LMHeadModel.from_pretrained(args.pretrain), args.lora_rank).half().cuda() + model = GPTActor(pretrained=args.pretrain, + lora_rank=args.lora_rank, + checkpoint=args.grad_checkpoint) elif args.model == 'llama': - model = convert_to_lora_module(LlamaForCausalLM.from_pretrained(args.pretrain), - args.lora_rank).half().cuda() + model = LlamaActor(pretrained=args.pretrain, + lora_rank=args.lora_rank, + checkpoint=args.grad_checkpoint) else: raise ValueError(f'Unsupported model "{args.model}"') - if args.grad_checkpoint: - model.gradient_checkpointing_enable() + + model.to(torch.float16).to(torch.cuda.current_device()) # configure tokenizer if args.model == 'gpt2': - tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + tokenizer = GPT2Tokenizer.from_pretrained( + 'gpt2' if args.tokenizer is None else args.tokenizer) tokenizer.pad_token = tokenizer.eos_token elif args.model == 'bloom': - tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m') + tokenizer = BloomTokenizerFast.from_pretrained( + 'bigscience/bloom-560m' if args.tokenizer is None else args.tokenizer) tokenizer.pad_token = tokenizer.eos_token elif args.model == 'opt': - tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") - tokenizer.pad_token = tokenizer.eos_token - elif args.model == 'llama': tokenizer = AutoTokenizer.from_pretrained( - args.pretrain, - padding_side="right", - use_fast=False, - ) - tokenizer.eos_token = '' + "facebook/opt-350m" if args.tokenizer is None else args.tokenizer) tokenizer.pad_token = tokenizer.eos_token + elif args.model == 'llama': + tokenizer = LlamaTokenizer.from_pretrained( + "hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer) + tokenizer.eos_token = '<\s>' + tokenizer.pad_token = tokenizer.unk_token else: raise ValueError(f'Unsupported model "{args.model}"') @@ -111,7 +116,6 @@ def train(args): max_datasets_size=args.max_datasets_size, max_length=args.max_len) eval_dataset = None - data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) if dist.is_initialized() and dist.get_world_size() > 1: train_sampler = DistributedSampler(train_dataset, @@ -135,14 +139,12 @@ def train(args): shuffle=(train_sampler is None), sampler=train_sampler, batch_size=args.batch_size, - collate_fn=data_collator, pin_memory=True) if eval_dataset is not None: eval_dataloader = DataLoader(eval_dataset, shuffle=(eval_sampler is None), sampler=eval_sampler, batch_size=args.batch_size, - collate_fn=data_collator, pin_memory=True) else: eval_dataloader = None @@ -184,6 +186,7 @@ def train(args): choices=['ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_zero2_cpu'], default='colossalai_zero2') parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom') + parser.add_argument('--tokenizer', type=str, default=None) parser.add_argument('--pretrain', type=str, default=None) parser.add_argument('--dataset', type=str, default=None) parser.add_argument('--max_datasets_size', type=int, default=None) diff --git a/applications/Chat/examples/train_sft.sh b/applications/Chat/examples/train_sft.sh index c880f85825a7..1a5cd069011d 100755 --- a/applications/Chat/examples/train_sft.sh +++ b/applications/Chat/examples/train_sft.sh @@ -1,12 +1,29 @@ +set_n_least_used_CUDA_VISIBLE_DEVICES() { + local n=${1:-"9999"} + echo "GPU Memory Usage:" + local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv | + tail -n +2 | + nl -v 0 | + tee /dev/tty | + sort -g -k 2 | + awk '{print $1}' | + head -n $n) + export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') + echo "Now CUDA_VISIBLE_DEVICES is set to:" + echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" +} + +set_n_least_used_CUDA_VISIBLE_DEVICES 4 + torchrun --standalone --nproc_per_node=4 train_sft.py \ --pretrain "/path/to/LLaMa-7B/" \ --model 'llama' \ --strategy colossalai_zero2 \ --log_interval 10 \ - --save_path /path/to/Coati-7B \ + --save_path /path/to/Coati-7B \ --dataset /path/to/data.json \ --batch_size 4 \ --accumulation_steps 8 \ --lr 2e-5 \ --max_datasets_size 512 \ - --max_epochs 1 \ + --max_epochs 1 diff --git a/applications/Chat/inference/benchmark.py b/applications/Chat/inference/benchmark.py index a8485f588705..438a1e3ef1c7 100644 --- a/applications/Chat/inference/benchmark.py +++ b/applications/Chat/inference/benchmark.py @@ -4,8 +4,8 @@ from time import time import torch -from llama_gptq import load_quant -from transformers import AutoTokenizer, GenerationConfig, LlamaForCausalLM +from coati.quant import llama_load_quant, low_resource_init +from transformers import AutoTokenizer, GenerationConfig, LlamaConfig, LlamaForCausalLM def generate_prompt(instruction, input=None): @@ -106,7 +106,10 @@ def evaluate( tokenizer = AutoTokenizer.from_pretrained(args.pretrained) if args.quant == '4bit': - model = load_quant(args.pretrained, args.gptq_checkpoint, 4, args.gptq_group_size) + with low_resource_init(): + config = LlamaConfig.from_pretrained(args.pretrained) + model = LlamaForCausalLM(config) + model = llama_load_quant(model, args.gptq_checkpoint, 4, args.gptq_group_size) model.cuda() else: model = LlamaForCausalLM.from_pretrained( diff --git a/applications/Chat/inference/llama_gptq/__init__.py b/applications/Chat/inference/llama_gptq/__init__.py deleted file mode 100644 index 51c8d6316290..000000000000 --- a/applications/Chat/inference/llama_gptq/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .loader import load_quant - -__all__ = [ - 'load_quant', -] diff --git a/applications/Chat/inference/llama_gptq/loader.py b/applications/Chat/inference/llama_gptq/loader.py deleted file mode 100644 index a5c6ac7d1589..000000000000 --- a/applications/Chat/inference/llama_gptq/loader.py +++ /dev/null @@ -1,41 +0,0 @@ -import torch -import torch.nn as nn -import transformers -from transformers import LlamaConfig, LlamaForCausalLM - -from .model_utils import find_layers -from .quant import make_quant - - -def load_quant(pretrained: str, checkpoint: str, wbits: int, groupsize: int): - config = LlamaConfig.from_pretrained(pretrained) - - def noop(*args, **kwargs): - pass - - torch.nn.init.kaiming_uniform_ = noop - torch.nn.init.uniform_ = noop - torch.nn.init.normal_ = noop - - torch.set_default_dtype(torch.half) - transformers.modeling_utils._init_weights = False - torch.set_default_dtype(torch.half) - model = LlamaForCausalLM(config) - torch.set_default_dtype(torch.float) - model = model.eval() - layers = find_layers(model) - for name in ['lm_head']: - if name in layers: - del layers[name] - make_quant(model, layers, wbits, groupsize) - - print(f'Loading model with {wbits} bits...') - if checkpoint.endswith('.safetensors'): - from safetensors.torch import load_file as safe_load - model.load_state_dict(safe_load(checkpoint)) - else: - model.load_state_dict(torch.load(checkpoint)) - model.seqlen = 2048 - print('Done.') - - return model diff --git a/applications/Chat/inference/llama_gptq/model_utils.py b/applications/Chat/inference/llama_gptq/model_utils.py deleted file mode 100644 index 62db171abb52..000000000000 --- a/applications/Chat/inference/llama_gptq/model_utils.py +++ /dev/null @@ -1,13 +0,0 @@ -# copied from https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/past/modelutils.py - -import torch -import torch.nn as nn - - -def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''): - if type(module) in layers: - return {name: module} - res = {} - for name1, child in module.named_children(): - res.update(find_layers(child, layers=layers, name=name + '.' + name1 if name != '' else name1)) - return res diff --git a/applications/Chat/inference/llama_gptq/quant.py b/applications/Chat/inference/llama_gptq/quant.py deleted file mode 100644 index f7d5b7ce4bd8..000000000000 --- a/applications/Chat/inference/llama_gptq/quant.py +++ /dev/null @@ -1,283 +0,0 @@ -# copied from https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/past/quant.py - -import math - -import numpy as np -import torch -import torch.nn as nn - - -def quantize(x, scale, zero, maxq): - q = torch.clamp(torch.round(x / scale) + zero, 0, maxq) - return scale * (q - zero) - - -class Quantizer(nn.Module): - - def __init__(self, shape=1): - super(Quantizer, self).__init__() - self.register_buffer('maxq', torch.tensor(0)) - self.register_buffer('scale', torch.zeros(shape)) - self.register_buffer('zero', torch.zeros(shape)) - - def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=.8): - self.maxq = torch.tensor(2**bits - 1) - self.perchannel = perchannel - self.sym = sym - self.mse = mse - self.norm = norm - self.grid = grid - self.maxshrink = maxshrink - - def find_params(self, x, weight=False): - dev = x.device - self.maxq = self.maxq.to(dev) - - shape = x.shape - if self.perchannel: - if weight: - x = x.flatten(1) - else: - if len(shape) == 4: - x = x.permute([1, 0, 2, 3]) - x = x.flatten(1) - if len(shape) == 3: - x = x.reshape((-1, shape[-1])).t() - if len(shape) == 2: - x = x.t() - else: - x = x.flatten().unsqueeze(0) - - tmp = torch.zeros(x.shape[0], device=dev) - xmin = torch.minimum(x.min(1)[0], tmp) - xmax = torch.maximum(x.max(1)[0], tmp) - - if self.sym: - xmax = torch.maximum(torch.abs(xmin), xmax) - tmp = xmin < 0 - if torch.any(tmp): - xmin[tmp] = -xmax[tmp] - tmp = (xmin == 0) & (xmax == 0) - xmin[tmp] = -1 - xmax[tmp] = +1 - - self.scale = (xmax - xmin) / self.maxq - if self.sym: - self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2) - else: - self.zero = torch.round(-xmin / self.scale) - - if self.mse: - best = torch.full([x.shape[0]], float('inf'), device=dev) - for i in range(int(self.maxshrink * self.grid)): - p = 1 - i / self.grid - xmin1 = p * xmin - xmax1 = p * xmax - scale1 = (xmax1 - xmin1) / self.maxq - zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero - q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq) - q -= x - q.abs_() - q.pow_(self.norm) - err = torch.sum(q, 1) - tmp = err < best - if torch.any(tmp): - best[tmp] = err[tmp] - self.scale[tmp] = scale1[tmp] - self.zero[tmp] = zero1[tmp] - if not self.perchannel: - if weight: - tmp = shape[0] - else: - tmp = shape[1] if len(shape) != 3 else shape[2] - self.scale = self.scale.repeat(tmp) - self.zero = self.zero.repeat(tmp) - - if weight: - shape = [-1] + [1] * (len(shape) - 1) - self.scale = self.scale.reshape(shape) - self.zero = self.zero.reshape(shape) - return - if len(shape) == 4: - self.scale = self.scale.reshape((1, -1, 1, 1)) - self.zero = self.zero.reshape((1, -1, 1, 1)) - if len(shape) == 3: - self.scale = self.scale.reshape((1, 1, -1)) - self.zero = self.zero.reshape((1, 1, -1)) - if len(shape) == 2: - self.scale = self.scale.unsqueeze(0) - self.zero = self.zero.unsqueeze(0) - - def quantize(self, x): - if self.ready(): - return quantize(x, self.scale, self.zero, self.maxq) - return x - - def enabled(self): - return self.maxq > 0 - - def ready(self): - return torch.all(self.scale != 0) - - -try: - import quant_cuda -except: - print('CUDA extension not installed.') - -# Assumes layer is perfectly divisible into 256 * 256 blocks - - -class QuantLinear(nn.Module): - - def __init__(self, bits, groupsize, infeatures, outfeatures): - super().__init__() - if bits not in [2, 3, 4, 8]: - raise NotImplementedError("Only 2,3,4,8 bits are supported.") - self.infeatures = infeatures - self.outfeatures = outfeatures - self.bits = bits - if groupsize != -1 and groupsize < 32 and groupsize != int(math.pow(2, int(math.log2(groupsize)))): - raise NotImplementedError("groupsize supports powers of 2 greater than 32. (e.g. : 32,64,128,etc)") - groupsize = groupsize if groupsize != -1 else infeatures - self.groupsize = groupsize - self.register_buffer( - 'qzeros', torch.zeros((math.ceil(infeatures / groupsize), outfeatures // 256 * (bits * 8)), - dtype=torch.int)) - self.register_buffer('scales', torch.zeros((math.ceil(infeatures / groupsize), outfeatures))) - self.register_buffer('bias', torch.zeros(outfeatures)) - self.register_buffer('qweight', torch.zeros((infeatures // 256 * (bits * 8), outfeatures), dtype=torch.int)) - self._initialized_quant_state = False - - def pack(self, linear, scales, zeros): - scales = scales.t().contiguous() - zeros = zeros.t().contiguous() - scale_zeros = zeros * scales - self.scales = scales.clone() - if linear.bias is not None: - self.bias = linear.bias.clone() - - intweight = [] - for idx in range(self.infeatures): - g_idx = idx // self.groupsize - intweight.append( - torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx]) / self.scales[g_idx]).to(torch.int)[:, - None]) - intweight = torch.cat(intweight, dim=1) - intweight = intweight.t().contiguous() - intweight = intweight.numpy().astype(np.uint32) - qweight = np.zeros((intweight.shape[0] // 256 * (self.bits * 8), intweight.shape[1]), dtype=np.uint32) - i = 0 - row = 0 - while row < qweight.shape[0]: - if self.bits in [2, 4, 8]: - for j in range(i, i + (32 // self.bits)): - qweight[row] |= intweight[j] << (self.bits * (j - i)) - i += 32 // self.bits - row += 1 - elif self.bits == 3: - for j in range(i, i + 10): - qweight[row] |= intweight[j] << (3 * (j - i)) - i += 10 - qweight[row] |= intweight[i] << 30 - row += 1 - qweight[row] |= (intweight[i] >> 2) & 1 - i += 1 - for j in range(i, i + 10): - qweight[row] |= intweight[j] << (3 * (j - i) + 1) - i += 10 - qweight[row] |= intweight[i] << 31 - row += 1 - qweight[row] |= (intweight[i] >> 1) & 0x3 - i += 1 - for j in range(i, i + 10): - qweight[row] |= intweight[j] << (3 * (j - i) + 2) - i += 10 - row += 1 - else: - raise NotImplementedError("Only 2,3,4,8 bits are supported.") - - qweight = qweight.astype(np.int32) - self.qweight = torch.from_numpy(qweight) - - zeros -= 1 - zeros = zeros.numpy().astype(np.uint32) - qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 256 * (self.bits * 8)), dtype=np.uint32) - i = 0 - col = 0 - while col < qzeros.shape[1]: - if self.bits in [2, 4, 8]: - for j in range(i, i + (32 // self.bits)): - qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i)) - i += 32 // self.bits - col += 1 - elif self.bits == 3: - for j in range(i, i + 10): - qzeros[:, col] |= zeros[:, j] << (3 * (j - i)) - i += 10 - qzeros[:, col] |= zeros[:, i] << 30 - col += 1 - qzeros[:, col] |= (zeros[:, i] >> 2) & 1 - i += 1 - for j in range(i, i + 10): - qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 1) - i += 10 - qzeros[:, col] |= zeros[:, i] << 31 - col += 1 - qzeros[:, col] |= (zeros[:, i] >> 1) & 0x3 - i += 1 - for j in range(i, i + 10): - qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 2) - i += 10 - col += 1 - else: - raise NotImplementedError("Only 2,3,4,8 bits are supported.") - - qzeros = qzeros.astype(np.int32) - self.qzeros = torch.from_numpy(qzeros) - - def forward(self, x): - intermediate_dtype = torch.float32 - - if not self._initialized_quant_state: - # Do we even have a bias? Check for at least one non-zero element. - if self.bias is not None and bool(torch.any(self.bias != 0)): - # Then make sure it's the right type. - self.bias.data = self.bias.data.to(intermediate_dtype) - else: - self.bias = None - - outshape = list(x.shape) - outshape[-1] = self.outfeatures - x = x.reshape(-1, x.shape[-1]) - if self.bias is None: - y = torch.zeros(x.shape[0], outshape[-1], dtype=intermediate_dtype, device=x.device) - else: - y = self.bias.clone().repeat(x.shape[0], 1) - - output_dtype = x.dtype - x = x.to(intermediate_dtype) - if self.bits == 2: - quant_cuda.vecquant2matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize) - elif self.bits == 3: - quant_cuda.vecquant3matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize) - elif self.bits == 4: - quant_cuda.vecquant4matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize) - elif self.bits == 8: - quant_cuda.vecquant8matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize) - else: - raise NotImplementedError("Only 2,3,4,8 bits are supported.") - y = y.to(output_dtype) - return y.reshape(outshape) - - -def make_quant(module, names, bits, groupsize, name=''): - if isinstance(module, QuantLinear): - return - for attr in dir(module): - tmp = getattr(module, attr) - name1 = name + '.' + attr if name != '' else attr - if name1 in names: - setattr(module, attr, QuantLinear(bits, groupsize, tmp.in_features, tmp.out_features)) - for name1, child in module.named_children(): - make_quant(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1) diff --git a/applications/Chat/inference/locustfile.py b/applications/Chat/inference/locustfile.py index 51cdc68125bb..9443d4b99180 100644 --- a/applications/Chat/inference/locustfile.py +++ b/applications/Chat/inference/locustfile.py @@ -5,8 +5,7 @@ samples = [[ dict( instruction='Who is the best player in the history of NBA?', - response= - 'The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1' + response='The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1' ), dict(instruction='continue this talk', response=''), ], [ diff --git a/applications/Chat/inference/server.py b/applications/Chat/inference/server.py index e23f0fceb2fa..9d6b7fabef54 100644 --- a/applications/Chat/inference/server.py +++ b/applications/Chat/inference/server.py @@ -1,19 +1,19 @@ import argparse import os from threading import Lock -from typing import Dict, Generator, List, Optional +from typing import Generator, List, Optional import torch import uvicorn -from fastapi import FastAPI, HTTPException, Request +from coati.quant import llama_load_quant, low_resource_init +from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware -from llama_gptq import load_quant from pydantic import BaseModel, Field from slowapi import Limiter, _rate_limit_exceeded_handler from slowapi.errors import RateLimitExceeded from slowapi.util import get_remote_address from sse_starlette.sse import EventSourceResponse -from transformers import AutoTokenizer, GenerationConfig, LlamaForCausalLM +from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM from utils import ChatPromptProcessor, Dialogue, LockedIterator, load_json, sample_streamingly, update_model_kwargs_fn CONTEXT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.' @@ -56,7 +56,7 @@ class GenerationTaskReq(BaseModel): def generate_streamingly(prompt, max_new_tokens, top_k, top_p, temperature): inputs = {k: v.cuda() for k, v in tokenizer(prompt, return_tensors="pt").items()} - #TODO(ver217): streaming generation does not support repetition_penalty now + # TODO(ver217): streaming generation does not support repetition_penalty now model_kwargs = { 'max_generate_tokens': max_new_tokens, 'early_stopping': True, @@ -162,7 +162,10 @@ def generate_no_stream(data: GenerationTaskReq, request: Request): prompt_processor = ChatPromptProcessor(tokenizer, CONTEXT, MAX_LEN, censored_words=censored_words) if args.quant == '4bit': - model = load_quant(args.pretrained, args.gptq_checkpoint, 4, args.gptq_group_size) + with low_resource_init(): + config = LlamaConfig.from_pretrained(args.pretrained) + model = LlamaForCausalLM(config) + model = llama_load_quant(model, args.gptq_checkpoint, 4, args.gptq_group_size) model.cuda() else: model = LlamaForCausalLM.from_pretrained( diff --git a/applications/Chat/inference/tests/test_chat_prompt.py b/applications/Chat/inference/tests/test_chat_prompt.py index f5737ebe8c09..23028d4959cb 100644 --- a/applications/Chat/inference/tests/test_chat_prompt.py +++ b/applications/Chat/inference/tests/test_chat_prompt.py @@ -10,37 +10,34 @@ ([ Dialogue( instruction='Who is the best player in the history of NBA?', - response= - 'The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1' + response='The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1' ), Dialogue(instruction='continue this talk', response=''), ], 128, - 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\nThe best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1\n\n### Instruction:\ncontinue this talk\n\n### Response:\n' + 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\nThe best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1\n\n### Instruction:\ncontinue this talk\n\n### Response:\n' ), ([ Dialogue( instruction='Who is the best player in the history of NBA?', - response= - 'The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1' + response='The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1' ), Dialogue(instruction='continue this talk', response=''), ], 200, - 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this talk\n\n### Response:\n' + 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this talk\n\n### Response:\n' ), ([ Dialogue( instruction='Who is the best player in the history of NBA?', - response= - 'The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1' + response='The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1' ), Dialogue(instruction='continue this talk', response=''), ], 211, - 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this\n\n### Response:\n' + 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this\n\n### Response:\n' ), ([ Dialogue(instruction='Who is the best player in the history of NBA?', response=''), ], 128, - 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\n' + 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\n' ), ] diff --git a/applications/Chat/inference/utils.py b/applications/Chat/inference/utils.py index 37944be70a3b..e8e7b05ac719 100644 --- a/applications/Chat/inference/utils.py +++ b/applications/Chat/inference/utils.py @@ -1,9 +1,9 @@ +import json import re from threading import Lock from typing import Any, Callable, Generator, List, Optional -import json -import jieba +import jieba import torch import torch.distributed as dist import torch.nn as nn @@ -127,7 +127,7 @@ def _format_dialogue(instruction: str, response: str = ''): class ChatPromptProcessor: SAFE_RESPONSE = 'The input/response contains inappropriate content, please rephrase your prompt.' - def __init__(self, tokenizer, context: str, max_len: int = 2048, censored_words: List[str]=[]): + def __init__(self, tokenizer, context: str, max_len: int = 2048, censored_words: List[str] = []): self.tokenizer = tokenizer self.context = context self.max_len = max_len @@ -182,6 +182,7 @@ def has_censored_words(self, text: str) -> bool: intersection = set(jieba.cut(text.lower())) & self.censored_words return len(intersection) > 0 + class LockedIterator: def __init__(self, it, lock: Lock) -> None: @@ -195,6 +196,7 @@ def __next__(self): with self.lock: return next(self.it) + def load_json(path: str): with open(path) as f: - return json.load(f) \ No newline at end of file + return json.load(f) diff --git a/applications/Chat/tests/test_benchmarks.sh b/applications/Chat/tests/test_benchmarks.sh new file mode 100755 index 000000000000..3fdb25181342 --- /dev/null +++ b/applications/Chat/tests/test_benchmarks.sh @@ -0,0 +1,33 @@ +#!/bin/bash + +set -xue + +echo "Hint: You can run this script with 'verbose' as the first argument to run all strategies." + +if [[ $# -ne 0 && "$1" == "verbose" ]]; then + STRATEGIES=( + 'ddp' + 'colossalai_gemini' + 'colossalai_gemini_cpu' + 'colossalai_zero2' + 'colossalai_zero2_cpu' + 'colossalai_zero1' + 'colossalai_zero1_cpu' + ) +else + STRATEGIES=( + 'colossalai_zero2' + ) +fi + +BASE_DIR=$(dirname $(dirname $(realpath $BASH_SOURCE))) +BENCHMARKS_DIR=$BASE_DIR/benchmarks + +echo "[Test]: testing benchmarks ..." + +for strategy in ${STRATEGIES[@]}; do + torchrun --standalone --nproc_per_node 1 $BENCHMARKS_DIR/benchmark_opt_lora_dummy.py \ + --model 125m --critic_model 125m --strategy ${strategy} --lora_rank 4 \ + --num_episodes 2 --num_collect_steps 4 --num_update_steps 2 \ + --train_batch_size 2 --experience_batch_size 4 +done diff --git a/applications/Chat/tests/test_checkpoint.py b/applications/Chat/tests/test_checkpoint.py index 19338da437ab..3a3bf5b19cb8 100644 --- a/applications/Chat/tests/test_checkpoint.py +++ b/applications/Chat/tests/test_checkpoint.py @@ -7,7 +7,7 @@ import torch.distributed as dist from coati.models.gpt import GPTActor from coati.models.utils import calc_action_log_probs -from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy +from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy, Strategy from transformers.models.gpt2.configuration_gpt2 import GPT2Config from colossalai.nn.optimizer import HybridAdam @@ -17,40 +17,41 @@ def get_data(batch_size: int, seq_len: int = 10) -> dict: - input_ids = torch.randint(0, 50257, (batch_size, seq_len), device='cuda') + input_ids = torch.randint(0, 50257, (batch_size, seq_len), device="cuda") attention_mask = torch.ones_like(input_ids) return dict(input_ids=input_ids, attention_mask=attention_mask) -def run_test_checkpoint(strategy): - BATCH_SIZE = 2 +def train_step(strategy: Strategy, + actor: GPTActor, + actor_optim: HybridAdam, + batch_size: int = 8): + data = get_data(batch_size) + action_mask = torch.ones_like(data["attention_mask"], dtype=torch.bool) + actor_output = actor(data["input_ids"], data["attention_mask"]) + action_log_probs = calc_action_log_probs(actor_output, data["input_ids"], action_mask.size(1)) + loss = action_log_probs.sum() + strategy.backward(loss, actor, actor_optim) + strategy.optimizer_step(actor_optim) - if strategy == 'ddp': + +def run_test_checkpoint(strategy_name: str, + shard: bool): + if strategy_name == "ddp": strategy = DDPStrategy() - elif strategy == 'colossalai_gemini': - strategy = GeminiStrategy(placement_policy='cuda', initial_scale=2**5) - elif strategy == 'colossalai_zero2': - strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda') + elif strategy_name == "colossalai_gemini": + strategy = GeminiStrategy(placement_policy="cuda", initial_scale=2**5) + elif strategy_name == "colossalai_zero2": + strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda") else: - raise ValueError(f'Unsupported strategy "{strategy}"') + raise ValueError(f"Unsupported strategy '{strategy_name}'") with strategy.model_init_context(): actor = GPTActor(config=GPT_CONFIG).cuda() - actor_optim = HybridAdam(actor.parameters()) - actor, actor_optim = strategy.prepare((actor, actor_optim)) - def run_step(): - data = get_data(BATCH_SIZE) - action_mask = torch.ones_like(data['attention_mask'], dtype=torch.bool) - actor_output = actor(data['input_ids'], data['attention_mask']) - action_log_probs = calc_action_log_probs(actor_output, data['input_ids'], action_mask.size(1)) - loss = action_log_probs.sum() - strategy.backward(loss, actor, actor_optim) - strategy.optimizer_step(actor_optim) - - run_step() + train_step(strategy, actor, actor_optim) ctx = tempfile.TemporaryDirectory() if dist.get_rank() == 0 else nullcontext() @@ -59,43 +60,47 @@ def run_step(): dist.broadcast_object_list(rank0_dirname) rank0_dirname = rank0_dirname[0] - model_path = os.path.join(rank0_dirname, 'model.pt') - strategy.save_model(actor, model_path, only_rank0=True) - - optim_path = os.path.join(rank0_dirname, f'optim.pt') - strategy.save_optimizer(actor_optim, optim_path, only_rank0=True) - - # FIXME(cwher): Sharded optimizer checkpoint is not supported yet. - # at "ColossalAI/colossalai/checkpoint_io/general_checkpoint_io.py", line 62 - # optim_path = os.path.join(rank0_dirname, f'optim-r{dist.get_rank()}.pt') - # strategy.save_optimizer(actor_optim, optim_path, only_rank0=False) - + model_path = os.path.join( + rank0_dirname, "model" if shard else f"model.pt") + strategy.save_model(actor, model_path, only_rank0=not shard) + optim_path = os.path.join( + rank0_dirname, "optim" if shard else "optim.pt") + strategy.save_optimizer(actor_optim, optim_path, only_rank0=not shard) dist.barrier() strategy.load_model(actor, model_path, strict=False) strategy.load_optimizer(actor_optim, optim_path) - dist.barrier() - run_step() + train_step(strategy, actor, actor_optim) -def run_dist(rank, world_size, port, strategy): - os.environ['RANK'] = str(rank) - os.environ['LOCAL_RANK'] = str(rank) - os.environ['WORLD_SIZE'] = str(world_size) - os.environ['MASTER_ADDR'] = 'localhost' - os.environ['MASTER_PORT'] = str(port) - run_test_checkpoint(strategy) +def run_dist(rank: int, + world_size: int, + port: int, + strategy_name: str, + shard: bool): + os.environ["RANK"] = str(rank) + os.environ["LOCAL_RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(port) + run_test_checkpoint(strategy_name, shard) @pytest.mark.dist -@pytest.mark.parametrize('world_size', [2]) -@pytest.mark.parametrize('strategy', ['ddp', 'colossalai_zero2', 'colossalai_gemini']) +@pytest.mark.parametrize("world_size", [4]) +@pytest.mark.parametrize("strategy_name", ["ddp", "colossalai_gemini", "colossalai_zero2"]) +@pytest.mark.parametrize("shard", [False, True]) @rerun_if_address_is_in_use() -def test_checkpoint(world_size, strategy): - spawn(run_dist, world_size, strategy=strategy) +def test_checkpoint(world_size: int, + strategy_name: str, + shard: bool): + spawn(run_dist, + world_size, + strategy_name=strategy_name, + shard=shard) -if __name__ == '__main__': - test_checkpoint(2, 'colossalai_zero2') +if __name__ == "__main__": + test_checkpoint(2, "colossalai_gemini", shard=False) diff --git a/applications/Chat/tests/test_dataset.py b/applications/Chat/tests/test_dataset.py new file mode 100644 index 000000000000..64ea1178cd0d --- /dev/null +++ b/applications/Chat/tests/test_dataset.py @@ -0,0 +1,248 @@ +import json +import os +import tempfile +from typing import Optional + +import pytest +import torch +from coati.dataset.prompt_dataset import PromptDataset +from coati.dataset.reward_dataset import HhRlhfDataset, RmStaticDataset +from coati.dataset.sft_dataset import IGNORE_INDEX, SFTDataset, SupervisedDataset +from datasets import load_dataset +from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer, PreTrainedTokenizer +from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer + +SFT_DATASET = [ + { + "instruction": "Provide a list of the top 10 most popular mobile games in Asia", + "input": "", + "output": "The top 10 most popular mobile games in Asia are:\n1) PUBG Mobile\n2) Pokemon Go\n3) Candy Crush Saga\n4) Free Fire\n5) Clash of Clans\n6) Mario Kart Tour\n7) Arena of Valor\n8) Fantasy Westward Journey\n9) Subway Surfers\n10) ARK Survival Evolved", + "id": 0 + }, + { + "instruction": "Please provide an action plan for reducing carbon footprint on a corporate level", + "input": "", + "output": "An action plan for reducing carbon footprint on a corporate level could include switching to renewable energy sources (e.g., solar or wind powered), investing in energy efficient infrastructure, implementing green production practices (e.g., reduce water and paper usage, dispose of materials responsibly, recycle/reuse materials), reducing or eliminating travel, encouraging employees to be conscious of and reduce their own energy usage, and incentivizing employees to participate in sustainable work practices.", + "id": 1 + }, + { + "instruction": "Write a persuasive email to your boss explaining why you should have a pay raise", + "input": "", + "output": "Dear [Boss], \n\nI hope this message finds you well. I am writing to request a pay raise.\n\nAs you know, I have been a dedicated and hardworking member of the team since I started working here [insert number] of months/years ago. My enthusiasm and passion for my job has remained consistent over the years, and I have always given 100% to my role. \n\nI understand that the current financial situation is challenging, however, I would sincerely appreciate you taking the time to consider my request. I believe that my dedication to the job and the value that I bring to the organization warrants a raise. I work diligently and am confident that I can continue to be an asset to the company. \n\nI hope my request is taken into account and I thank you in advance for your understanding. I look forward to our conversation. \n\nSincerely,\n[Your Name]", + "id": 2 + }, +] + +PROMPT_DATASET = [ + { + "instruction": "Edit this paragraph to make it more concise: \"Yesterday, I went to the store and bought some things. Then, I came home and put them away. After that, I went for a walk and met some friends.\"", + "id": 0 + }, + { + "instruction": "Write a descriptive paragraph about a memorable vacation you went on", + "id": 1 + }, + { + "instruction": "Write a persuasive essay arguing why homework should be banned in schools", + "id": 2 + }, + { + "instruction": "Create a chart comparing the statistics on student debt in the United States.", + "id": 3 + }, +] + + +def make_tokenizer(model: str): + if model == "gpt2": + tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + tokenizer.pad_token = tokenizer.eos_token + elif model == "bloom": + tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m") + tokenizer.pad_token = tokenizer.eos_token + elif model == "opt": + tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + tokenizer.pad_token = tokenizer.eos_token + elif model == "llama": + tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") + tokenizer.pad_token = tokenizer.unk_token + else: + raise ValueError(f"Unsupported model '{model}'") + return tokenizer + + +def check_content(input_ids_stripped: torch.Tensor, + tokenizer: PreTrainedTokenizer, + model: str): + if model == "opt": + # NOTE: Contrary to GPT2, OPT adds the EOS token to the beginning of every prompt. + assert input_ids_stripped[0] == tokenizer.eos_token_id + input_ids_stripped = input_ids_stripped[1:] + elif model == "llama": + assert input_ids_stripped[0] == tokenizer.bos_token_id + input_ids_stripped = input_ids_stripped[1:] + + assert torch.all(input_ids_stripped != tokenizer.pad_token_id) + assert torch.all(input_ids_stripped != tokenizer.bos_token_id) + assert torch.all(input_ids_stripped != tokenizer.eos_token_id) + assert input_ids_stripped != tokenizer.sep_token_id + assert input_ids_stripped != tokenizer.cls_token_id + assert input_ids_stripped != tokenizer.mask_token_id + + +@pytest.mark.cpu +@pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama"]) +@pytest.mark.parametrize("max_length", [32, 1024]) +@pytest.mark.parametrize("max_datasets_size", [2]) +def test_prompt_dataset(model: str, + max_datasets_size: int, + max_length: int): + with tempfile.TemporaryDirectory() as tmp_dir: + dataset_name = "prompt_dataset.json" + with open(os.path.join(tmp_dir, dataset_name), "w") as f: + json.dump(PROMPT_DATASET, f) + tokenizer = make_tokenizer(model) + assert tokenizer.padding_side in ("left", "right") + prompt_dataset = PromptDataset(data_path=os.path.join(tmp_dir, dataset_name), + tokenizer=tokenizer, + max_datasets_size=max_datasets_size, + max_length=max_length) + assert len(prompt_dataset) == min(max_datasets_size, len(PROMPT_DATASET)) + for i in range(len(prompt_dataset)): + assert isinstance(prompt_dataset[i], dict) + assert list(prompt_dataset[i].keys()) == ["input_ids", "attention_mask"] + input_ids = prompt_dataset[i]["input_ids"] + attention_mask = prompt_dataset[i]["attention_mask"] + attention_mask = attention_mask.bool() + assert input_ids.shape == attention_mask.shape == torch.Size([max_length]) + assert torch.all(input_ids[torch.logical_not(attention_mask)] == tokenizer.pad_token_id) + check_content(input_ids.masked_select(attention_mask), tokenizer, model) + + +@pytest.mark.cpu +@pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama"]) +@pytest.mark.parametrize(["dataset_path", "subset"], [ + ("Anthropic/hh-rlhf", "harmless-base"), + ("Dahoas/rm-static", None) +]) +@pytest.mark.parametrize("max_datasets_size", [32]) +@pytest.mark.parametrize("max_length", [32, 1024]) +def test_reward_dataset(model: str, + dataset_path: str, + subset: Optional[str], + max_datasets_size: int, + max_length: int): + data = load_dataset(dataset_path, data_dir=subset) + assert max_datasets_size <= len(data["train"]) \ + and max_datasets_size <= len(data["test"]) + train_data = data["train"].select(range(max_datasets_size)) + test_data = data["test"].select(range(max_datasets_size)) + tokenizer = make_tokenizer(model) + assert tokenizer.padding_side in ("left", "right") + + if dataset_path == "Anthropic/hh-rlhf": + train_dataset = HhRlhfDataset(train_data, tokenizer, max_length) + test_dataset = HhRlhfDataset(test_data, tokenizer, max_length) + elif dataset_path == "Dahoas/rm-static": + train_dataset = RmStaticDataset(train_data, tokenizer, max_length) + test_dataset = RmStaticDataset(test_data, tokenizer, max_length) + else: + raise ValueError(f'Unsupported dataset "{dataset_path}"') + + assert len(train_dataset) == len(test_dataset) == max_datasets_size + for i in range(max_datasets_size): + chosen_ids, c_mask, reject_ids, r_mask = train_dataset[i] + assert chosen_ids.shape == c_mask.shape == \ + reject_ids.shape == r_mask.shape == torch.Size([max_length]) + c_mask = c_mask.to(torch.bool) + r_mask = r_mask.to(torch.bool) + if chosen_ids.masked_select(c_mask)[-1] == tokenizer.eos_token_id: + check_content(chosen_ids.masked_select(c_mask)[:-1], tokenizer, model) + assert torch.all(chosen_ids.masked_select(torch.logical_not(c_mask)) == tokenizer.pad_token_id) + else: + check_content(chosen_ids.masked_select(c_mask), tokenizer, model) + assert torch.all(c_mask) + if reject_ids.masked_select(r_mask)[-1] == tokenizer.eos_token_id: + check_content(reject_ids.masked_select(r_mask)[:-1], tokenizer, model) + assert torch.all(reject_ids.masked_select(torch.logical_not(r_mask)) == tokenizer.pad_token_id) + else: + check_content(reject_ids.masked_select(r_mask), tokenizer, model) + assert torch.all(r_mask) + + chosen_ids, c_mask, reject_ids, r_mask = test_dataset[i] + assert chosen_ids.shape == c_mask.shape == \ + reject_ids.shape == r_mask.shape == torch.Size([max_length]) + c_mask = c_mask.to(torch.bool) + r_mask = r_mask.to(torch.bool) + if chosen_ids.masked_select(c_mask)[-1] == tokenizer.eos_token_id: + check_content(chosen_ids.masked_select(c_mask)[:-1], tokenizer, model) + assert torch.all(chosen_ids.masked_select(torch.logical_not(c_mask)) == tokenizer.pad_token_id) + else: + check_content(chosen_ids.masked_select(c_mask), tokenizer, model) + assert torch.all(c_mask) + if reject_ids.masked_select(r_mask)[-1] == tokenizer.eos_token_id: + check_content(reject_ids.masked_select(r_mask)[:-1], tokenizer, model) + assert torch.all(reject_ids.masked_select(torch.logical_not(r_mask)) == tokenizer.pad_token_id) + else: + check_content(reject_ids.masked_select(r_mask), tokenizer, model) + assert torch.all(r_mask) + + +@pytest.mark.cpu +@pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama"]) +@pytest.mark.parametrize("dataset_path", ["yizhongw/self_instruct", None]) +@pytest.mark.parametrize("max_dataset_size", [2]) +@pytest.mark.parametrize("max_length", [32, 1024]) +def test_sft_dataset(model: str, + dataset_path: Optional[str], + max_dataset_size: int, + max_length: int): + tokenizer = make_tokenizer(model) + if dataset_path == "yizhongw/self_instruct": + data = load_dataset(dataset_path, "super_natural_instructions") + train_data = data["train"].select(range(max_dataset_size)) + sft_dataset = SFTDataset(train_data, tokenizer, max_length) + else: + with tempfile.TemporaryDirectory() as tmp_dir: + dataset_name = "sft_dataset.json" + with open(os.path.join(tmp_dir, dataset_name), "w") as f: + json.dump(SFT_DATASET, f) + sft_dataset = SupervisedDataset(tokenizer=tokenizer, + data_path=os.path.join(tmp_dir, dataset_name), + max_datasets_size=max_dataset_size, + max_length=max_length) + assert len(sft_dataset) == min(max_dataset_size, len(SFT_DATASET)) + + for i in range(max_dataset_size): + assert isinstance(sft_dataset[i], dict) + assert list(sft_dataset[i].keys()) == ["input_ids", "labels", "attention_mask"] + input_ids = sft_dataset[i]["input_ids"] + labels = sft_dataset[i]["labels"] + attention_mask = sft_dataset[i]["attention_mask"].to(torch.bool) + assert input_ids.shape == labels.shape == \ + attention_mask.shape == torch.Size([max_length]) + if input_ids.masked_select(attention_mask)[-1] == tokenizer.eos_token_id: + check_content(input_ids.masked_select(attention_mask)[:-1], tokenizer, model) + assert torch.all(input_ids.masked_select(torch.logical_not(attention_mask)) == tokenizer.pad_token_id) + else: + check_content(input_ids.masked_select(attention_mask), tokenizer, model) + assert torch.all(attention_mask) + ignore_mask = labels == IGNORE_INDEX + check_content(input_ids.masked_select(ignore_mask), tokenizer, model) + + +if __name__ == "__main__": + test_sft_dataset(model="bloom", + dataset_path="yizhongw/self_instruct", + max_dataset_size=2, + max_length=256) + + test_reward_dataset(model="gpt2", + dataset_path="Anthropic/hh-rlhf", + subset="harmless-base", + max_datasets_size=8, + max_length=256) + + test_prompt_dataset(model="opt", + max_datasets_size=2, + max_length=128) diff --git a/applications/Chat/tests/test_data.py b/applications/Chat/tests/test_experience.py similarity index 82% rename from applications/Chat/tests/test_data.py rename to applications/Chat/tests/test_experience.py index db641a6218b1..071e50b90e8e 100644 --- a/applications/Chat/tests/test_data.py +++ b/applications/Chat/tests/test_experience.py @@ -4,11 +4,12 @@ import pytest import torch import torch.distributed as dist +from coati.experience_buffer import NaiveExperienceBuffer from coati.experience_maker import NaiveExperienceMaker from coati.models.base import RewardModel from coati.models.gpt import GPTActor, GPTCritic -from coati.replay_buffer import NaiveReplayBuffer from coati.trainer.strategies import DDPStrategy, GeminiStrategy +from coati.trainer.strategies.colossalai import LowLevelZeroStrategy from transformers.models.gpt2.configuration_gpt2 import GPT2Config from colossalai.testing import rerun_if_address_is_in_use, spawn @@ -32,13 +33,15 @@ def gather_and_equal(tensor: torch.Tensor) -> bool: return True -def run_test_data(strategy): +def make_and_consume_experience(strategy): EXPERIENCE_BATCH_SIZE = 4 SAMPLE_BATCH_SIZE = 2 if strategy == 'ddp': strategy = DDPStrategy() - elif strategy == 'colossalai': + elif strategy == 'colossalai-zero2': + strategy = LowLevelZeroStrategy() + elif strategy == 'colossalai-gemini': strategy = GeminiStrategy(placement_policy='cuda') else: raise ValueError(f'Unsupported strategy "{strategy}"') @@ -50,7 +53,7 @@ def run_test_data(strategy): reward_model = RewardModel(deepcopy(critic.model)).cuda() experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model) - replay_buffer = NaiveReplayBuffer(SAMPLE_BATCH_SIZE, cpu_offload=False) + data_buffer = NaiveExperienceBuffer(SAMPLE_BATCH_SIZE, cpu_offload=False) # experience of all ranks should be the same for _ in range(2): @@ -69,12 +72,12 @@ def run_test_data(strategy): assert gather_and_equal(experience.advantages) assert gather_and_equal(experience.action_mask) assert gather_and_equal(experience.attention_mask) - replay_buffer.append(experience) + data_buffer.append(experience) - # replay buffer's data should be the same - buffer_size = torch.tensor([len(replay_buffer)], device='cuda') + # data buffer's data should be the same + buffer_size = torch.tensor([len(data_buffer)], device='cuda') assert gather_and_equal(buffer_size) - for item in replay_buffer.items: + for item in data_buffer.items: assert gather_and_equal(item.sequences) assert gather_and_equal(item.action_log_probs) assert gather_and_equal(item.values) @@ -84,7 +87,7 @@ def run_test_data(strategy): assert gather_and_equal(item.attention_mask) # dataloader of each rank should have the same size and different batch - dataloader = strategy.setup_dataloader(replay_buffer) + dataloader = strategy.setup_dataloader(data_buffer) dataloader_size = torch.tensor([len(dataloader)], device='cuda') assert gather_and_equal(dataloader_size) for experience in dataloader: @@ -102,17 +105,16 @@ def run_dist(rank, world_size, port, strategy): os.environ['WORLD_SIZE'] = str(world_size) os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = str(port) - run_test_data(strategy) + make_and_consume_experience(strategy) -@pytest.mark.skip @pytest.mark.dist @pytest.mark.parametrize('world_size', [2]) -@pytest.mark.parametrize('strategy', ['ddp', 'colossalai']) +@pytest.mark.parametrize('strategy', ['ddp', 'colossalai-zero2', 'colossalai-gemini']) @rerun_if_address_is_in_use() -def test_data(world_size, strategy): +def test_experience(world_size, strategy): spawn(run_dist, world_size, strategy=strategy) if __name__ == '__main__': - test_data(2, 'colossalai') + test_experience(2, 'colossalai') diff --git a/applications/Chat/tests/test_inference.sh b/applications/Chat/tests/test_inference.sh new file mode 100755 index 000000000000..849db06e58ab --- /dev/null +++ b/applications/Chat/tests/test_inference.sh @@ -0,0 +1,11 @@ +set -xue + +BASE_DIR=$(dirname $(dirname $(realpath $BASH_SOURCE))) +EXAMPLES_DIR=$BASE_DIR/examples + +echo "[Test]: testing inference ..." + +# HACK: skip llama due to oom +for model in 'gpt2' 'bloom' 'opt'; do + python $EXAMPLES_DIR/inference.py --model $model +done diff --git a/applications/Chat/tests/test_models.py b/applications/Chat/tests/test_models.py new file mode 100644 index 000000000000..bd6b3e8a5ad1 --- /dev/null +++ b/applications/Chat/tests/test_models.py @@ -0,0 +1,235 @@ +import copy +from typing import Any, Callable, Dict, Tuple + +import pytest +import torch +import torch.nn as nn +from coati.models.base import Actor, Critic, RewardModel, get_base_model +from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic +from coati.models.generation import generate +from coati.models.gpt import GPTRM, GPTActor, GPTCritic +from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM +from coati.models.lora import LoraLinear, convert_to_lora_module +from coati.models.loss import GPTLMLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss +from coati.models.opt import OPTRM, OPTActor, OPTCritic +from coati.models.utils import calc_action_log_probs, compute_reward, masked_mean + + +@pytest.mark.gpu +@pytest.mark.parametrize("batch_size", [4]) +@pytest.mark.parametrize("seq_len", [32]) +@pytest.mark.parametrize("actor_maker", [ + lambda: BLOOMActor(), + lambda: GPTActor(), + # HACK: skip llama due to long execution time + # lambda: LlamaActor(), + lambda: OPTActor() +]) +@pytest.mark.parametrize("generate_kwargs", [{ + "max_length": 64, + "use_cache": True, + "do_sample": True, + "temperature": 1.0, + "top_k": 50, +}]) +def test_generation(actor_maker: Callable[[], Actor], + batch_size: int, + seq_len: int, + generate_kwargs: Dict[str, Any] + ): + actor = actor_maker() + input_ids = torch.randint(0, 100, (batch_size, seq_len)).cuda() + sequences = generate(actor.cuda(), input_ids, **generate_kwargs) + assert sequences.shape == (batch_size, generate_kwargs["max_length"]) + + +@pytest.mark.cpu +def test_utils(): + fn_input = { + "tensor": torch.ones((10, )), + "mask": torch.randint(0, 2, (10, )) + } + fn_output = masked_mean(dim=0, **fn_input) + assert fn_output.dim() == 0 + assert torch.allclose(fn_output, torch.tensor(1.0)) + + batch_size = 4 + num_labels = 10 + fn_input = { + "r": torch.ones((batch_size, )), + "kl_coef": 1.0, + "log_probs": torch.randn((batch_size, num_labels)), + "log_probs_base": torch.randn((batch_size, num_labels)), + "action_mask": torch.randint(0, 2, (batch_size, num_labels)) + } + fn_output = compute_reward(**fn_input) + assert fn_output.shape == (batch_size, ) + + batch_size = 4 + seq_len = 32 + num_labels = 10 + num_actions = 2 + fn_input = { + "output": { + "logits": torch.randn((batch_size, seq_len, num_labels)) + }, + "sequences": torch.randint(0, num_labels, (batch_size, seq_len)), + "num_actions": num_actions, + } + fn_output = calc_action_log_probs(**fn_input) + assert fn_output.shape == (batch_size, num_actions) + + +@pytest.mark.cpu +@pytest.mark.parametrize("lora_rank", [4]) +@pytest.mark.parametrize("num_dim", [32]) +@pytest.mark.parametrize("num_layers", [4]) +def test_lora(lora_rank: int, + num_dim: int, + num_layers: int): + model = nn.ModuleList( + [nn.Linear(num_dim, num_dim) + for _ in range(num_layers)] + ) + lora_model = convert_to_lora_module(model, lora_rank) + assert isinstance(lora_model, nn.ModuleList) + for i in range(num_layers): + assert isinstance(lora_model[i], LoraLinear) + assert lora_model[i].lora_A.shape == (lora_rank, num_dim) + assert lora_model[i].lora_B.shape == (num_dim, lora_rank) + + old_model = copy.deepcopy(lora_model) + for i in range(num_layers): + assert isinstance(lora_model[i], LoraLinear) + assert torch.allclose(old_model[i].weight, lora_model[i].weight) + assert torch.allclose(old_model[i].bias, lora_model[i].bias) + assert torch.allclose(old_model[i].lora_B @ old_model[i].lora_A, + lora_model[i].lora_B @ lora_model[i].lora_A) + optimizer = torch.optim.Adam(lora_model.parameters()) + x = torch.randn(8, num_dim) + for i in range(num_layers): + x = lora_model[i](x) + loss = x.sum() + loss.backward() + optimizer.step() + for i in range(num_layers): + assert isinstance(lora_model[i], LoraLinear) + assert torch.allclose(old_model[i].weight, lora_model[i].weight) + assert torch.allclose(old_model[i].bias, lora_model[i].bias) + assert not torch.allclose(old_model[i].lora_B @ old_model[i].lora_A, + lora_model[i].lora_B @ lora_model[i].lora_A) + + +@pytest.mark.cpu +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("seq_len", [128]) +@pytest.mark.parametrize("models_maker", [ + lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()), + lambda: (GPTActor(), GPTCritic(), GPTRM()), + # HACK: skip llama due to long execution time + # lambda: (LlamaActor(), LlamaCritic(), LlamaRM()), + lambda: (OPTActor(), OPTCritic(), OPTRM()), +]) +@torch.no_grad() +def test_models(models_maker: Callable[[], Tuple[Actor, Critic, RewardModel]], + batch_size: int, + seq_len: int): + + actor_input = { + "input_ids": torch.randint(0, 100, (batch_size, seq_len)), + "attention_mask": torch.randint(0, 2, (batch_size, seq_len)) + } + critic_input = { + "sequences": torch.randint(0, 100, (batch_size, seq_len)), + "action_mask": torch.randint(0, 2, (batch_size, seq_len)), + "attention_mask": torch.randint(0, 2, (batch_size, seq_len)) + } + rm_input = { + "sequences": torch.randint(0, 100, (batch_size, seq_len)), + "attention_mask": torch.randint(0, 2, (batch_size, seq_len)) + } + + actor, critic, rm = models_maker() + assert isinstance(actor, Actor) + base_actor_model = get_base_model(actor) + assert isinstance(critic, Critic) + base_critic_model = get_base_model(critic) + assert isinstance(rm, RewardModel) + base_rm_model = get_base_model(rm) + + actor_output = actor(**actor_input) + critic_output = critic(**critic_input) + rm_output = rm(**rm_input) + + assert actor_output.logits.shape[:2] == (batch_size, seq_len) + assert critic_output.shape == (batch_size, ) + assert rm_output.shape == (batch_size, ) + + +@pytest.mark.cpu +@pytest.mark.parametrize("batch_size", [16]) +@pytest.mark.parametrize("seq_len", [128]) +@pytest.mark.parametrize("num_labels", [100]) +def test_loss(batch_size: int, + seq_len: int, + num_labels: int): + loss = GPTLMLoss() + loss_input = { + "logits": torch.randn(batch_size, seq_len, num_labels), + "labels": torch.randint(0, num_labels, (batch_size, seq_len)) + } + loss_output = loss(**loss_input) + + loss = PolicyLoss() + loss_input = { + "log_probs": torch.randn(batch_size, ), + "old_log_probs": torch.randn(batch_size, ), + "advantages": torch.randn(batch_size, ) + } + loss_output = loss(**loss_input) + + loss = ValueLoss() + loss_input = { + "values": torch.randn(batch_size, ), + "old_values": torch.randn(batch_size, ), + "reward": torch.randn(batch_size, ) + } + loss_output = loss(**loss_input) + + loss = LogSigLoss() + loss_input = { + "chosen_reward": torch.randn(batch_size, ), + "reject_reward": torch.randn(batch_size, ), + } + loss_output = loss(**loss_input) + + loss = LogExpLoss() + loss_input = { + "chosen_reward": torch.randn(batch_size, ), + "reject_reward": torch.randn(batch_size, ), + } + loss_output = loss(**loss_input) + + +if __name__ == "__main__": + generate_kwargs = dict(max_length=40, + use_cache=True, + do_sample=True, + temperature=1.0, + top_k=50) + test_generation(lambda: LlamaActor(), + batch_size=4, + seq_len=32, + generate_kwargs=generate_kwargs) + + test_utils() + + test_lora(lora_rank=2, num_dim=8, num_layers=2) + + test_models(models_maker=lambda: (BLOOMActor(), + BLOOMCritic(), + BLOOMRM()), + batch_size=8, + seq_len=128) + + test_loss(batch_size=8, seq_len=128, num_labels=100) diff --git a/applications/Chat/tests/test_train.sh b/applications/Chat/tests/test_train.sh new file mode 100755 index 000000000000..c5127c188612 --- /dev/null +++ b/applications/Chat/tests/test_train.sh @@ -0,0 +1,228 @@ +#!/usr/bin/env bash + +set_n_least_used_CUDA_VISIBLE_DEVICES() { + local n=${1:-"9999"} + echo "GPU Memory Usage:" + local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv | + tail -n +2 | + nl -v 0 | + tee /dev/tty | + sort -g -k 2 | + awk '{print $1}' | + head -n $n) + export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') + echo "Now CUDA_VISIBLE_DEVICES is set to:" + echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" +} + +set_n_least_used_CUDA_VISIBLE_DEVICES 4 + +set -xu + +if [ -z "$SFT_DATASET" ]; then + echo "Please set \$SFT_DATASET to the path to sft dataset." + exit 1 +fi + +if [ -z "$PROMPT_PATH" ]; then + echo "Please set \$PROMPT_PATH to the path to prompts csv." + exit 1 +fi + +if [ -z "$PRETRAIN_DATASET" ]; then + echo "Please set \$PRETRAIN_DATASET to the path to alpaca data." + exit 1 +fi + +NUM_RETRY=3 +BASE_DIR=$(dirname $(dirname $(realpath $BASH_SOURCE))) +EXAMPLES_DIR=$BASE_DIR/examples +MODELS_DIR=$BASE_DIR/examples/models_config +MODELS=('gpt2' 'bloom' 'opt' 'llama') +STRATEGIES=('ddp' 'colossalai_gemini' 'colossalai_zero2') + +export OMP_NUM_THREADS=8 + +# install requirements +pip install -r $EXAMPLES_DIR/requirements.txt + +python $EXAMPLES_DIR/download_model.py --model-dir $MODELS_DIR --config-only + +get_pretrain() { + local model=$1 + if [[ $model == "gpt2" ]]; then + echo "gpt2" + elif [[ $model == "bloom" ]]; then + echo "bigscience/bloom-560m" + elif [[ $model == "opt" ]]; then + echo "facebook/opt-350m" + else + echo "Unknown model $model" + exit 1 + fi +} + +random_choice() { + local arr=("$@") + local len=${#arr[@]} + local idx=$((RANDOM % len)) + echo ${arr[$idx]} +} + +echo "[Test]: testing sft ..." + +# FIXME: This is a hack to skip tests that are not working +# - gpt2-ddp: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation +# - llama-*: These tests can be passed locally, skipped for long execution time +SKIPPED_TESTS=( + "gpt2-ddp" + "llama-ddp" + "llama-colossalai_gemini" + "llama-colossalai_zero2" +) + +GRAD_CKPTS=('' '--grad_checkpoint') +for lora_rank in '0' '4'; do + for model in ${MODELS[@]}; do + strategies=($(shuf -e "${STRATEGIES[@]}")) + for strategy in ${strategies[@]}; do + if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$strategy-$lora_rank " ]]; then + echo "[Test]: Skipped $model-$strategy-$lora_rank" + continue + elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$strategy " ]]; then + echo "[Test]: Skipped $model-$strategy" + continue + fi + pretrain=$(get_pretrain $model) + pretrain_model="" + if [[ $lora_rank -gt 0 ]]; then + pretrain_model="--pretrain $pretrain" + fi + grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}") + for i in $(seq $NUM_RETRY); do + echo "[Test]: $model-$strategy-$lora_rank, attempt $i" + torchrun --standalone --nproc_per_node=4 $EXAMPLES_DIR/train_sft.py \ + $pretrain_model --tokenizer $MODELS_DIR/$model \ + --model $model --strategy $strategy --lora_rank $lora_rank $grad_ckpt \ + --dataset $SFT_DATASET --max_datasets_size 8 \ + --max_epochs 1 --batch_size 1 --accumulation_steps 1 \ + --save_path $EXAMPLES_DIR/rlhf_models/sft_ckpt_${model}_${lora_rank} + passed=$? + if [ $passed -eq 0 ]; then + break + fi + done + if [ $passed -ne 0 ]; then + echo "[Test]: Failed $model-$strategy-$lora_rank" + exit 1 + fi + done + done +done + +echo "[Test]: testing reward model ..." + +# FIXME: This is a hack to skip tests that are not working +# - gpt2-ddp: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation +# - llama-*: These tests can be passed locally, skipped for long execution time +SKIPPED_TESTS=( + "gpt2-ddp" + "llama-ddp" + "llama-colossalai_gemini" + "llama-colossalai_zero2" +) + +LOSS_FNS=('log_sig' 'log_exp') +DATASETS=('Anthropic/hh-rlhf' 'Dahoas/rm-static') +for lora_rank in '0' '4'; do + for model in ${MODELS[@]}; do + strategies=($(shuf -e "${STRATEGIES[@]}")) + for strategy in ${strategies[@]}; do + if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$strategy-$lora_rank " ]]; then + echo "[Test]: Skipped $model-$strategy-$lora_rank" + continue + elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$strategy " ]]; then + echo "[Test]: Skipped $model-$strategy" + continue + fi + pretrain=$(get_pretrain $model) + pretrain_model="" + if [[ $lora_rank -gt 0 ]]; then + pretrain_model="--pretrain $pretrain" + fi + loss_fn=$(random_choice "${LOSS_FNS[@]}") + dataset=$(random_choice "${DATASETS[@]}") + subset=$(if [[ $dataset == "Dahoas/rm-static" ]]; then echo "None"; else echo "harmless-base"; fi) + for i in $(seq $NUM_RETRY); do + echo "[Test]: $model-$strategy-$lora_rank, attempt $i" + torchrun --standalone --nproc_per_node=4 $EXAMPLES_DIR/train_reward_model.py \ + $pretrain_model --tokenizer $MODELS_DIR/$model \ + --model $model --strategy $strategy --lora_rank $lora_rank --loss_fn $loss_fn \ + --dataset $dataset --subset $subset --test True --batch_size 1 \ + --save_path $EXAMPLES_DIR/rlhf_models/rm_ckpt_${model}_${lora_rank}.pt + passed=$? + if [ $passed -eq 0 ]; then + break + fi + done + if [ $passed -ne 0 ]; then + echo "[Test]: Failed to train reward model $model-$strategy-$lora_rank" + exit 1 + fi + done + done +done + +echo "[Test]: testing RLHF ..." + +# FIXME: This is a hack to skip tests that are not working +# - gpt2-ddp: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation +# - llama-*: These tests can be passed locally, skipped for long execution time +SKIPPED_TESTS=( + "gpt2-ddp" + "llama-ddp" + "llama-colossalai_gemini" + "llama-colossalai_zero2" +) + +for model in ${MODELS[@]}; do + for lora_rank in '0' '4'; do + strategies=($(shuf -e "${STRATEGIES[@]}")) + for strategy in ${strategies[@]}; do + if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$strategy-$lora_rank " ]]; then + echo "[Test]: Skipped $model-$strategy-$lora_rank" + continue + elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$strategy " ]]; then + echo "[Test]: Skipped $model-$strategy" + continue + fi + rm_pretrain=$(get_pretrain $model) + rm_pretrain_model="" + if [[ $lora_rank -gt 0 ]]; then + rm_pretrain_model="--rm_pretrain $rm_pretrain" + fi + for i in $(seq $NUM_RETRY); do + echo "[Test]: $model-$strategy-$lora_rank, attempt $i" + torchrun --standalone --nproc_per_node=4 $EXAMPLES_DIR/train_prompts.py \ + --prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \ + --strategy $strategy --model $model --tokenizer $MODELS_DIR/$model \ + --num_episodes 1 --num_collect_steps 1 --num_update_steps 1 \ + --experience_batch_size 2 --train_batch_size 1 --lora_rank $lora_rank \ + --pretrain $EXAMPLES_DIR/rlhf_models/sft_ckpt_${model}_${lora_rank} \ + $rm_pretrain_model --rm_path $EXAMPLES_DIR/rlhf_models/rm_ckpt_${model}_${lora_rank}.pt \ + --save_path $EXAMPLES_DIR/rlhf_models/actor_checkpoint_prompts.pt + passed=$? + if [ $passed -eq 0 ]; then + break + fi + done + if [ $passed -ne 0 ]; then + echo "[Test]: Failed to train RLHF $model-$strategy-$lora_rank" + exit 1 + fi + done + rm -rf $EXAMPLES_DIR/rlhf_models/sft_ckpt_${model}_${lora_rank} + rm $EXAMPLES_DIR/rlhf_models/rm_ckpt_${model}_${lora_rank}.pt + done +done +rm $EXAMPLES_DIR/rlhf_models/actor_checkpoint_prompts.pt diff --git a/colossalai/booster/booster.py b/colossalai/booster/booster.py index cee547b33b0c..ec3dc7fc143f 100644 --- a/colossalai/booster/booster.py +++ b/colossalai/booster/booster.py @@ -9,7 +9,7 @@ from torch.utils.data import DataLoader from colossalai.checkpoint_io import GeneralCheckpointIO -from colossalai.interface import ModelWrapper +from colossalai.interface import ModelWrapper, OptimizerWrapper from .accelerator import Accelerator from .mixed_precision import MixedPrecision, mixed_precision_factory @@ -153,18 +153,20 @@ def execute_pipeline(self, # return loss or outputs if needed pass - def no_sync(self, model: nn.Module) -> contextmanager: + def no_sync(self, model: nn.Module = None, optimizer: OptimizerWrapper = None) -> contextmanager: """Context manager to disable gradient synchronization across DP process groups. + Support torch DDP and Low Level ZeRO-1 for now. Args: - model (nn.Module): The model to be disabled gradient synchronization. + model (nn.Module): The model to be disabled gradient synchronization, for DDP + optimizer (OptimizerWrapper): The optimizer to be disabled gradient synchronization, for ZeRO1-1 Returns: contextmanager: Context to disable gradient synchronization. """ assert self.plugin is not None, f'no_sync is only enabled when a plugin is provided and the plugin supports no_sync.' - assert self.plugin.support_no_sync, f'The plugin {self.plugin.__class__.__name__} does not support no_sync.' - return self.plugin.no_sync(model) + assert self.plugin.support_no_sync(), f'The plugin {self.plugin.__class__.__name__} does not support no_sync.' + return self.plugin.no_sync(model, optimizer) def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True): """Load model from checkpoint. diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 7b6e17337d36..0f5ba6e9a6da 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -408,5 +408,5 @@ def control_checkpoint_io(self) -> bool: def get_checkpoint_io(self) -> CheckpointIO: return GeminiCheckpointIO() - def no_sync(self, model: nn.Module) -> Iterator[None]: + def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: raise NotImplementedError diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 3ec0d34092a4..616b218b2070 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -1,5 +1,8 @@ +import logging +import os import warnings from functools import partial +from pathlib import Path from typing import Callable, Iterator, List, Optional, Tuple, Union import torch @@ -10,10 +13,16 @@ from torch.utils._pytree import tree_map from torch.utils.data import DataLoader -from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO +from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO +from colossalai.checkpoint_io.utils import ( + get_optimizer_base_filenames, + get_shard_filename, + save_param_groups, + save_state_dict, +) from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.utils import get_current_device -from colossalai.zero import zero_model_wrapper, zero_optim_wrapper +from colossalai.zero import LowLevelZeroOptimizer, zero_model_wrapper, zero_optim_wrapper from .dp_plugin_base import DPPluginBase from .torch_ddp_plugin import TorchDDPCheckpointIO @@ -32,21 +41,104 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16): class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): - def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool): - """ - Save optimizer to checkpoint but only on master process. + def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False): + """Save optimizer to checkpoint but only on master process. + + Args: + optimizer (OptimizerWrapper): Optimizer to save state_dict + checkpoint (str): Path to save checkpoint + gather_dtensor (bool): Whether to gather_dtensor, not used """ - # TODO(ver217): optimizer state dict is sharded, and cannot get full state dict now - warnings.warn( - 'LowLevelZeroPlugin does not support save full optimizer checkpoint now. Save it on every process.') - checkpoint = f'{checkpoint}.rank{self.coordinator.rank}' - GeneralCheckpointIO.save_unsharded_optimizer(self, optimizer, checkpoint, gather_dtensor) - def load_optimizer(self, optimizer: Optimizer, checkpoint: str): - warnings.warn( - 'LowLevelZeroPlugin can only load optimizer checkpoint saved by itself with the same number of processes.') - checkpoint = f'{checkpoint}.rank{self.coordinator.rank}' - super().load_optimizer(optimizer, checkpoint) + # the `state_dict` in LowLevelZeroOptimizer has communication + # if only the master rank collect state_dict and save, + # the communication on each rank would not match + state_dict = optimizer.state_dict() + if self.coordinator.is_master(): + save_state_dict(state_dict, checkpoint, use_safetensors=False) + + def save_sharded_optimizer(self, + optimizer: OptimizerWrapper, + checkpoint: str, + gather_dtensor: bool = False, + prefix: str = None, + size_per_shard: int = 1024): + """ + Save sharded Zero-optimizer checkpoint under the given checkpointing path. + The following files will be created under the path: + - An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names + - A group file (pytorch_optim_group.bin) recording information of param_groups + - Multiple files (pytorch_optim-000XX.bin) that store state tensors of optimizer in a sharding way + + Args: + optimizer (OptimizerWrapper): Optimizer to save sharded state_dict + checkpoint (str): Path to save optimizer state_dict + gather_dtensor (bool): Whether to gather_dtensor, not used + prefix (str): Perfix of file to save + size_per_shard (int): Max file size of each file that store state tensors + """ + if os.path.isfile(checkpoint): + logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") + return + + Path(checkpoint).mkdir(parents=True, exist_ok=True) + + # state_dict only provide only 'param_groups' + state_dict = optimizer.optim.state_dict() + # state shard would be handled by the low-level zero optimizer + sharded_state = optimizer.state_dict_shard(max_shard_size=size_per_shard) + + # Preparing file paths and index file. + states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix) + index_file = CheckpointIndexFile(checkpoint) + + # Store the information of param groups to param_group_file. + index_file.append_meta_data("param_groups", param_group_file) + group_file_path = os.path.join(checkpoint, param_group_file) + save_param_groups(state_dict, group_file_path) + + # Save shards of optimizer states. + total_size = 0 + for idx, shard_pair in enumerate(sharded_state): + shard, current_size = shard_pair + shard_file = get_shard_filename(states_name, idx) + total_size = total_size + current_size + for param_id in shard.keys(): + index_file.append_weight_map(str(param_id), shard_file) + + checkpoint_file_path = os.path.join(checkpoint, shard_file) + if self.coordinator.is_master(): + save_state_dict(shard, checkpoint_file_path, use_safetensors=False) + + # Wrap up index file. + index_file.append_meta_data("total_size", total_size) + if self.coordinator.is_master(): + index_file.write_index_file(save_index_file) + logging.info(f"The optimizer is going to be split to checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}.") + + def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: str, prefix: str): + """Load sharded optimizer with the given path to index file. + + Args: + optimizer (OptimizerWrapper): Optimizer to load state_dict + index_file_path (str): Path to the index file + prefix (str): Not used. + """ + super().load_sharded_optimizer(optimizer, index_file_path, prefix) + current_rank_state_dict = optimizer.optim.state_dict()['state'] + for param_idx, state in current_rank_state_dict.items(): + for k, v in state.items(): + if isinstance(v, torch.Tensor) and k != 'step': + padding_size = (self.coordinator.world_size - + v.numel() % self.coordinator.world_size) % self.coordinator.world_size + with torch.no_grad(): + v = v.flatten() + if padding_size > 0: + v = torch.nn.functional.pad(v, [0, padding_size]) + v_list = v.split(v.numel() // self.coordinator.world_size) + current_rank_state_dict[param_idx][k] = v_list[self.coordinator.rank].detach() class LowLevelZeroModel(ModelWrapper): @@ -74,36 +166,6 @@ def forward(self, *args, **kwargs): return super().forward(*args, **kwargs) -class LowLevelZeroOptimizer(OptimizerWrapper): - - def __init__(self, - module: nn.Module, - optimizer: Optimizer, - zero_optim_config: dict, - optim_kwargs: dict, - verbose: bool = False) -> None: - optimizer = zero_optim_wrapper(module, - optimizer, - optim_config=zero_optim_config, - **optim_kwargs, - verbose=verbose) - super().__init__(optimizer) - - def backward(self, loss: Tensor, *args, **kwargs): - self.optim.backward(loss) - - def clip_grad_by_norm(self, - max_norm: Union[float, int], - norm_type: Union[float, int] = 2, - error_if_nonfinite: bool = False, - *args, - **kwargs) -> Tensor: - warnings.warn(f'LowLevelZero controls grad clipping by itself, so you should not use clip_grad_by_norm') - - def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None: - raise NotImplementedError('LowLevelZero does not support clip_grad_by_value') - - class LowLevelZeroPlugin(DPPluginBase): """ Plugin for low level zero. @@ -179,8 +241,11 @@ def __init__( norm_type=norm_type) self.verbose = verbose + # set class name with stage, for better error message + setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}") + def support_no_sync(self) -> bool: - return False + return self.stage == 1 def control_precision(self) -> bool: return True @@ -208,8 +273,11 @@ def configure( if optimizer is not None and \ not isinstance(optimizer, OptimizerWrapper): - optimizer = LowLevelZeroOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs, - self.verbose) + optimizer = zero_optim_wrapper(model.unwrap(), + optimizer, + optim_config=self.zero_optim_config, + **self.optim_kwargs, + verbose=self.verbose) return model, optimizer, criterion, dataloader, lr_scheduler @@ -219,5 +287,6 @@ def control_checkpoint_io(self) -> bool: def get_checkpoint_io(self) -> CheckpointIO: return LowLevelZeroCheckpointIO() - def no_sync(self, model: nn.Module) -> Iterator[None]: - raise NotImplementedError + def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: + assert isinstance(optimizer, LowLevelZeroOptimizer) + return optimizer.optim.no_sync() diff --git a/colossalai/booster/plugin/plugin_base.py b/colossalai/booster/plugin/plugin_base.py index aa78f6827003..fb21e57f41f7 100644 --- a/colossalai/booster/plugin/plugin_base.py +++ b/colossalai/booster/plugin/plugin_base.py @@ -61,7 +61,7 @@ def get_checkpoint_io(self) -> CheckpointIO: pass @abstractmethod - def no_sync(self, model: nn.Module) -> Iterator[None]: + def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: """ Context manager to disable gradient synchronization. """ diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py index 71b435155503..f3f779c88e42 100644 --- a/colossalai/booster/plugin/torch_ddp_plugin.py +++ b/colossalai/booster/plugin/torch_ddp_plugin.py @@ -168,6 +168,6 @@ def control_checkpoint_io(self) -> bool: def get_checkpoint_io(self) -> CheckpointIO: return TorchDDPCheckpointIO() - def no_sync(self, model: nn.Module) -> Iterator[None]: + def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: assert isinstance(model, TorchDDPModel), 'Model is not boosted by TorchDDPPlugin.' return model.module.no_sync() diff --git a/colossalai/booster/plugin/torch_fsdp_plugin.py b/colossalai/booster/plugin/torch_fsdp_plugin.py index abfffa9b099e..fb7b5baadd0c 100644 --- a/colossalai/booster/plugin/torch_fsdp_plugin.py +++ b/colossalai/booster/plugin/torch_fsdp_plugin.py @@ -177,7 +177,7 @@ def __init__( def support_no_sync(self) -> bool: False - def no_sync(self, model: nn.Module) -> Iterator[None]: + def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: raise NotImplementedError("Torch fsdp no_sync func not supported yet.") def control_precision(self) -> bool: diff --git a/colossalai/cli/launcher/hostinfo.py b/colossalai/cli/launcher/hostinfo.py index d1b88b229fb8..2a6a111e4d72 100644 --- a/colossalai/cli/launcher/hostinfo.py +++ b/colossalai/cli/launcher/hostinfo.py @@ -46,11 +46,8 @@ def is_host_localhost(hostname: str, port: str = None) -> None: localhost = socket.gethostname() localaddrs = socket.getaddrinfo(localhost, port) targetaddrs = socket.getaddrinfo(hostname, port) - for (family, socktype, proto, canonname, sockaddr) in localaddrs: - for (rfamily, rsocktype, rproto, rcanonname, rsockaddr) in targetaddrs: - if rsockaddr[0] == sockaddr[0]: - return True - return False + + return localaddrs == targetaddrs def __str__(self): return f'hostname: {self.hostname}, port: {self.port}' diff --git a/colossalai/utils/checkpoint_io/__init__.py b/colossalai/utils/checkpoint_io/__init__.py deleted file mode 100644 index fe030866894f..000000000000 --- a/colossalai/utils/checkpoint_io/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .io import load, merge, redist, save -from .meta import (ParamDistMeta, ParamRedistMeta, PipelineRedistMeta, RankRedistMeta, RedistMeta) diff --git a/colossalai/utils/checkpoint_io/backend.py b/colossalai/utils/checkpoint_io/backend.py deleted file mode 100644 index 140192c05f12..000000000000 --- a/colossalai/utils/checkpoint_io/backend.py +++ /dev/null @@ -1,74 +0,0 @@ -import shutil -import tempfile -from abc import ABC, abstractmethod -from typing import Dict, List, Type - -from .reader import CheckpointReader, DiskCheckpointReader -from .writer import CheckpointWriter, DiskCheckpointWriter - -_backends: Dict[str, Type['CheckpointIOBackend']] = {} - - -def register(name: str): - assert name not in _backends, f'"{name}" is registered' - - def wrapper(cls): - _backends[name] = cls - return cls - - return wrapper - - -def get_backend(name: str) -> 'CheckpointIOBackend': - assert name in _backends, f'Unsupported backend "{name}"' - return _backends[name]() - - -class CheckpointIOBackend(ABC): - - def __init__(self) -> None: - super().__init__() - self.temps: List[str] = [] - - @abstractmethod - def get_writer(self, - base_name: str, - overwrite: bool = False, - rank: int = 0, - world_size: int = 1) -> CheckpointWriter: - pass - - @abstractmethod - def get_reader(self, base_name: str) -> CheckpointReader: - pass - - @abstractmethod - def get_temp(self, base_name: str) -> str: - pass - - @abstractmethod - def clean_temp(self) -> None: - pass - - -@register('disk') -class CheckpointDiskIO(CheckpointIOBackend): - - def get_writer(self, - base_name: str, - overwrite: bool = False, - rank: int = 0, - world_size: int = 1) -> CheckpointWriter: - return DiskCheckpointWriter(base_name, overwrite, rank=rank, world_size=world_size) - - def get_reader(self, base_name: str) -> CheckpointReader: - return DiskCheckpointReader(base_name) - - def get_temp(self, base_name: str) -> str: - temp_dir_name = tempfile.mkdtemp(dir=base_name) - self.temps.append(temp_dir_name) - return temp_dir_name - - def clean_temp(self) -> None: - for temp_dir_name in self.temps: - shutil.rmtree(temp_dir_name) diff --git a/colossalai/utils/checkpoint_io/constant.py b/colossalai/utils/checkpoint_io/constant.py deleted file mode 100644 index 2199484741bf..000000000000 --- a/colossalai/utils/checkpoint_io/constant.py +++ /dev/null @@ -1,9 +0,0 @@ -import re - -GLOBAL_META_FILE_NAME = 'global_meta.bin' -MODEL_CKPT_FILE_NAME = 'model.bin' -OPTIM_CKPT_FILE_NAME = 'optim.bin' -META_CKPT_FILE_NAME = 'meta.bin' -OTHER_CKPT_FILE_NAME = 'other.bin' - -CKPT_PAT = re.compile(r'global_meta|model|optim|meta|other') diff --git a/colossalai/utils/checkpoint_io/convertor.py b/colossalai/utils/checkpoint_io/convertor.py deleted file mode 100644 index 529ceb86829b..000000000000 --- a/colossalai/utils/checkpoint_io/convertor.py +++ /dev/null @@ -1,227 +0,0 @@ -from abc import ABC, abstractmethod -from collections import defaultdict -from typing import Any, Callable, Dict, List, Optional - -from torch import Tensor - -from .distributed import merge_param, unmerge_param -from .meta import ParamDistMeta, RedistMeta -from .utils import (ModelCheckpointSharder, OptimizerCheckpointSharder, run_if_not_none) - - -class CheckpointConvertor(ABC): - - @abstractmethod - def append(self, shard_dict: Dict[int, dict], dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None: - pass - - @abstractmethod - def complete(self) -> None: - pass - - -class ModelCheckpointConvertor(CheckpointConvertor): - - def __init__(self, param_count: Dict[str, int]) -> None: - super().__init__() - self.param_count = param_count - self.buffer: Dict[str, Dict[int, Tensor]] = defaultdict(dict) - - @abstractmethod - def convert_tensors(self, key: str, tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> None: - pass - - def append(self, shard_dict: Dict[int, dict], dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None: - for rank, state_dict in shard_dict.items(): - for k, tensor in state_dict.items(): - self.buffer[k][rank] = tensor - converted_keys = set() - for k, rank_dict in self.buffer.items(): - if len(rank_dict) == self.param_count[k]: - tensors = [] - dist_metas = [] - for rank, tensor in rank_dict.items(): - tensors.append(tensor) - if dist_meta_list[rank] is not None: - dist_metas.append(dist_meta_list[rank][k]) - self.convert_tensors(k, tensors, dist_metas) - converted_keys.add(k) - for k in converted_keys: - del self.buffer[k] - - def complete(self) -> None: - assert len(self.buffer) == 0 - - -class ModelCheckpointMerger(ModelCheckpointConvertor): - - def __init__(self, max_shard_size: int, save_fn: Callable[[dict], Any], param_count: Dict[str, int]) -> None: - super().__init__(param_count) - self.sharder = ModelCheckpointSharder(max_shard_size) - self.save_fn = save_fn - - def convert_tensors(self, key: str, tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> None: - assert len(dist_metas) == len(tensors) - tensor = merge_param(tensors, dist_metas) - shard = self.sharder.append(key, tensor) - run_if_not_none(self.save_fn, shard) - - def complete(self) -> None: - super().complete() - run_if_not_none(self.save_fn, self.sharder.complete()) - - -class ModelCheckpointRedistor(ModelCheckpointConvertor): - - def __init__(self, max_shard_size: int, save_fns: List[Callable[[dict], Any]], param_count: Dict[str, int], - redist_meta: RedistMeta) -> None: - super().__init__(param_count) - self.save_fns = save_fns - self.redist_meta = redist_meta - nprocs = len(save_fns) - self.sharders = [ModelCheckpointSharder(max_shard_size) for _ in range(nprocs)] - self.rank_map = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) - for k, rank_meta in redist_meta.rank_meta.items(): - for rank, rank_info in rank_meta.items(): - self.rank_map[k][rank_info.tp_rank][rank_info.dp_rank].append(rank) - - def convert_tensors(self, key: str, tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> None: - if len(dist_metas) == 0: - # already global - tensor = tensors[0] - else: - assert len(dist_metas) == len(tensors) - tensor = merge_param(tensors, dist_metas) - for tp_rank, tensor_list in enumerate(unmerge_param(tensor, self.redist_meta.param_meta[key])): - for dp_rank, t in enumerate(tensor_list): - for rank in self.rank_map[key][tp_rank][dp_rank]: - shard = self.sharders[rank].append(key, t) - run_if_not_none(self.save_fns[rank], shard) - - def complete(self) -> None: - super().complete() - for rank, save_fn in enumerate(self.save_fns): - run_if_not_none(save_fn, self.sharders[rank].complete()) - - -class OptimizerCheckpointConvertor(CheckpointConvertor): - - def __init__(self, param_count: Dict[str, int], param_to_os: Optional[Dict[str, int]], - paired_os: Optional[Dict[int, dict]]) -> None: - super().__init__() - self.param_count = param_count - self.param_to_os = param_to_os - self.paired_os = paired_os - self.buffer: Dict[int, Dict[int, dict]] = defaultdict(dict) - self.os_to_param = {v: k for k, v in param_to_os.items()} - - @abstractmethod - def setup(self, param_groups: dict) -> None: - pass - - @abstractmethod - def convert_states(self, idx: int, states: List[dict], dist_metas: List[ParamDistMeta]) -> None: - pass - - def append(self, shard_dict: Dict[int, dict], dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None: - for rank, state_dict in shard_dict.items(): - self.setup(state_dict['param_groups']) - for idx, state in state_dict['state'].items(): - self.buffer[idx][rank] = state - converted_indices = set() - for idx, rank_dict in self.buffer.items(): - if len(rank_dict) == self.param_count[self.os_to_param[idx]]: - states = [] - dist_metas = [] - for rank, state in rank_dict.items(): - states.append(state) - if dist_meta_list[rank] is not None: - dist_metas.append(dist_meta_list[rank][self.os_to_param[idx]]) - self.convert_states(idx, states, dist_metas) - converted_indices.add(idx) - for idx in converted_indices: - del self.buffer[idx] - - def complete(self) -> None: - assert len(self.buffer) == 0 - - -class OptimizerCheckpointMerger(OptimizerCheckpointConvertor): - - def __init__(self, max_shard_size: int, save_fn: Callable[[dict], Any], param_count: Dict[str, int], - param_to_os: Optional[Dict[str, int]], paired_os: Optional[Dict[int, dict]]) -> None: - super().__init__(param_count, param_to_os, paired_os) - self.max_shard_size = max_shard_size - self.save_fn = save_fn - self.sharder = None - - def setup(self, param_groups: dict) -> None: - if self.sharder is None: - self.sharder = OptimizerCheckpointSharder(self.max_shard_size, param_groups) - - def convert_states(self, idx: int, states: List[dict], dist_metas: List[ParamDistMeta]) -> None: - assert len(dist_metas) == len(states) - new_state = {} - for state_key, state_tensor in states[0].items(): - if self.paired_os[idx][state_key]: - new_state[state_key] = merge_param([state[state_key] for state in states], dist_metas) - else: - new_state[state_key] = state_tensor - shard = self.sharder.append(idx, new_state) - run_if_not_none(self.save_fn, shard) - - def complete(self) -> None: - super().complete() - run_if_not_none(self.save_fn, self.sharder.complete()) - - -class OptimizerCheckpointRedistor(OptimizerCheckpointConvertor): - - def __init__(self, max_shard_size: int, save_fns: List[Callable[[dict], Any]], param_count: Dict[str, int], - param_to_os: Optional[Dict[str, int]], paired_os: Optional[Dict[int, dict]], - redist_meta: RedistMeta) -> None: - super().__init__(param_count, param_to_os, paired_os) - self.max_shard_size = max_shard_size - self.save_fns = save_fns - self.redist_meta = redist_meta - self.sharders: List[OptimizerCheckpointSharder] = [] - self.rank_map = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) - for k, rank_meta in redist_meta.rank_meta.items(): - for rank, rank_info in rank_meta.items(): - self.rank_map[k][rank_info.tp_rank][rank_info.dp_rank].append(rank) - - def setup(self, param_groups: dict) -> None: - if len(self.sharders) == 0: - nprocs = len(self.save_fns) - for _ in range(nprocs): - self.sharders.append(OptimizerCheckpointSharder(self.max_shard_size, param_groups)) - - def convert_states(self, idx: int, states: List[dict], dist_metas: List[ParamDistMeta]) -> None: - need_merge: bool = True - if len(dist_metas) == 0: - need_merge = False - else: - assert len(dist_metas) == len(states) - new_states = [{} for _ in range(len(self.save_fns))] - for state_key, state_tensor in states[0].items(): - if self.paired_os[idx][state_key]: - if need_merge: - tensor = merge_param([state[state_key] for state in states], dist_metas) - else: - tensor = state_tensor - for tp_rank, tensor_list in enumerate( - unmerge_param(tensor, self.redist_meta.param_meta[self.os_to_param[idx]])): - for dp_rank, t in enumerate(tensor_list): - for rank in self.rank_map[self.os_to_param[idx]][tp_rank][dp_rank]: - new_states[rank][state_key] = t - else: - for new_state in new_states: - new_state[state_key] = state_tensor - for rank, new_state in enumerate(new_states): - shard = self.sharders[rank].append(idx, new_state) - run_if_not_none(self.save_fns[rank], shard) - - def complete(self) -> None: - super().complete() - for rank, save_fn in enumerate(self.save_fns): - run_if_not_none(save_fn, self.sharders[rank].complete()) diff --git a/colossalai/utils/checkpoint_io/distributed.py b/colossalai/utils/checkpoint_io/distributed.py deleted file mode 100644 index bf720437c41a..000000000000 --- a/colossalai/utils/checkpoint_io/distributed.py +++ /dev/null @@ -1,127 +0,0 @@ -import torch -from numpy import prod -from torch import Tensor -from typing import List, Optional, Tuple -from collections import defaultdict -from .meta import ParamDistMeta, ParamRedistMeta - - -def unflatten_zero_param(tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> Tensor: - assert len(tensors) > 0 and len(dist_metas) > 0 and len(tensors) == len(dist_metas) - for dist_meta in dist_metas[1:]: - assert dist_meta.zero_meta == dist_metas[0].zero_meta, 'Expect all params have the same zero meta.' - if not dist_metas[0].used_zero: - # tensors are replicate - return tensors[0] - numel = dist_metas[0].zero_numel - orig_shape = dist_metas[0].zero_orig_shape - tensors = [t[1] for t in sorted(zip(dist_metas, tensors), key=lambda tp: tp[0].dp_rank)] - assert numel == sum(t.numel() for t in tensors), 'Expect numel of all params is equal to zero_numel.' - return torch.cat(tensors).reshape(orig_shape) - - -def gather_tp_param(tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> Tensor: - assert len(tensors) > 0 and len(dist_metas) > 0 and len(tensors) == len(dist_metas) - for dist_meta in dist_metas[1:]: - assert dist_meta.tp_meta == dist_metas[0].tp_meta, 'Expect all params have the same tp meta.' - for t in tensors[1:]: - assert t.shape == tensors[0].shape, 'Expect all params have the same shape.' - if not dist_metas[0].used_tp: - # tensors are replicate - return tensors[0] - total_parts = prod(dist_meta.tp_num_parts) - assert dist_meta.tp_world_size == total_parts, \ - f'Expect prod(tp_num_parts) == tp_world_size, got {total_parts} and {dist_meta.tp_world_size}.' - shard_info = sorted(zip(dist_meta.tp_shard_dims, dist_meta.tp_num_parts), key=lambda t: t[0], reverse=True) - for dim, num_parts in shard_info: - buffer = [] - for start in range(0, len(tensors), num_parts): - buffer.append(torch.cat(tensors[start:start + num_parts], dim)) - tensors = buffer - assert len(tensors) == 1 - return tensors[0] - - -def validate_parallel_info(dist_metas: List[ParamDistMeta]) -> None: - assert len(dist_metas) > 0 - # check world size - for dist_meta in dist_metas[1:]: - assert dist_meta.dp_world_size == dist_metas[ - 0].dp_world_size, 'Expect all dist meta have the same dp_world_size' - assert dist_meta.tp_world_size == dist_metas[ - 0].tp_world_size, 'Expect all dist meta have the same tp_world_size' - - -def deduplicate_params(tensors: List[Tensor], - dist_metas: List[ParamDistMeta]) -> Tuple[List[Tensor], List[ParamDistMeta]]: - unique_dist_meta = [] - unique_idx = [] - for i, dist_meta in enumerate(dist_metas): - if dist_meta not in unique_dist_meta: - unique_dist_meta.append(dist_meta) - unique_idx.append(i) - return [tensors[i] for i in unique_idx], [dist_metas[i] for i in unique_idx] - - -def merge_param(tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> Tensor: - assert len(tensors) > 0 and len(dist_metas) > 0 and len(tensors) == len(dist_metas) - # validate parallel info - validate_parallel_info(dist_metas) - tensors, dist_metas = deduplicate_params(tensors, dist_metas) - unflattened_tensors = [] - # group zero params by tp rank - tensor_dict = defaultdict(list) - dist_meta_dict = defaultdict(list) - for t, dist_meta in zip(tensors, dist_metas): - tensor_dict[dist_meta.tp_rank].append(t) - dist_meta_dict[dist_meta.tp_rank].append(dist_meta) - assert len(tensor_dict - ) == dist_metas[0].tp_world_size, f'Expect {dist_metas[0].tp_world_size} ranks, got {len(tensor_dict)}' - for tp_rank in tensor_dict.keys(): - unflattened_tensors.append(unflatten_zero_param(tensor_dict[tp_rank], dist_meta_dict[tp_rank])) - return gather_tp_param(unflattened_tensors, [dist_meta_list[0] for dist_meta_list in dist_meta_dict.values()]) - - -def split_tp_param(tensor: Tensor, redist_meta: ParamRedistMeta) -> List[Tensor]: - if not redist_meta.used_tp: - assert redist_meta.tp_world_size == 1, 'Expect tp_world_size == 1, when no tp meta provided.' - return [tensor] - total_parts = prod(redist_meta.tp_num_parts) - assert redist_meta.tp_world_size == total_parts, f'Expect prod(tp_num_parts) == tp_world_size, got {total_parts} and {redist_meta.tp_world_size}.' - shard_info = sorted(zip(redist_meta.tp_shard_dims, redist_meta.tp_num_parts), key=lambda t: t[0]) - tensors = [tensor] - for dim, num_parts in shard_info: - buffer = [] - for t in tensors: - assert t.size(dim) % num_parts == 0, \ - f'Expect dim{dim} of tensor({tensor.shape}) is divisible by {num_parts}.' - chunks = [chunk.contiguous() for chunk in t.chunk(num_parts, dim)] - buffer.extend(chunks) - tensors = buffer - assert len(tensors) == redist_meta.tp_world_size - return tensors - - -def flatten_zero_param(tensor: Tensor, redist_meta: ParamRedistMeta) -> List[Tensor]: - if not redist_meta.used_zero: - return [tensor] * redist_meta.dp_world_size - tensors: List[Optional[Tensor]] = [ - torch.empty(0, dtype=tensor.dtype, device=tensor.device) for _ in range(redist_meta.zero_start_dp_rank) - ] - offsets = redist_meta.zero_offsets + [tensor.numel()] - for i, offset in enumerate(offsets[:-1]): - end = offsets[i + 1] - tensors.append(tensor.view(-1)[offset:end]) - if len(tensors) < redist_meta.dp_world_size: - tensors.extend([ - torch.empty(0, dtype=tensor.dtype, device=tensor.device) - for _ in range(redist_meta.dp_world_size - len(tensors)) - ]) - assert len(tensors) == redist_meta.dp_world_size - return tensors - - -def unmerge_param(tensor: Tensor, redist_meta: ParamRedistMeta) -> List[List[Tensor]]: - tensors = split_tp_param(tensor, redist_meta) - tensors = [flatten_zero_param(t, redist_meta) for t in tensors] - return tensors diff --git a/colossalai/utils/checkpoint_io/io.py b/colossalai/utils/checkpoint_io/io.py deleted file mode 100644 index f00212cdf859..000000000000 --- a/colossalai/utils/checkpoint_io/io.py +++ /dev/null @@ -1,170 +0,0 @@ -import warnings -from typing import Any, Callable, Dict, Generator, List, Optional, Tuple - -import torch.distributed as dist -from torch.nn import Module -from torch.optim import Optimizer - -from .backend import get_backend -from .convertor import (CheckpointConvertor, ModelCheckpointMerger, ModelCheckpointRedistor, OptimizerCheckpointMerger, - OptimizerCheckpointRedistor) -from .meta import ParamDistMeta, RedistMeta -from .utils import build_checkpoints, optimizer_load_state_dict - - -def save(path: str, - model: Module, - optimizer: Optional[Optimizer] = None, - param_to_os: Optional[Dict[str, int]] = None, - dist_meta: Optional[Dict[str, ParamDistMeta]] = None, - max_shard_size_gb: float = 0.0, - overwrite: bool = False, - backend: str = 'disk', - **kwargs: Any) -> None: - io_backend = get_backend(backend) - if dist.is_initialized(): - rank = dist.get_rank() - world_size = dist.get_world_size() - else: - rank = 0 - world_size = 1 - if world_size == 1: - # global doesn't need dist_meta - dist_meta = None - else: - assert dist_meta is not None - max_shard_size = int(max_shard_size_gb * 1024**3) - model_checkpoints, optimizer_checkpoints, meta_checkpoint = build_checkpoints(max_shard_size, model, optimizer, - param_to_os, dist_meta) - writer = io_backend.get_writer(path, overwrite, rank, world_size) - writer.save_others(kwargs) - for model_checkpoint in model_checkpoints: - writer.save_model(model_checkpoint) - for optimizer_checkpoint in optimizer_checkpoints: - writer.save_optimizer(optimizer_checkpoint) - writer.save_meta(meta_checkpoint) - - -def merge(path: str, - output_path: str, - max_shard_size_gb: float = 0.0, - overwrite: bool = False, - backend: str = 'disk') -> bool: - io_backend = get_backend(backend) - if dist.is_initialized() and dist.get_rank() != 0: - return False - reader = io_backend.get_reader(path) - if len(reader.meta_list) == 1: - # already global - warnings.warn(f'Checkpoint at "{path}" is already global, nothing to do.') - return False - dist_meta_list, param_count, param_to_os, paired_os = reader.load_meta() - writer = io_backend.get_writer(output_path, overwrite=overwrite) - writer.save_others(reader.load_others()) - max_shard_size = int(max_shard_size_gb * 1024**3) - _convert_shards(ModelCheckpointMerger(max_shard_size, writer.save_model, param_count), reader.load_models(), - dist_meta_list) - _convert_shards( - OptimizerCheckpointMerger(max_shard_size, writer.save_optimizer, param_count, param_to_os, paired_os), - reader.load_optimizers(), dist_meta_list) - meta_checkpoint = {'dist_meta': None, 'params': list(param_count.keys())} - if param_to_os is not None: - meta_checkpoint['param_to_os'] = param_to_os - meta_checkpoint['paired_os'] = paired_os - writer.save_meta(meta_checkpoint) - return True - - -def redist(path: str, - output_path: str, - redist_meta: RedistMeta, - dist_metas: List[Dict[str, ParamDistMeta]], - max_shard_size_gb: float = 0.0, - overwrite: bool = False, - backend: str = 'disk') -> bool: - io_backend = get_backend(backend) - if dist.is_initialized() and dist.get_rank() != 0: - return False - nprocs = len(dist_metas) - reader = io_backend.get_reader(path) - dist_meta_list, param_count, param_to_os, paired_os = reader.load_meta() - do_redist: bool = False - if len(dist_meta_list) == nprocs: - for a, b in zip(dist_metas, dist_meta_list): - if a != b: - do_redist = True - break - else: - do_redist = True - if not do_redist: - warnings.warn(f'Checkpoint at "{path}" is not required to redist, nothing to do.') - return False - - writers = [io_backend.get_writer(output_path, overwrite, rank, nprocs) for rank in range(nprocs)] - writers[0].save_others(reader.load_others()) - max_shard_size = int(max_shard_size_gb * 1024**3) - _convert_shards( - ModelCheckpointRedistor(max_shard_size, [writer.save_model for writer in writers], param_count, redist_meta), - reader.load_models(), dist_meta_list) - _convert_shards( - OptimizerCheckpointRedistor(max_shard_size, [writer.save_optimizer for writer in writers], param_count, - param_to_os, paired_os, redist_meta), reader.load_optimizers(), dist_meta_list) - for writer, dist_meta in zip(writers, dist_metas): - meta_checkpoint = {'dist_meta': dist_meta, 'params': list(param_count.keys())} - if param_to_os is not None: - meta_checkpoint['param_to_os'] = param_to_os - meta_checkpoint['paired_os'] = paired_os - writer.save_meta(meta_checkpoint) - return True - - -def _convert_shards(convertor: CheckpointConvertor, shard_generator: Generator[dict, None, None], - dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None: - for shard_dict in shard_generator: - convertor.append(shard_dict, dist_meta_list) - convertor.complete() - - -def load(path: str, - model: Module, - optimizer: Optional[Optimizer] = None, - redist_meta: Optional[RedistMeta] = None, - dist_metas: Optional[List[Dict[str, ParamDistMeta]]] = None, - max_shard_size_gb: float = 0.0, - backend: str = 'disk') -> dict: - is_global: bool = not dist.is_initialized() or dist.get_world_size() == 1 - rank: int = dist.get_rank() if dist.is_initialized() else 0 - is_main_process: bool = rank == 0 - # validate args - if redist_meta is None or dist_metas is None: - assert is_global - io_backend = get_backend(backend) - read_path: str = path - if is_main_process: - # pre-process checkpoints - temp_path = io_backend.get_temp(path) - if is_global: - wrote = merge(path, temp_path, max_shard_size_gb, backend=backend) - else: - wrote = redist(path, temp_path, redist_meta, dist_metas, max_shard_size_gb, backend=backend) - if wrote: - read_path = temp_path - if not is_global: - bcast_list = [read_path] if is_main_process else [None] - dist.broadcast_object_list(bcast_list) - read_path = bcast_list[0] - reader = io_backend.get_reader(read_path) - # load model - for shard in reader.load_model(rank): - model.load_state_dict(shard, strict=False) - if optimizer is not None: - for shard in reader.load_optimizer(rank): - # optimizer.load_state_dict(shard) - optimizer_load_state_dict(optimizer, shard) - others_dict = reader.load_others() - if not is_global: - dist.barrier() - # clean up temp - if is_main_process: - io_backend.clean_temp() - return others_dict diff --git a/colossalai/utils/checkpoint_io/meta.py b/colossalai/utils/checkpoint_io/meta.py deleted file mode 100644 index 994f08b4b5e4..000000000000 --- a/colossalai/utils/checkpoint_io/meta.py +++ /dev/null @@ -1,81 +0,0 @@ -from dataclasses import dataclass -from typing import List, Optional, Set, Dict - - -@dataclass -class ParamDistMeta: - # parallel info - dp_rank: int - dp_world_size: int - tp_rank: int - tp_world_size: int - # tp info - tp_shard_dims: Optional[List[int]] = None - tp_num_parts: Optional[List[int]] = None - # zero info - zero_numel: Optional[int] = None - zero_orig_shape: Optional[List[int]] = None - - @property - def used_tp(self) -> bool: - return self.tp_shard_dims is not None and self.tp_num_parts is not None - - @property - def used_zero(self) -> bool: - return self.zero_numel is not None and self.zero_orig_shape is not None - - @property - def parallel_meta(self) -> tuple: - return self.dp_rank, self.dp_world_size, self.tp_rank, self.tp_world_size - - @property - def tp_meta(self) -> tuple: - return self.tp_shard_dims, self.tp_num_parts - - @property - def zero_meta(self) -> tuple: - return self.zero_numel, self.zero_orig_shape - - @staticmethod - def from_dict(d: dict) -> 'ParamDistMeta': - return ParamDistMeta(**d) - - -@dataclass -class ParamRedistMeta: - # parallel info - dp_world_size: int - tp_world_size: int - # tp info - tp_shard_dims: Optional[List[int]] = None - tp_num_parts: Optional[List[int]] = None - # zero info - zero_start_dp_rank: Optional[int] = None - zero_offsets: Optional[List[int]] = None - - @property - def used_tp(self) -> bool: - return self.tp_shard_dims is not None and self.tp_num_parts is not None - - @property - def used_zero(self) -> bool: - return self.zero_start_dp_rank is not None and self.zero_offsets is not None - - -@dataclass -class RankRedistMeta: - dp_rank: int - tp_rank: int - pp_rank: int - - -@dataclass -class PipelineRedistMeta: - params: Set[str] - - -@dataclass -class RedistMeta: - rank_meta: Dict[str, Dict[int, RankRedistMeta]] - pipeline_meta: List[PipelineRedistMeta] - param_meta: Dict[str, ParamRedistMeta] diff --git a/colossalai/utils/checkpoint_io/reader.py b/colossalai/utils/checkpoint_io/reader.py deleted file mode 100644 index 3158c6481263..000000000000 --- a/colossalai/utils/checkpoint_io/reader.py +++ /dev/null @@ -1,131 +0,0 @@ -import os -from abc import ABC, abstractmethod -from collections import Counter -from typing import Dict, Generator, List, Optional, Tuple - -import torch - -from .constant import GLOBAL_META_FILE_NAME, OTHER_CKPT_FILE_NAME -from .meta import ParamDistMeta -from .utils import is_duplicated_list - - -class CheckpointReader(ABC): - - def __init__(self, base_name: str) -> None: - super().__init__() - self.base_name = base_name - self.meta_list = [] - - @abstractmethod - def read(self, name: str) -> dict: - pass - - @abstractmethod - def load_meta( - self) -> Tuple[List[Optional[Dict[str, ParamDistMeta]]], Dict[str, int], Optional[dict], Optional[dict]]: - pass - - @abstractmethod - def load_model(self, rank: int) -> Generator[dict, None, None]: - pass - - @abstractmethod - def load_models(self) -> Generator[Dict[int, dict], None, None]: - pass - - @abstractmethod - def load_optimizer(self, rank: int) -> Generator[dict, None, None]: - pass - - @abstractmethod - def load_optimizers(self) -> Generator[Dict[int, dict], None, None]: - pass - - @abstractmethod - def load_others(self) -> dict: - pass - - -class DiskCheckpointReader(CheckpointReader): - - def __init__(self, base_name: str) -> None: - super().__init__(base_name) - assert os.path.isdir(base_name), f'"{base_name}" is not a directory' - global_meta = self.read(GLOBAL_META_FILE_NAME) - for meta_file_name in global_meta['meta']: - meta = self.read(meta_file_name) - if meta.get('dist_meta', None) is None: - # only global checkpoint can have empty dist_meta - assert len(global_meta['meta']) == 1 - self.meta_list.append(meta) - - def read(self, name: str) -> dict: - return torch.load(os.path.join(self.base_name, name)) - - def load_meta( - self) -> Tuple[List[Optional[Dict[str, ParamDistMeta]]], Dict[str, int], Optional[dict], Optional[dict]]: - meta_infos = [(meta.get('dist_meta', None), meta['params'], meta.get('param_to_os', - None), meta.get('paired_os', None)) - for meta in self.meta_list] - dist_meta_list, params_list, param_to_os_list, paired_os_list = zip(*meta_infos) - # reduce param_count - param_count = Counter(p for params in params_list for p in params) - # validate param_to_os - assert is_duplicated_list(param_to_os_list) - assert is_duplicated_list(paired_os_list) - return list(dist_meta_list), param_count, param_to_os_list[0], paired_os_list[0] - - def _load_shard(self, shard_type: str, rank: int) -> Generator[dict, None, None]: - meta = self.meta_list[rank] - checkpoint_names = meta.get(shard_type, []) - for name in checkpoint_names: - yield self.read(name) - - def load_model(self, rank: int) -> Generator[dict, None, None]: - return self._load_shard('model', rank) - - def load_models(self) -> Generator[Dict[int, dict], None, None]: - indices = [0] * len(self.meta_list) - while True: - shards = {} - for i, meta in enumerate(self.meta_list): - model_checkpoint_names = meta.get('model', []) - if indices[i] < len(model_checkpoint_names): - shards[i] = self.read(model_checkpoint_names[indices[i]]) - indices[i] += 1 - if len(shards) > 0: - yield shards - else: - break - - def load_optimizer(self, rank: int) -> Generator[dict, None, None]: - param_groups = None - for shard in self._load_shard('optimizer', rank): - if param_groups is None: - param_groups = shard['param_groups'] - else: - shard['param_groups'] = param_groups - yield shard - - def load_optimizers(self) -> Generator[Dict[int, dict], None, None]: - indices = [0] * len(self.meta_list) - param_groups = [] - while True: - shards = {} - for i, meta in enumerate(self.meta_list): - optimizer_checkpoint_names = meta.get('optimizer', []) - if indices[i] < len(optimizer_checkpoint_names): - shards[i] = self.read(optimizer_checkpoint_names[indices[i]]) - if indices[i] == 0: - param_groups.append(shards[i]['param_groups']) - else: - shards[i]['param_groups'] = param_groups[i] - indices[i] += 1 - if len(shards) > 0: - yield shards - else: - break - - def load_others(self) -> dict: - return self.read(OTHER_CKPT_FILE_NAME) diff --git a/colossalai/utils/checkpoint_io/utils.py b/colossalai/utils/checkpoint_io/utils.py deleted file mode 100644 index 135385f57379..000000000000 --- a/colossalai/utils/checkpoint_io/utils.py +++ /dev/null @@ -1,223 +0,0 @@ -import warnings -from copy import deepcopy -from itertools import chain -from typing import Any, Callable, Dict, List, Optional, Tuple - -from torch import Tensor -from torch.nn import Module -from torch.nn.parameter import Parameter -from torch.optim import Optimizer - -from .meta import ParamDistMeta - - -def run_if_not_none(fn: Callable[[Any], Any], arg: Any) -> Any: - if arg is not None: - return fn(arg) - - -def get_param_to_os(model: Module, optimizer: Optimizer) -> Dict[str, int]: - # ensure all params in optimizer are in model state dict - params_set = set(id(p) for p in model.parameters()) - for group in optimizer.param_groups: - for p in group['params']: - assert id(p) in params_set - param_mappings = {} - start_index = 0 - - def get_group_mapping(group): - nonlocal start_index - param_mappings.update( - {id(p): i for i, p in enumerate(group['params'], start_index) if id(p) not in param_mappings}) - start_index += len(group['params']) - - for g in optimizer.param_groups: - get_group_mapping(g) - return {k: param_mappings[id(p)] for k, p in model.named_parameters()} - - -def compute_optimizer_state_size(state: Dict[str, Any]) -> int: - size = 0 - for v in state.values(): - if isinstance(v, Tensor): - size += v.numel() * v.element_size() - return size - - -class ModelCheckpointSharder: - - def __init__(self, max_shard_size: int) -> None: - self.max_shard_size = max_shard_size - self.buffer: Dict[str, Tensor] = {} - self.buffer_size: int = 0 - - def append(self, key: str, tensor: Tensor) -> Optional[dict]: - retval = None - if self.max_shard_size > 0 and self.buffer_size >= self.max_shard_size: - retval = self.buffer - self.buffer = {} - self.buffer_size = 0 - self.buffer[key] = tensor - self.buffer_size += tensor.numel() * tensor.element_size() - return retval - - def extend(self, state_dict: Dict[str, Tensor]) -> List[dict]: - shards = [] - for key, tensor in state_dict.items(): - shard = self.append(key, tensor) - run_if_not_none(shards.append, shard) - return shards - - def complete(self) -> Optional[dict]: - return self.buffer if len(self.buffer) > 0 else None - - -class OptimizerCheckpointSharder: - - def __init__(self, max_shard_size: int, param_groups: dict) -> None: - self.max_shard_size = max_shard_size - self.buffer: Dict[str, dict] = {'state': {}, 'param_groups': param_groups} - self.buffer_size: int = 0 - self.returned_first: bool = False - - def append(self, key: int, state: dict) -> Optional[dict]: - retval = None - if self.max_shard_size > 0 and self.buffer_size >= self.max_shard_size: - retval = self.buffer - self.buffer = {'state': {}} - self.buffer_size = 0 - self.buffer['state'][key] = state - self.buffer_size += compute_optimizer_state_size(state) - return retval - - def extend(self, state_dict: Dict[str, dict]) -> List[dict]: - shards = [] - for key, state in state_dict['state'].items(): - shard = self.append(key, state) - run_if_not_none(shards.append, shard) - return shards - - def complete(self) -> Optional[dict]: - return self.buffer if len(self.buffer['state']) > 0 else None - - -def shard_checkpoint(max_shard_size: int, - model_state_dict: Dict[str, Tensor], - optimizer_state_dict: Optional[dict] = None, - param_to_os: Optional[dict] = None) -> Tuple[List[dict], List[dict]]: - has_optimizer: bool = False - if optimizer_state_dict is not None: - assert param_to_os is not None - os_to_param = {v: k for k, v in param_to_os.items()} - for os_key in optimizer_state_dict['state'].keys(): - assert os_key in os_to_param - assert os_to_param[os_key] in model_state_dict - has_optimizer = True - model_sharder = ModelCheckpointSharder(max_shard_size) - model_shards = model_sharder.extend(model_state_dict) - run_if_not_none(model_shards.append, model_sharder.complete()) - if not has_optimizer: - return model_shards, [] - optimizer_sharder = OptimizerCheckpointSharder(max_shard_size, optimizer_state_dict['param_groups']) - optimizer_shards = optimizer_sharder.extend(optimizer_state_dict) - run_if_not_none(optimizer_shards.append, optimizer_sharder.complete()) - return model_shards, optimizer_shards - - -def get_paired_os(model_state_dict: Dict[str, Tensor], optimizer_state_dict: dict, param_to_os: Dict[str, int]) -> dict: - os_to_param = {v: k for k, v in param_to_os.items()} - paired_os = {} - for idx, state in optimizer_state_dict['state'].items(): - paired_os[idx] = {} - p = model_state_dict[os_to_param[idx]] - for k, v in state.items(): - if isinstance(v, Tensor) and v.shape == p.shape: - paired_os[idx][k] = True - else: - paired_os[idx][k] = False - return paired_os - - -def build_checkpoints(max_size: int, - model: Module, - optimizer: Optional[Optimizer] = None, - param_to_os: Optional[Dict[str, int]] = None, - dist_meta: Optional[Dict[str, ParamDistMeta]] = None, - eliminate_replica: bool = False) -> Tuple[List[dict], List[dict], dict]: - save_global = dist_meta is None - model_state_dict = model.state_dict() - optimizer_state_dict = optimizer.state_dict() if optimizer else None - meta = {'dist_meta': dist_meta} - if optimizer: - param_to_os = param_to_os or get_param_to_os(model, optimizer) - paired_os = get_paired_os(model_state_dict, optimizer_state_dict, param_to_os) - meta['param_to_os'] = param_to_os - meta['paired_os'] = paired_os - if not save_global and eliminate_replica: - # filter dp replicated params - model_state_dict = { - k: v for k, v in model_state_dict.items() if dist_meta[k].used_zero or dist_meta[k].dp_rank == 0 - } - if optimizer: - optimizer_state_dict['state'] = { - param_to_os[k]: optimizer_state_dict['state'][param_to_os[k]] - for k in model_state_dict.keys() - if dist_meta[k].used_zero or dist_meta[k].dp_rank == 0 - } - meta['params'] = list(model_state_dict.keys()) - if len(model_state_dict) == 0: - warnings.warn('model state dict is empty, checkpoint is not saved') - return [], [], meta - model_checkpoints, optimizer_checkpoints = shard_checkpoint(max_size, model_state_dict, optimizer_state_dict, - param_to_os) - return model_checkpoints, optimizer_checkpoints, meta - - -def is_duplicated_list(list_: List[Any]) -> bool: - if len(list_) == 0: - return True - elem = list_[0] - for x in list_[1:]: - if x != elem: - return False - return True - - -def copy_optimizer_state(src_state: dict, dest_state: dict) -> None: - for k, v in src_state.items(): - if k in dest_state: - old_v = dest_state[k] - if isinstance(old_v, Tensor): - old_v.copy_(v) - else: - dest_state[k] = v - - -def optimizer_load_state_dict(optimizer: Optimizer, state_dict: dict, strict: bool = False) -> None: - assert optimizer.state_dict()['param_groups'] == state_dict['param_groups'] - state_dict = deepcopy(state_dict) - groups = optimizer.param_groups - saved_groups = state_dict['param_groups'] - idx_to_p: Dict[str, Parameter] = { - old_id: p for old_id, p in zip(chain.from_iterable((g['params'] for g in saved_groups - )), chain.from_iterable((g['params'] for g in groups))) - } - missing_keys = list(set(idx_to_p.keys()) - set(state_dict['state'].keys())) - unexpected_keys = [] - error_msgs = [] - for idx, state in state_dict['state'].items(): - if idx in idx_to_p: - old_state = optimizer.state[idx_to_p[idx]] - copy_optimizer_state(state, old_state) - else: - unexpected_keys.append(idx) - if strict: - if len(unexpected_keys) > 0: - error_msgs.insert( - 0, 'Unexpected key(s) in state_dict: {}. '.format(', '.join('"{}"'.format(k) for k in unexpected_keys))) - if len(missing_keys) > 0: - error_msgs.insert( - 0, 'Missing key(s) in state_dict: {}. '.format(', '.join('"{}"'.format(k) for k in missing_keys))) - if len(error_msgs) > 0: - raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(optimizer.__class__.__name__, - "\n\t".join(error_msgs))) diff --git a/colossalai/utils/checkpoint_io/writer.py b/colossalai/utils/checkpoint_io/writer.py deleted file mode 100644 index 4552accde470..000000000000 --- a/colossalai/utils/checkpoint_io/writer.py +++ /dev/null @@ -1,98 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Optional -from .constant import MODEL_CKPT_FILE_NAME, OPTIM_CKPT_FILE_NAME, META_CKPT_FILE_NAME, OTHER_CKPT_FILE_NAME, GLOBAL_META_FILE_NAME -import torch -import os - - -class CheckpointWriter(ABC): - - def __init__(self, base_name: str, overwrite: bool = False, rank: int = 0, world_size: int = 1) -> None: - super().__init__() - self.base_name = base_name - self.overwrite = overwrite - self.rank = rank - self.world_size = world_size - self.is_distributed = world_size > 1 - self.is_main_process = rank == 0 - - @abstractmethod - def write(self, name: str, state_dict: dict) -> None: - pass - - @abstractmethod - def save_model(self, model_checkpoint: dict) -> None: - pass - - @abstractmethod - def save_optimizer(self, optimizer_checkpoint: dict) -> None: - pass - - @abstractmethod - def save_meta(self, meta_checkpoint: dict) -> None: - pass - - @abstractmethod - def save_others(self, kwargs: dict) -> None: - pass - - -class DiskCheckpointWriter(CheckpointWriter): - - def __init__(self, base_name: str, overwrite: bool = False, rank: int = 0, world_size: int = 1) -> None: - super().__init__(base_name, overwrite, rank, world_size) - if not os.path.exists(base_name): - os.makedirs(base_name) - assert os.path.isdir(base_name), f'"{base_name}" is not a directory' - self.model_checkpoint_names = [] - self.optimizer_checkpoint_names = [] - self.is_meta_saved: bool = False - self._save_global_meta() - - def write(self, name: str, state_dict: dict) -> None: - path = os.path.join(self.base_name, name) - if os.path.exists(path) and not self.overwrite: - raise RuntimeError(f'Save error: Checkpoint "{path}" exists. (overwrite = False)') - torch.save(state_dict, path) - - def _save_global_meta(self) -> None: - if self.is_main_process: - global_meta = {'meta': []} - if self.is_distributed: - for i in range(self.world_size): - global_meta['meta'].append(META_CKPT_FILE_NAME.replace('.bin', f'-rank{i}.bin')) - else: - global_meta['meta'].append(META_CKPT_FILE_NAME) - self.write(GLOBAL_META_FILE_NAME, global_meta) - - def _get_checkpoint_name(self, base_name: str, shard_idx: Optional[int] = None) -> str: - checkpoint_name = base_name - if self.is_distributed: - checkpoint_name = checkpoint_name.replace('.bin', f'-rank{self.rank}.bin') - if shard_idx is not None: - checkpoint_name = checkpoint_name.replace('.bin', f'-shard{shard_idx}.bin') - return checkpoint_name - - def save_model(self, model_checkpoint: dict) -> None: - assert not self.is_meta_saved, 'Cannot save model after saving meta' - name = self._get_checkpoint_name(MODEL_CKPT_FILE_NAME, len(self.model_checkpoint_names)) - self.write(name, model_checkpoint) - self.model_checkpoint_names.append(name) - - def save_optimizer(self, optimizer_checkpoint: dict) -> None: - assert not self.is_meta_saved, 'Cannot save optimizer after saving meta' - name = self._get_checkpoint_name(OPTIM_CKPT_FILE_NAME, len(self.optimizer_checkpoint_names)) - self.write(name, optimizer_checkpoint) - self.optimizer_checkpoint_names.append(name) - - def save_meta(self, meta_checkpoint: dict) -> None: - if len(self.model_checkpoint_names) > 0: - meta_checkpoint['model'] = self.model_checkpoint_names - if len(self.optimizer_checkpoint_names) > 0: - meta_checkpoint['optimizer'] = self.optimizer_checkpoint_names - self.write(self._get_checkpoint_name(META_CKPT_FILE_NAME), meta_checkpoint) - self.is_meta_saved = True - - def save_others(self, kwargs: dict) -> None: - if self.is_main_process: - self.write(OTHER_CKPT_FILE_NAME, kwargs) diff --git a/colossalai/zero/low_level/_utils.py b/colossalai/zero/low_level/_utils.py index 218f7603bc54..4205a9891534 100644 --- a/colossalai/zero/low_level/_utils.py +++ b/colossalai/zero/low_level/_utils.py @@ -3,8 +3,9 @@ import torch import torch.distributed as dist -from torch import inf +from torch import Tensor, inf from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors +from torch.distributed import ProcessGroup from colossalai.tensor import ColoParameter from colossalai.utils import is_model_parallel_parameter @@ -194,26 +195,21 @@ def calculate_global_norm_from_list(norm_list): return math.sqrt(total_norm) -def compute_norm(gradients, params, dp_group, mp_group, norm_type=2): +def compute_norm(gradients: Tensor, dp_group: ProcessGroup, tp_group: ProcessGroup, norm_type: int = 2) -> int: """Clips gradient norm of an iterable of parameters. This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and - added functionality to handle model parallel parameters. Note that - the gradients are modified in place. - Arguments: - parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a - single Tensor that will have gradients normalized - max_norm (float or int): max norm of the gradients - norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for - infinity norm. + added functionality to handle model parallel parameters. + + Args: + gradients (Tensor): The gradients to compute norm + dp_group (ProcessGroup): The process group of ZeRO Data Parallelism + tp_group (ProcessGroup): The process group of Tensor Parallelism + norm_type (int, optional): type of the used p-norm, Can be ``'inf'`` for infinity norm. Defaults to 2. + Returns: - Total norm of the parameters (viewed as a single vector). + int: The total norm of given gradients """ - if mp_group is None: - mp_rank = 0 - else: - mp_rank = dist.get_rank(mp_group) - norm_type = float(norm_type) if norm_type == inf: total_norm = max(g.data.abs().max() for g in gradients) @@ -221,29 +217,21 @@ def compute_norm(gradients, params, dp_group, mp_group, norm_type=2): dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=dp_group) # Take max across all GPUs. - if mp_group is not None: + if tp_group is not None: dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.MAX) total_norm = total_norm_cuda[0].item() else: total_norm = 0.0 - # if dist.get_rank() == 0: - # logger.info(f"Total Norm beginning {total_norm}") - - for g, p in zip(gradients, params): - # Pipeline parallelism may replicate parameters. Avoid multi-counting. - tp_param_flag = False - if is_model_parallel_parameter(p) or (isinstance(p, ColoParameter) and not p.is_replicate()): - tp_param_flag = True - if tp_param_flag or mp_rank == 0: - param_norm = g.data.double().norm(2) - total_norm += param_norm.item()**2 + for g in gradients: + param_norm = g.data.double().norm(2) + total_norm += param_norm.item()**2 # Sum across all model parallel GPUs. total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) torch.distributed.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=dp_group) - if mp_group is not None: - dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=mp_group) + if tp_group is not None: + dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=tp_group) total_norm = total_norm_cuda[0].item()**(1. / norm_type) @@ -253,7 +241,7 @@ def compute_norm(gradients, params, dp_group, mp_group, norm_type=2): return total_norm -def sync_param(flat_tensor, tensor_list): +def sync_tensor(flat_tensor, tensor_list): """ Synchronize the flattened tensor and unflattened tensor list. When a list of tensor are flattened with `torch._utils._unflatten_dense_tensors`, diff --git a/colossalai/zero/low_level/bookkeeping/bucket_store.py b/colossalai/zero/low_level/bookkeeping/bucket_store.py index ec322a78bf81..98f1b78d0049 100644 --- a/colossalai/zero/low_level/bookkeeping/bucket_store.py +++ b/colossalai/zero/low_level/bookkeeping/bucket_store.py @@ -1,3 +1,8 @@ +from typing import Dict + +import torch +from torch import Tensor +from torch._utils import _flatten_dense_tensors from torch.distributed import ProcessGroup from .base_store import BaseStore @@ -7,35 +12,102 @@ class BucketStore(BaseStore): def __init__(self, torch_pg: ProcessGroup): super().__init__(torch_pg) - self._params = dict() - self._num_elements_in_bucket = dict() + + # init and reset + self.current_group_id = 0 + # mapping gardient slices and parameter + self.grad_to_param_mapping = dict() + + self._param_list = [] + self._padding_size = [] self.reset() - def num_elements_in_bucket(self, reduce_rank: int = None): - return self._num_elements_in_bucket[reduce_rank] + def num_elements_in_bucket(self) -> int: + """Return the total number of elements in bucket + + Returns: + int: the total number of elements in bucket + """ + + return self._num_elements_in_bucket + + def add_param_grad(self, group_id: int, param: Tensor, padding_size: int): + """Add a param to bucket and record the padding size of a param for gradient padding + + Args: + group_id (int): The index of a parameter group + param (Tensor): The parameter + padding_size (int): The padding size of the parameter + """ + + self._param_list.append(param) + self._padding_size.append(padding_size) + self._num_elements_in_bucket += (param.numel() + padding_size) + self.current_group_id = group_id + + def build_grad_in_bucket(self): + """Orgnize parameters' gradient(padding and split), follows the paramters' splitting method + + Data structure of self._grad_in_bucket: + { + rank0: [grad0_rank0, grad1_rank0, ...] + rank1: [grad1_rank1, grad1_rank1, ...] + } + """ + + for param, padding_size in zip(self._param_list, self._padding_size): + with torch.no_grad(): + grad = param.grad.detach().flatten() + if padding_size > 0: + grad = torch.nn.functional.pad(grad, [0, padding_size]) + grad_list = grad.split(grad.numel() // self._world_size) + for rank in range(self._world_size): + grad_current_rank = grad_list[rank].detach() + self.grad_to_param_mapping[id(grad_current_rank)] = id(param) + self._grad_in_bucket[rank].append(grad_current_rank) + param.grad = None + + def get_grad(self) -> Dict: + """Return the dictionary of gradients slices, of which the keys are ranks + + Returns: + Dict: The dictionary of gradients slices + """ + + return self._grad_in_bucket + + def get_flatten_grad(self) -> Tensor: + """Return the flattened gradients slices in the bucket, the data orginization of the flattened tensor: + [grad0_rank0, grad1_rank0, ..., grad_1_rank0, grad1_rank1, ....] + + Returns: + Tensor: the flattened gradients slices in the bucket + """ + + flat_grad = [] + for grad_list in self._grad_in_bucket.values(): + flat_grad.append(_flatten_dense_tensors(grad_list)) + flat_grad = _flatten_dense_tensors(flat_grad) + return flat_grad + + def get_param_id_of_grad(self, grad: Tensor) -> int: + """Return the id of a parameter which the gradient slice belongs to + + Args: + grad (Tensor): the gradient slice - def add_num_elements_in_bucket(self, num_elements, reduce_rank: int = None): - self._num_elements_in_bucket[reduce_rank] += num_elements + Returns: + int: the id of a parameter which the gradient slice belongs to + """ - def add_param(self, tensor, reduce_rank: int = None): - self._params[reduce_rank].append(tensor) + return self.grad_to_param_mapping[id(grad)] def reset(self): - keys = [None] + list(range(self._world_size)) - self._params = {rank: [] for rank in keys} - self._num_elements_in_bucket = {rank: 0 for rank in keys} - - def reset_by_rank(self, reduce_rank=None): - self._params[reduce_rank] = [] - self._num_elements_in_bucket[reduce_rank] = 0 - - def get_grad(self, reduce_rank: int = None): - param_list = self.get_param(reduce_rank) - for param in param_list: - # the param must have grad for reduction - assert param.grad is not None, f'Parameter of size ({param.size()}) has None grad, cannot be reduced' - return [param.grad for param in param_list] - - def get_param(self, reduce_rank: int = None): - return self._params[reduce_rank] + self.grad_to_param_mapping = dict() + self._num_elements_in_bucket = 0 + self._param_list = [] + self._padding_size = [] + self._grad_in_bucket = dict() + for rank in range(self._world_size): + self._grad_in_bucket[rank] = [] diff --git a/colossalai/zero/low_level/bookkeeping/gradient_store.py b/colossalai/zero/low_level/bookkeeping/gradient_store.py index 942d7186e55f..0b86ec8ca89e 100644 --- a/colossalai/zero/low_level/bookkeeping/gradient_store.py +++ b/colossalai/zero/low_level/bookkeeping/gradient_store.py @@ -1,88 +1,92 @@ from typing import List from torch import Tensor +from torch._utils import _flatten_dense_tensors from .base_store import BaseStore class GradientStore(BaseStore): - def __init__(self, *args): + def __init__(self, *args, partition_grad: bool = False): super().__init__(*args) - # bookkeeping data structures - self._averaged_gradients = dict() - - # for backward reduction hooks - self._grad_acc_objs = [] - - def append_accumulate_grad_object(self, obj): """ - Keep :class:`AccumulateGrad` objects. If these objects are not kept, reduction hooks may not - be attached successfully. - - :param obj: An object of :class:`AccumulateGrad` class - :type obj: :class:`AccumulateGrad` + self._grads_of_params mapping the paramater and its gradient slices + data structure: + { + group_id:{ + param_id: [grad_rank0, grad_rank1, ...] + } + } """ + self._grads_of_params = dict() + # for zero2, it's `param_id: [grad_local_rank]` + self._working_index = 0 if partition_grad else self._local_rank - self._grad_acc_objs.append(obj) + def get_partitioned_gradients_by_param_id(self, group_id: int, param_id: int) -> List: + """Return list of gradient slices of a specific parameter - def get_averaged_gradients_by_group(self, group_id: int) -> List[Tensor]: - """ - Return average gradients of a parameter group + Args: + group_id (int): The index of a parameter group + param_id (int): The id of a parameter - :param group_id: The index of parameter group - :type group_id: int - - :return: Return the list of averaged gradients of a parameter group. Each element is a gradient, not a parameter. - :rtype: List[torch.Tensor] + Returns: + List: the list of gradient slices of a parameter. """ - if group_id not in self._averaged_gradients: - self._averaged_gradients[group_id] = [] - - return self._averaged_gradients[group_id] - def append_average_gradient_by_group(self, group_id: int, tensor: Tensor) -> None: - """ - Append an average gradient to the list of averaged gradients of a parameter group + if group_id in self._grads_of_params: + if param_id in self._grads_of_params[group_id]: + return self._grads_of_params[group_id][param_id] + # the param has no grad, for instance, in layer drop + return [] - :param group_id: The index of a parameter group - :param tensor: A :class:`torch.Tensor` object - :type group_id: int - :type tensor: torch.Tensor + def append_gradients_by_param_id(self, grad: Tensor, group_id: int, param_id: int): + """Append a gradient slice to the parameter's gradient slice list + Args: + grad (Tensor): The gradient slice to append to list + group_id (int): The index of a parameter group + param_id (int): The id of a parameter """ - if group_id in self._averaged_gradients: - self._averaged_gradients[group_id].append(tensor) + if group_id not in self._grads_of_params: + self._grads_of_params[group_id] = dict() + if param_id not in self._grads_of_params[group_id]: + self._grads_of_params[group_id][param_id] = [grad] else: - self._averaged_gradients[group_id] = [tensor] + self._grads_of_params[group_id][param_id].append(grad) - def add_average_gradient_by_group(self, group_id: int, tensor_idx: int, tensor: Tensor) -> None: + def add_gradients_by_param_id(self, grad: Tensor, grad_idx: int, group_id: int, param_id: int): + """For old gradient accumulation, not in use now. + Add a gradient slice on an existing slice of the parameter's gradient + + Args: + grad (Tensor): The split gradient to append to list + grad_idx (int): The index of the existing slice + group_id (int): The index of a parameter group + param_id (int): The id of a parameter """ - Add an average gradient to the list of averaged gradients of a parameter group - :param group_id: The index of a parameter group - :param tensor_idx: The index of a tensor in the list of averaged gradients - :param tensor: A :class:`torch.Tensor` object - :type group_id: int - :type tensor_idx: int - :type tensor: torch.Tensor + self._grads_of_params[group_id][param_id][grad_idx].add_(grad) - """ - self._averaged_gradients[group_id][tensor_idx].add_(tensor) + def get_working_grads_by_group_id(self, group_id: int) -> List: + """Return list of working gradient slices in the group - def reset_average_gradients_by_group(self, group_id: int) -> None: - """ - Reset the bookkeeping data structure for averaged gradients to an empty list + Args: + group_id (int): The index of a parameter group - :param group_id: The index of a parameter group - :type group_id: int + Returns: + List: the list working gradient slices in the group """ - self._averaged_gradients[group_id] = [] + grad_list = [] + for param_grads in self._grads_of_params[group_id].values(): + grad_list.append(param_grads[self._working_index]) - def reset_all_average_gradients(self) -> None: - """ - Reset the bookkeeping data structure for averaged gradients to an empty list - """ - self._averaged_gradients = dict() + return grad_list + + def reset_grads_by_group_id(self, group_id: int): + self._grads_of_params[group_id] = dict() + + def reset_all_gradients(self): + self._grads_of_params = dict() diff --git a/colossalai/zero/low_level/bookkeeping/parameter_store.py b/colossalai/zero/low_level/bookkeeping/parameter_store.py index 1f3ba7cbc3bc..63f7c5506069 100644 --- a/colossalai/zero/low_level/bookkeeping/parameter_store.py +++ b/colossalai/zero/low_level/bookkeeping/parameter_store.py @@ -1,5 +1,3 @@ -from typing import List - from torch import Tensor from torch.distributed import ProcessGroup @@ -10,88 +8,43 @@ class ParameterStore(BaseStore): def __init__(self, torch_pg: ProcessGroup): super().__init__(torch_pg) - # param partitioning data structures - self._param_to_rank = dict() - self._rank_group_id_to_param_list = dict() - self._rank_group_id_to_flat_param = dict() - # param reduction data structures - self._is_param_reduced = dict() - self._reduced_param = [] + # record the padding size of each param + self._padding_map = dict() - def set_param_to_rank(self, tensor: Tensor, rank: int) -> None: - """ - Set the mapping between parameter to rank, each parameter should be owned by a rank. + # mapping working param and master param + self.master_to_working_param = dict() + self.working_to_master_param = dict() - :param tensor: A :class:`torch.Tensor` object - :type tensor: torch.Tensor - :param rank: The rank of which the process is responsible for updating the parameter - :type rank: int - """ + def record_param_padding_size(self, param: Tensor, padding_size: int): + """Record the padding size of a param - self._param_to_rank[tensor] = rank - - def get_param_rank(self, tensor: Tensor) -> int: + Args: + param (Tensor): The parameter + padding_size (int): The padding size of the parameter """ - Gives the rank which the parameter belongs to - :param tensor: A :class:`torch.Tensor` object - :type tensor: torch.Tensor - """ - return self._param_to_rank[tensor] + self._padding_map[id(param)] = padding_size - def belongs_to_current_rank(self, tensor) -> bool: - """ - Check whether a parameter is supposed to be updated by the process of the current rank + def get_param_padding_size(self, param: Tensor) -> int: + """Return the padding size of the parameter - :param tensor: A :class:`torch.Tensor` object - :type tensor: torch.Tensor + Args: + param (Tensor): The parameter - :return: True if the parameter should be updated by the current rank. Otherwise false. - :rtype: bool + Returns: + int: the padding size of the parameter """ - tensor_rank = self._param_to_rank[tensor] - return tensor_rank == self._local_rank - - def add_param_list_by_rank_group(self, rank, group_id, tensor_list) -> None: - if rank not in self._rank_group_id_to_param_list: - self._rank_group_id_to_param_list[rank] = dict() - - if group_id not in self._rank_group_id_to_param_list[rank]: - self._rank_group_id_to_param_list[rank][group_id] = [] - - self._rank_group_id_to_param_list[rank][group_id].extend(tensor_list) + return self._padding_map[id(param)] - def get_params_by_rank_group(self, rank, group_id) -> List[Tensor]: - return self._rank_group_id_to_param_list[rank][group_id] + def link_master_and_working_param(self, master_param: Tensor, working_param: Tensor): + """Mapping master parameter and working parameter - def add_flat_param_by_rank_group(self, rank, group_id, tensor) -> None: - if rank not in self._rank_group_id_to_flat_param: - self._rank_group_id_to_flat_param[rank] = dict() - - self._rank_group_id_to_flat_param[rank][group_id] = tensor - - def get_flat_param_by_rank_group(self, rank, group_id) -> Tensor: - return self._rank_group_id_to_flat_param[rank][group_id] - - def is_param_reduced(self, tensor): - return self._is_param_reduced[tensor] - - def set_param_reduction_state(self, tensor, state): - self._is_param_reduced[tensor] = state - - def get_param_reduction_states(self): - return self._is_param_reduced - - def reset_previous_reduced_params(self): - self._reduced_param = [] - - def add_previous_reduced_param(self, tensor): - self._reduced_param.append(tensor) + Args: + master_param (Tensor): The parameter copy in optimizer + working_param (Tensor): The parameter of the model + """ - def clear_grads_of_previous_reduced_params(self): - if len(self._reduced_param) > 0: - for param in self._reduced_param: - param.grad = None - self.reset_previous_reduced_params() + self.master_to_working_param[id(master_param)] = working_param + self.working_to_master_param[id(working_param)] = master_param diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index ee03c0f0ae15..2b3f50ed4fd4 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -1,9 +1,12 @@ # this code is inspired by the DeepSpeed library and implemented with our own design from scratch +import copy +from contextlib import contextmanager from functools import partial -from typing import Optional +from typing import Dict, Iterator, Optional, Tuple import torch import torch.distributed as dist +from torch.distributed import ProcessGroup from torch.optim import Optimizer from colossalai.amp.naive_amp.mixed_precision_mixin import ( @@ -11,11 +14,9 @@ FP16MixedPrecisionMixin, MixedPrecisionMixin, ) -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc +from colossalai.interface import OptimizerWrapper from colossalai.logging import get_dist_logger -from colossalai.nn.optimizer import ColossalaiOptimizer -from colossalai.tensor import ColoParameter, ProcessGroup +# from colossalai.tensor import ColoParameter, ProcessGroup from colossalai.utils.cuda import get_current_device from ._utils import ( @@ -23,12 +24,10 @@ compute_norm, flatten, has_inf_or_nan, - reduce_tensor_dp_group, release_param_grad, - split_by_dtype, - sync_param, + sync_tensor, ) -from .bookkeeping import BucketStore, GradientStore, ParameterStore, TensorBucket +from .bookkeeping import BucketStore, GradientStore, ParameterStore class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin): @@ -50,13 +49,13 @@ def __init__(self, def check_local_overflow(self) -> bool: for group_id in range(self.num_working_param_groups): - for avg_grad in self.grad_store.get_averaged_gradients_by_group(group_id): + for avg_grad in self.grad_store.get_working_grads_by_group_id(group_id): if avg_grad is not None and has_inf_or_nan(avg_grad): return True return False -class LowLevelZeroOptimizer(ColossalaiOptimizer): +class LowLevelZeroOptimizer(OptimizerWrapper): """Optimizer used for ZeRO-1 and ZeRO-2. """ @@ -77,14 +76,13 @@ def __init__( overlap_communication: bool = False, partition_grad: bool = False, # stage 2 flag cpu_offload: bool = False, # cpu offload + dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm + tp_process_group: Optional[ProcessGroup] = None, # if using tp forced_dtype: Optional[torch.dtype] = None): - # TODO: add support for - # 1. fp16 master weights - # 2. contiguous gradients - # 3. cpu offload - # 4. support when some parameters requires_grad = False - # 5. support layer drop + # TODO: + # 1. state_dict for checkpoint IO + super(LowLevelZeroOptimizer, self).__init__(optim=optimizer) self._dtype = self.optim.param_groups[0]['params'][0].dtype self._logger = get_dist_logger() @@ -95,34 +93,19 @@ def __init__( self._cpu_offload = cpu_offload - colo_pg = self._search_colo_process_group() - if isinstance(colo_pg, ProcessGroup): - self._local_rank = colo_pg.dp_local_rank() - self._world_size = colo_pg.dp_world_size() - self._dp_global_ranks = colo_pg.get_ranks_in_dp() - self._dp_torch_group = colo_pg.dp_process_group() - self._mp_torch_group = None - if colo_pg.tp_world_size() > 1: - self._mp_torch_group = colo_pg.tp_process_group() - elif colo_pg is None: - dp_parallel_mode = ParallelMode.DATA - mp_parallel_mode = ParallelMode.MODEL - - self._dp_parallel_mode = dp_parallel_mode - self._mp_parallel_mode = mp_parallel_mode - self._local_rank = gpc.get_local_rank(dp_parallel_mode) - self._world_size = gpc.get_world_size(dp_parallel_mode) - self._dp_global_ranks = gpc.get_ranks_in_group(dp_parallel_mode) - self._dp_torch_group = gpc.get_group(dp_parallel_mode) - self._mp_torch_group = None - if gpc.is_initialized(mp_parallel_mode) and gpc.get_world_size(mp_parallel_mode) > 1: - self._mp_torch_group = gpc.get_group(mp_parallel_mode) - else: - raise NotImplementedError + # grad accumulation + self.require_grad_sync = True + + # if process_group is none, will use the default one + self.dp_pg = dp_process_group + self._local_rank = dist.get_rank(group=self.dp_pg) + self._world_size = dist.get_world_size(group=self.dp_pg) + + self.tp_pg = tp_process_group # working and master params for mixed precision training self._working_param_groups = dict() - self._master_flat_param_groups_of_current_rank = dict() + self._master_param_groups_of_current_rank = dict() # communication params self._overlap_communication = overlap_communication @@ -144,9 +127,9 @@ def __init__( # ParameterStore will manage the tensor buffers used for zero # it will not manage the tensors used by mixed precision training - self._param_store = ParameterStore(self._dp_torch_group) - self._grad_store = GradientStore(self._dp_torch_group) - self._bucket_store = BucketStore(self._dp_torch_group) + self._param_store = ParameterStore(self.dp_pg) + self._grad_store = GradientStore(self.dp_pg, partition_grad=partition_grad) + self._bucket_store = BucketStore(self.dp_pg) # iterate over the param group in the optimizer # partition these param groups for data parallel training @@ -160,55 +143,17 @@ def __init__( # add the working params to working_param_groups for bookkeeping self._working_param_groups[group_id] = group_params - # assign parameters to ranks - # the params in the list are sorted - params_per_rank = self._partition_param_list(group_params) + master_param_current_rank = self._create_master_param_current_rank(group_params) - # store the mapping between param to rank - # each param should belong to only one rank - for rank, params in enumerate(params_per_rank): - self._param_store.add_param_list_by_rank_group(rank, group_id, params) - for param in params: - self._param_store.set_param_to_rank(param, rank) - - # move to cpu to make room to create the flat tensor - # move_tensor(params, device='cpu') - for param in group_params: - param.data = param.data.cpu() - - # flatten the reordered tensors - for rank in range(self._world_size): - tensor_list = self._param_store.get_params_by_rank_group(rank, group_id) - with torch.no_grad(): - flat_tensor = flatten(tensor_list) - flat_tensor = flat_tensor.data.cuda() - self._param_store.add_flat_param_by_rank_group(rank, group_id, flat_tensor) - - # sync parameters - for rank in range(self._world_size): - flat_tensor = self._param_store.get_flat_param_by_rank_group(rank, group_id) - tensor_list = self._param_store.get_params_by_rank_group(rank, group_id) - sync_param(flat_tensor=flat_tensor, tensor_list=tensor_list) - - # create a copy of fp32 master weights of the parameters for which this rank is responsible - working_flat_current_rank = self._param_store.get_flat_param_by_rank_group(self._local_rank, group_id) - master_flat_current_rank = working_flat_current_rank.float() - device = 'cpu' if self._cpu_offload else get_current_device() - master_flat_current_rank = master_flat_current_rank.to(device) - master_flat_current_rank.requires_grad = True - self._master_flat_param_groups_of_current_rank[group_id] = master_flat_current_rank + self._master_param_groups_of_current_rank[group_id] = master_param_current_rank # need to replace the params in the `params` field in the optimizer # so that when the optimizer calls step(), it only updates the tensors # managed by this data parallel rank - param_group['params'] = [master_flat_current_rank] + param_group['params'] = master_param_current_rank - # set reduction state - for param in self._working_param_groups[group_id]: - self._param_store.set_param_reduction_state(param, False) - - # initialize communication stream for - # communication-computation overlapping + # intialize communication stream for + # communication-compuation overlapping if self._overlap_communication: self._comm_stream = torch.cuda.Stream() @@ -249,45 +194,36 @@ def _sanity_checks(self): assert param.dtype == self._dtype, \ f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`" - def _search_colo_process_group(self): - colo_flag = False - colo_pg = None - for param_group in self.optim.param_groups: - group_params = param_group['params'] - for param in group_params: - if isinstance(param, ColoParameter): - colo_flag = True - if colo_pg is None: - colo_pg = param.get_process_group() - else: - assert colo_pg == param.get_process_group(), "All parameters should be in a same process group" - elif colo_flag: - raise RuntimeError("All parameters should be ColoParameter if you use ColoParameter.") - return colo_pg - - def _partition_param_list(self, param_list): - params_per_rank = [[] for _ in range(self._world_size)] - numel_per_rank = [0 for _ in range(self._world_size)] - - # partition the parameters in a greedy fashion - sorted_params = sorted(param_list, key=lambda x: x.numel(), reverse=True) - for param in sorted_params: - # allocate this parameter to the rank with - # the smallest numel for load balancing purpose - rank_to_go = numel_per_rank.index(min(numel_per_rank)) - params_per_rank[rank_to_go].append(param) - numel_per_rank[rank_to_go] += param.numel() - - if self._verbose: - self._logger.info(f'Number of elements on ranks: {numel_per_rank}', ranks=[0]) - return params_per_rank + def _create_master_param_current_rank(self, param_list): + # split each param evenly by world size + params_current_rank = [] + device = 'cpu' if self._cpu_offload else get_current_device() + + for param in param_list: + padding_size = (self._world_size - param.numel() % self._world_size) % self._world_size + self._param_store.record_param_padding_size(param, padding_size) + + with torch.no_grad(): + if padding_size > 0: + padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size]) + else: + padding_param = param.data.view(-1) + splited_params = padding_param.split(padding_param.numel() // self._world_size) + + splited_param_current_rank = splited_params[self._local_rank].detach().float().to(device) + params_current_rank.append(splited_param_current_rank) + self._param_store.link_master_and_working_param(splited_param_current_rank, param) + + return params_current_rank ########################### # Backward Reduction Hook # ########################### - def _grad_handler(self, param, grad, reduce_rank): - self._add_to_reduction_bucket(param, reduce_rank) + def _grad_handler(self, param, group_id, grad): + # if run with no_sync context, would not sync grad when backward + if self.require_grad_sync: + self._add_to_bucket(param, group_id) return grad def _attach_reduction_hook(self): @@ -297,149 +233,96 @@ def _attach_reduction_hook(self): param_group = self._working_param_groups[group_id] for param in param_group: if param.requires_grad: - # determines the reduction destination rank - # this is only valid for stage 2 - # dst_rank = None means using all-reduce - # else using reduce - if self._partition_grads: - reduce_rank = self._param_store.get_param_rank(param) - else: - reduce_rank = None - - param.register_hook(partial(self._grad_handler, param, reduce_rank=reduce_rank)) - - def _reduce_tensor_bucket(self, bucket: TensorBucket, reduce_rank): - if self._overlap_communication: - torch.cuda.synchronize() - self._param_store.clear_grads_of_previous_reduced_params() - stream = self._comm_stream - else: - stream = torch.cuda.current_stream() - - with torch.cuda.stream(stream): - flat = bucket.flatten() - reduce_global_rank = None - if reduce_rank is not None: - reduce_global_rank = self._dp_global_ranks[reduce_rank] - reduced_flat = reduce_tensor_dp_group(tensor=flat, - dtype=self._communication_dtype, - dst_local_rank=reduce_rank, - dst_global_rank=reduce_global_rank, - group=self._dp_torch_group) - - # update the reduced tensor - if reduce_rank is None or reduce_rank == self._local_rank: - bucket.unflatten_and_copy(reduced_flat) - - def _reduce_tensor_list_with_one_dtype(self, tensor_list, bucket_size, reduce_rank): - param_bucket = TensorBucket(size=bucket_size) - - for tensor in tensor_list: - param_bucket.add_to_bucket(tensor, allow_oversize=True) - - if param_bucket.is_full_or_oversized(): - self._reduce_tensor_bucket(bucket=param_bucket, reduce_rank=reduce_rank) - param_bucket.empty() - - if not param_bucket.is_empty(): - self._reduce_tensor_bucket(bucket=param_bucket, reduce_rank=reduce_rank) - - def _reduce_grads(self, reduce_rank, grads, bucket_size): - grad_buckets_by_dtype = split_by_dtype(grads) - - for tensor_list in grad_buckets_by_dtype: - self._reduce_tensor_list_with_one_dtype(tensor_list=tensor_list, - bucket_size=bucket_size, - reduce_rank=reduce_rank) + param.register_hook(partial(self._grad_handler, param, group_id)) ####################### # Reduction Functions # ####################### - def _run_reduction(self, reduce_rank=None): - # reduce grads - self._reduce_grads(reduce_rank=reduce_rank, - grads=self._bucket_store.get_grad(reduce_rank=reduce_rank), - bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank)) - - # use communication stream if overlapping - # communication with computation - if self._overlap_communication: - stream = self._comm_stream - else: - stream = torch.cuda.current_stream() - - with torch.cuda.stream(stream): - params_in_bucket = self._bucket_store.get_param(reduce_rank=reduce_rank) - - for param in params_in_bucket: - # the is_param_reduced flag should be False showing that - # this param is not reduced before calling self._reduce_grads_by_rank - is_param_reduced = self._param_store.is_param_reduced(param) + def _run_reduction(self): + if self._bucket_store.num_elements_in_bucket() > 0: + self._bucket_store.build_grad_in_bucket() + flat_grads = self._bucket_store.get_flatten_grad() + flat_grads /= self._world_size + if self._overlap_communication: + stream = self._comm_stream + else: + stream = torch.cuda.current_stream() + + with torch.cuda.stream(stream): + group_id = self._bucket_store.current_group_id + + grad_dtype = flat_grads.dtype + if self._communication_dtype is not None: + flat_grads = flat_grads.to(self._communication_dtype) + + if not self._partition_grads: + dist.all_reduce(flat_grads, group=self.dp_pg) + if flat_grads.dtype != grad_dtype: + flat_grads = flat_grads.to(grad_dtype) + + flat_grads_per_rank = flat_grads.split(flat_grads.numel() // self._world_size) + grad_in_bucket = self._bucket_store.get_grad() + + for rank, grad_list in grad_in_bucket.items(): + sync_tensor(flat_grads_per_rank[rank], grad_list) + for grad in grad_list: + param_id = self._bucket_store.get_param_id_of_grad(grad) + self._grad_store.append_gradients_by_param_id(grad, group_id, param_id) - if is_param_reduced: - msg = f'Parameter of size ({param.size()}) has been reduced, ' + \ - 'duplicate reduction will lead to arithmetic incorrectness' - raise RuntimeError(msg) + else: + flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size)) + recieved_grad = torch.zeros_like(flat_grads_list[0]) + dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg) - # update the flag - self._param_store.set_param_reduction_state(param, True) + if recieved_grad.dtype != grad_dtype: + recieved_grad = recieved_grad.to(grad_dtype) - # if partition grads = True - # we do not keep the gradient after reduction - if self._partition_grads and not self._param_store.belongs_to_current_rank(param): - if self._overlap_communication: - # we need to keep this gradient for now as reduction may - # be completed yet since it is using a different cuda stream - self._param_store.add_previous_reduced_param(param) - else: - param.grad = None + grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank] + sync_tensor(recieved_grad, grad_in_bucket_current_rank) + for grad in grad_in_bucket_current_rank: + param_id = self._bucket_store.get_param_id_of_grad(grad) + self._grad_store.append_gradients_by_param_id(grad, group_id, param_id) - self._bucket_store.reset_by_rank(reduce_rank) + self._bucket_store.reset() - def _add_to_reduction_bucket(self, param, reduce_rank=None): + def _add_to_bucket(self, param, group_id): param_size = param.numel() # check if the bucket is full # if full, will reduce the grads already in the bucket + # or got a grad of param from another group # after reduction, the bucket will be empty - if self._bucket_store.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size: - self._run_reduction(reduce_rank) - - # the param must not be reduced to ensure correctness - is_param_reduced = self._param_store.is_param_reduced(param) - if is_param_reduced: - msg = f'Parameter of size ({param.size()}) has already been reduced, ' \ - + 'duplicate reduction will lead to arithmetic incorrectness' - raise RuntimeError(msg) + if self._bucket_store.num_elements_in_bucket() + param_size > self._reduce_bucket_size or \ + group_id != self._bucket_store.current_group_id: + self._run_reduction() - self._bucket_store.add_num_elements_in_bucket(param_size, reduce_rank) - self._bucket_store.add_param(param, reduce_rank) + padding_size = self._param_store.get_param_padding_size(param) + self._bucket_store.add_param_grad(group_id, param, padding_size) ################################ # torch.optim.Optimizer methods ################################ - def backward(self, loss, retain_graph=False, sync_grad=True): + def backward(self, loss, retain_graph=False): + assert not(self._partition_grads and not self.require_grad_sync), \ + "ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible" + if self.mixed_precision_mixin is not None: loss = self.mixed_precision_mixin.pre_backward(loss) + loss.backward(retain_graph=retain_graph) - # finish gradient reduction - if not self._partition_grads: - self._reduce_grad_stage1() - else: - # TODO: support async comm in reduce - self._reduce_grad_stage2() + if not self.require_grad_sync: + return + + self._reduce_grad(self._partition_grads) # clear reduced grads if self._overlap_communication: torch.cuda.synchronize() - self._param_store.clear_grads_of_previous_reduced_params() - # gradient synchronization - if sync_grad: - self._sync_grad() + self.zero_grad() def zero_grad(self, set_to_none=True): """ @@ -467,68 +350,80 @@ def zero_grad(self, set_to_none=True): def step(self, closure=None): assert closure is None, 'closure is not supported by step()' + if not self.require_grad_sync: + return + if self.mixed_precision_mixin is not None and self.mixed_precision_mixin.should_skip_step(): - self._grad_store.reset_all_average_gradients() + self._grad_store.reset_all_gradients() if self._verbose: self._logger.info(f'Found overflow. Skip step') self.zero_grad() return - # copy the grad of working param to master param - single_grad_partition_groups = [] + # record all grads for unscale and clip + grad_partition_groups = [] norm_groups = [] + # sometimes not all params are 'really' working + # for instance, when layer drop, the dropped layer has no grad + # and should not be updated + real_working_params = dict() + real_master_params = dict() + + grad_index = 0 if self._partition_grads else self._local_rank + for group_id in range(self.num_param_groups): + master_params = self._master_param_groups_of_current_rank[group_id] + real_working_params[group_id] = [] + real_master_params[group_id] = [] + for splited_param in master_params: + working_param = self._param_store.master_to_working_param[id(splited_param)] + # if a working param requires grad and has no grad + # it is not 'really' working, e.g. the droped layer + # else the splited grad should be attached to the splited param + grads = self._grad_store.get_partitioned_gradients_by_param_id(group_id, id(working_param)) + if len(grads) > 0: + real_working_params[group_id].append(working_param) + grad = grads[grad_index].to(splited_param.dtype).to(splited_param.device) + splited_param.grad = grad + grad_partition_groups.append(grad) + real_master_params[group_id].append(splited_param) + # compute norm - norm_group = compute_norm(gradients=self._grad_store.get_averaged_gradients_by_group(group_id), - params=self._param_store.get_params_by_rank_group(group_id=group_id, - rank=self._local_rank), - dp_group=self._dp_torch_group, - mp_group=self._mp_torch_group) + working_grads = self._grad_store.get_working_grads_by_group_id(group_id) + norm_group = compute_norm(gradients=working_grads, dp_group=self.dp_pg, tp_group=self.tp_pg) norm_groups.append(norm_group) - # create flat gradient for the flat fp32 master params - working_avg_grads = self._grad_store.get_averaged_gradients_by_group(group_id) - flat_working_avg_grads = flatten(working_avg_grads) - - dtype = self._master_flat_param_groups_of_current_rank[group_id].dtype - flat_master_avg_grads = flat_working_avg_grads.to(dtype) + self._grad_store.reset_grads_by_group_id(group_id) - param_shape = self._master_flat_param_groups_of_current_rank[group_id].shape - assert param_shape == flat_master_avg_grads.shape, \ - f'fp32 param and grad have different shape {param_shape} vs {flat_master_avg_grads.shape}' - - single_grad_partition_groups.append(flat_master_avg_grads) - device = self._master_flat_param_groups_of_current_rank[group_id].device - self._master_flat_param_groups_of_current_rank[group_id].grad = flat_master_avg_grads.to(device) - self._grad_store.reset_average_gradients_by_group(group_id) + # update the params in the optimizer + self.optim.param_groups[group_id]['params'] = real_master_params[group_id] # unscale and clip grads global_norm = calculate_global_norm_from_list(norm_list=norm_groups) - self._unscale_and_clip_grads(single_grad_partition_groups, global_norm) + self._unscale_and_clip_grads(grad_partition_groups, global_norm) # update the parameters self.optim.step() - # release the master grad - release_param_grad(self._master_flat_param_groups_of_current_rank.values()) - # update working partition updated by the current rank - for group_id in range(len(self._working_param_groups)): - working_param = self._param_store.get_flat_param_by_rank_group(rank=self._local_rank, group_id=group_id) - master_param = self._master_flat_param_groups_of_current_rank[group_id] - working_param.data.copy_(master_param) + # release the grad + grad_partition_groups = [] + for group_id in range(self.num_param_groups): + release_param_grad(self._master_param_groups_of_current_rank[group_id]) - # broadcast the updated model weights - handles = [] + # update working partition updated by the current rank + dtype = real_working_params[0][0].dtype for group_id in range(self.num_param_groups): - for index in range(self._world_size): - rank = self._dp_global_ranks[index] - working_param = self._param_store.get_flat_param_by_rank_group(rank=index, group_id=group_id) - handle = dist.broadcast(working_param, src=rank, group=self._dp_torch_group, async_op=True) - handles.append(handle) + master_working_param = self.optim.param_groups[group_id]['params'] + for idx, splited_param in enumerate(master_working_param): + working_param = real_working_params[group_id][idx] + all_splited_param = [ + torch.zeros(splited_param.shape, device="cuda", dtype=dtype) for _ in range(self._world_size) + ] + dist.all_gather(all_splited_param, splited_param.cuda().to(dtype), group=self.dp_pg) + working_param.data.copy_(flatten(all_splited_param)[:working_param.numel()].reshape_as(working_param)) - for handle in handles: - handle.wait() + self.optim.param_groups[group_id]['params'] = self._master_param_groups_of_current_rank[group_id] ############################# # Mixed Precision Utilities # @@ -553,49 +448,139 @@ def _unscale_and_clip_grads(self, grad_groups_flat, total_norm): # Gradient Synchronization # ############################ - def _sync_grad(self): - # update param already reduced flag - reduction_states = self._param_store.get_param_reduction_states() - for tensor, _ in reduction_states.items(): - reduction_states[tensor] = False - - # accumulate gradient + # this method is used to sync gradient manually + def sync_grad(self): for group_id in range(self.num_param_groups): - param_group = self._param_store.get_params_by_rank_group(self._local_rank, group_id) - - avg_gradients_group = self._grad_store.get_averaged_gradients_by_group(group_id) - - param_idx = 0 + param_group = self._working_param_groups[group_id] for param in param_group: - if param.grad is not None: - if len(avg_gradients_group) == param_idx: - self._grad_store.append_average_gradient_by_group(group_id, param.grad) - else: - self._grad_store.add_average_gradient_by_group(group_id, param_idx, param.grad) - param_idx += 1 - - # the gradients needed are stored in the avg_gradients buffer - # thus, can clear this - self.zero_grad() + if param.requires_grad and param.grad is not None: + self._add_to_bucket(param, group_id) - def _reduce_grad_stage1(self): - # if not overlapping communication (no reduction hook is attached) - # we need to manually reduce these gradients - if not self._overlap_communication: - for group_id in range(len(self._working_param_groups)): - param_group = self._working_param_groups[group_id] - for param in param_group: - if param.grad is not None: - self._add_to_reduction_bucket(param) - - # we need to reduce the gradients - # left in the communication bucket self._run_reduction() - def _reduce_grad_stage2(self): - # when partition_grads is True, reduction hooks - # are attached in the __init__ function, so we - # only need to reduce the gradients - # left in the communication bucket - for reduce_rank in range(self._world_size): - self._run_reduction(reduce_rank) + def _reduce_grad(self, partition_grad): + # if not overlapping communication (no reduction hook is attached) when zero1 + # we need to manually reduce these gradients + if not partition_grad and not self._overlap_communication: + self.sync_grad() + else: + self._run_reduction() + + # this context comes from pytorch DDP + @contextmanager + def no_sync(self): + old_require_grad_sync = self.require_grad_sync + self.require_grad_sync = False + try: + yield + finally: + self.require_grad_sync = old_require_grad_sync + + ############## + # State Dict # + ############## + + def _pack_state(self, state: Dict) -> Dict: + # comes from pytorch optimizer.state_dict() + param_mappings = {} + start_index = 0 + + def pack_group(group): + nonlocal start_index + packed = {k: v for k, v in group.items() if k != 'params'} + param_mappings.update( + {id(p): i for i, p in enumerate(group['params'], start_index) if id(p) not in param_mappings}) + packed['params'] = [param_mappings[id(p)] for p in group['params']] + start_index += len(packed['params']) + return packed + + param_groups = [pack_group(g) for g in self.optim.param_groups] + # Remap state to use order indices as keys + packed_state = {(param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): v for k, v in state.items()} + + return {'state': packed_state, 'param_groups': param_groups} + + def state_dict(self) -> Dict: + """Return a state_dict same with DDP + + Returns: + Dict: the pytorch form state_dict + """ + zero_state = dict() + for param, state in self.optim.state.items(): + zero_state[param] = copy.deepcopy(state) + for k, v in state.items(): + if isinstance(v, torch.Tensor) and k != 'step': + working_param = self._param_store.master_to_working_param[id(param)] + gather_tensor = [torch.zeros_like(v) for _ in range(self._world_size)] + dist.all_gather(gather_tensor, v, group=self.dp_pg) + param_state = torch.stack(gather_tensor).view(-1)[:working_param.numel()].reshape_as(working_param) + zero_state[param][k] = param_state + + states_dict = self._pack_state(zero_state) + + return states_dict + + def load_state_dict(self, state_dict: Dict): + """Load state dict, requires the state_dict be the pytorch form + + Args: + state_dict (dict): A pytorch form state_dict + """ + zero_state_dict = copy.deepcopy(state_dict) + for param_idx, state in zero_state_dict['state'].items(): + for k, v in state.items(): + if isinstance(v, torch.Tensor) and k != 'step': + padding_size = (self._world_size - v.numel() % self._world_size) % self._world_size + with torch.no_grad(): + v = v.flatten() + if padding_size > 0: + v = torch.nn.functional.pad(v, [0, padding_size]) + v_list = v.split(v.numel() // self._world_size) + zero_state_dict['state'][param_idx][k] = v_list[self._local_rank].detach() + + self.optim.load_state_dict(zero_state_dict) + zero_state_dict = dict() + + def state_dict_shard(self, max_shard_size: int = 1024) -> Iterator[Tuple[Dict, int]]: + """Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``. + Only include the 'state' in state_dict. + + Args: + max_shard_size (int, optional): max size of state shard (in MB). Defaults to 1024. + + Yields: + Iterator[OrderedDict]: A generator of state dict shard + """ + ret_block = dict() + ret_block_size = 0 + + local_states = self.optim.state_dict()['state'] + for param_idx, states in local_states.items(): + current_block_size = 0 + current_block = copy.deepcopy(states) + + # find the working param of current param_id + for group_id, pg in self._master_param_groups_of_current_rank.items(): + if (group_id + 1) * len(pg) < param_idx: + continue + master_param = pg[param_idx - (group_id) * len(pg)] + working_param = self._param_store.master_to_working_param[id(master_param)] + + for k, v in states.items(): + if isinstance(v, torch.Tensor) and k != 'step': + state_tensor = [torch.zeros_like(v) for _ in range(self._world_size)] + dist.all_gather(state_tensor, v, group=self.dp_pg) + state_tensor = torch.stack(state_tensor).view(-1)[:working_param.numel()].reshape_as(working_param) + current_block_size += state_tensor.numel() + current_block[k] = state_tensor + + if ret_block_size + current_block_size > max_shard_size and len(ret_block) > 0: + yield ret_block, ret_block_size + ret_block = dict() + ret_block_size = 0 + + ret_block[param_idx] = current_block + ret_block_size += current_block_size + + yield ret_block, ret_block_size diff --git a/colossalai/zero/low_level/readme.md b/colossalai/zero/low_level/readme.md new file mode 100644 index 000000000000..aa92159d8022 --- /dev/null +++ b/colossalai/zero/low_level/readme.md @@ -0,0 +1,54 @@ +# Low Level ZeRO +>Low Level ZeRO == ZeRO-DP stage 1 and 2, we would denote it as ZeRO. + +## Design: +### Notion +`p32` denotes the param copy in the optimizer +`p` denotes the model param +`g` denotes the gradient + +### INIT +In low level zero(1, 2), `p32` is split. Different from the previous implement, we split each `p32` evenly by world_size. Thus, rank0 got a param list as `[p00, p10]`, rank1 got a param list as `[p-01, p-11]`, etc. +image + +For the detailed implementation, we first pad `p` for it can be split by world_size if needed. Then, we would view it to the shape `[world_size, -1]`, and each rank got its own part `p32` by cloning. + +### BWD +To leverage the communication, a gradient would be added to a bucket first. When the bucket is full, each `g` in it would be reshaped as `[world_size, -1]`. And the `[local_rank]` parts would be united. +The data structure looks like this: +``` +{ +0: [g-00, g-10], +1: [g-01, g-11], +2: [g-02, g-12] +} +``` +After that, the gradients would be flattened by rank, and the data structure looks like this: +``` +# g-0 means flatten([g-00, g-10]) +{ +0: [g-0], +1: [g-1], +2: [g-2] +} +``` +For zero1, we iterate the dictionary and do `all_reduce`. For zero2, we can just do `reduce-scatter`. + +### Optim +For each rank gets its own `p32` and the counterpart `g`, it is quite easy to do `optim.step()`. + +However, we have to consider a situation of layer drop, for instance: +``` +class MlpModel(nn.Module): + def __init__(self): + super(MlpModel, self).__init__() + self.linear1 = nn.Linear(128, 256) + self.drop_linear = nn.Linear(256, 256) + self.linear2 = nn.Linear(256, 512) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x +``` +And the solution is to build a mapping of `p32`, `p`, and `g`. Before `optim.step()`, we collect `p` which `requires_grad=True` and `p.grad != None` as a real working param. And select the counterpart `p32` and `g`. diff --git a/examples/images/diffusion/requirements.txt b/examples/images/diffusion/requirements.txt index 59d027fcf60f..0d9ce55a8079 100644 --- a/examples/images/diffusion/requirements.txt +++ b/examples/images/diffusion/requirements.txt @@ -12,7 +12,7 @@ einops==0.3.0 transformers webdataset==0.2.5 open-clip-torch==2.7.0 -gradio==3.11 +gradio==3.34.0 lightning==1.9.0 datasets colossalai diff --git a/pytest.ini b/pytest.ini index 01e5cd217c5d..e99fe3f086c6 100644 --- a/pytest.ini +++ b/pytest.ini @@ -4,3 +4,4 @@ markers = gpu: tests which requires a single GPU dist: tests which are run in a multi-GPU or multi-machine environment experiment: tests for experimental features +addopts = --ignore=tests/test_analyzer --ignore=tests/test_auto_parallel --ignore=tests/test_autochunk diff --git a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py index eedd8c59a3a8..79f98a4c95d0 100644 --- a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py +++ b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py @@ -11,14 +11,9 @@ from tests.kit.model_zoo import model_zoo # These models are not compatible with AMP -_AMP_ERR_MODELS = ['timm_convit', 'dlrm', 'deepfm_interactionarch', 'deepfm_simpledeepfmnn'] +_AMP_ERR_MODELS = ['timm_convit', 'deepfm_interactionarch'] # These models have no parameters -_LOW_LEVEL_ZERO_ERR_MODELS = ['dlrm_interactionarch', 'deepfm_overarch', 'deepfm_sparsearch', 'dlrm_sparsearch'] -# These models will get stuck -_STUCK_MODELS = [ - 'diffusers_vq_model', 'transformers_albert', 'transformers_albert_for_pretraining', 'transformers_bert', - 'transformers_bert_for_pretraining', 'transformers_gpt_double_heads' -] +_LOW_LEVEL_ZERO_ERR_MODELS = ['dlrm_interactionarch'] def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]: @@ -58,7 +53,7 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True): """ passed_models = [] failed_info = {} # (model_name, error) pair - ignore_models = _AMP_ERR_MODELS + _LOW_LEVEL_ZERO_ERR_MODELS + _STUCK_MODELS + ignore_models = _AMP_ERR_MODELS + _LOW_LEVEL_ZERO_ERR_MODELS skipped_models = [] for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items(): diff --git a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py index c51b54c82f57..a94e8d42c78e 100644 --- a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py @@ -38,9 +38,8 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool): optimizer_ckpt_path = f"{tempdir}/optimizer" # lr scheduler is tested in test_torch_ddp_checkpoint_io.py and low level zero does not change it, we can skip it here booster.save_model(model, model_ckpt_path, shard=shard) - if not shard: - # TODO(ver217): optimizer checkpointing is not supported for sharded checkpoint - booster.save_optimizer(optimizer, optimizer_ckpt_path) + booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard) + dist.barrier() new_model = resnet18() @@ -49,9 +48,9 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool): booster.load_model(new_model, model_ckpt_path) check_state_dict_equal(model.state_dict(), new_model.state_dict(), False) - if not shard: - booster.load_optimizer(new_optimizer, optimizer_ckpt_path) - check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False) + + booster.load_optimizer(new_optimizer, optimizer_ckpt_path) + check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False) def run_dist(rank, world_size, port): @@ -62,3 +61,7 @@ def run_dist(rank, world_size, port): @rerun_if_address_is_in_use() def test_low_level_zero_checkpointIO(): spawn(run_dist, 2) + + +if __name__ == "__main__": + test_low_level_zero_checkpointIO() diff --git a/tests/test_lazy/test_distribute.py b/tests/test_lazy/test_distribute.py deleted file mode 100644 index 622d9deb601d..000000000000 --- a/tests/test_lazy/test_distribute.py +++ /dev/null @@ -1,102 +0,0 @@ -from typing import Optional - -import pytest -import torch -import torch.nn as nn - -import colossalai -from colossalai.device.device_mesh import DeviceMesh -from colossalai.tensor.d_tensor.layout import Layout -from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.utils.common import print_rank_0 - -try: - from colossalai.lazy.lazy_init import LazyInitContext, LazyTensor, _MyTensor -except: - pass -from lazy_init_utils import SUPPORT_LAZY, assert_dist_model_equal, set_seed - -from tests.kit.model_zoo import model_zoo - - -def find_shard_dim(shape: torch.Size) -> Optional[int]: - for dim, size in enumerate(shape): - if size % 2 == 0: - return dim - - -def make_sharding_spec(original_tensor: torch.Tensor) -> Layout: - shard_dim = find_shard_dim(original_tensor.shape) - dim_partition_dict = {shard_dim: [0]} if shard_dim is not None else {} - target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict=dim_partition_dict) - return target_sharding_spec - - -def _get_current_name(prefix: str, name: str) -> str: - return f'{prefix}.{name}'.lstrip('.') - - -def generate_sharding_spec_dict(model: nn.Module) -> dict: - sharding_spec_dict = {} - - @torch.no_grad() - def generate_recursively(module: nn.Module, prefix: str = ''): - # recursively initialize the module - for name, mod in module.named_children(): - generate_recursively(mod, prefix=_get_current_name(prefix, name)) - - # initialize tensors directly attached to the current module - for name, param in module.named_parameters(recurse=False): - if isinstance(param, LazyTensor): - sharding_spec = make_sharding_spec(param) - sharding_spec_dict[_get_current_name(prefix, name)] = sharding_spec - - for name, buf in module.named_buffers(recurse=False): - if isinstance(buf, LazyTensor): - sharding_spec = make_sharding_spec(buf) - sharding_spec_dict[_get_current_name(prefix, name)] = sharding_spec - - generate_recursively(model) - - return sharding_spec_dict - - -@parameterize('subset', ['torchvision', 'diffusers', 'timm', 'transformers', 'torchaudio', 'deepfm', 'dlrm']) -def run_dist_lazy_init(subset, seed: int = 42): - sub_model_zoo = model_zoo.get_sub_registry(subset) - device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True) - _MyTensor._pre_op_fn = lambda *args: set_seed(seed) - LazyTensor._pre_op_fn = lambda *args: set_seed(seed) - - for name, entry in sub_model_zoo.items(): - # TODO(ver217): lazy init does not support weight norm, skip these models - if name in ('torchaudio_wav2vec2_base', 'torchaudio_hubert_base') or name.startswith('transformers_llama'): - continue - print_rank_0(name) - model_fn, data_gen_fn, output_transform_fn, _, model_attr = entry - ctx = LazyInitContext(tensor_cls=_MyTensor) - with ctx: - model = model_fn() - ctx = LazyInitContext() - with ctx: - deferred_model = model_fn() - sharding_spec_dict = generate_sharding_spec_dict(deferred_model) - ctx.distribute(deferred_model, device_mesh, sharding_spec_dict, verbose=True) - assert_dist_model_equal(model, deferred_model, device_mesh, sharding_spec_dict) - - -def run_dist(rank, world_size, port) -> None: - colossalai.launch({}, rank=rank, world_size=world_size, host='localhost', port=port) - run_dist_lazy_init() - - -@pytest.mark.skipif(not SUPPORT_LAZY, reason='torch version should be >= 1.12.0') -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_dist_lazy_init(): - spawn(run_dist, 4) - - -if __name__ == '__main__': - test_dist_lazy_init() diff --git a/tests/test_utils/test_checkpoint_io/test_build_checkpoints.py b/tests/test_utils/test_checkpoint_io/test_build_checkpoints.py deleted file mode 100644 index 6d89fb90c574..000000000000 --- a/tests/test_utils/test_checkpoint_io/test_build_checkpoints.py +++ /dev/null @@ -1,120 +0,0 @@ -import torch -import torch.nn as nn -from colossalai.utils.checkpoint_io.meta import ParamDistMeta -from colossalai.utils.checkpoint_io.utils import build_checkpoints -from torch.optim import Adam - - -class DummyModel(nn.Module): - - def __init__(self) -> None: - super().__init__() - self.fc = nn.Linear(20, 1) - - -def test_global_model(): - model = DummyModel() - model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model) - assert len(model_checkpoints) == 1 - assert len(optimizer_checkpoints) == 0 - assert meta['dist_meta'] is None - orig_state_dict = model.state_dict() - global_state_dict = model_checkpoints[0] - assert set(orig_state_dict.keys()) == set(global_state_dict.keys()) - for k, v in orig_state_dict.items(): - assert torch.equal(v, global_state_dict[k]) - - -def test_global_model_shard(): - model = DummyModel() - model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(80, model) - assert len(model_checkpoints) == 2 - assert len(optimizer_checkpoints) == 0 - assert meta['dist_meta'] is None - orig_state_dict = model.state_dict() - assert set(orig_state_dict.keys()) == set(model_checkpoints[0].keys()) | set(model_checkpoints[1].keys()) - assert len(set(model_checkpoints[0].keys()) & set(model_checkpoints[1].keys())) == 0 - for k, v in orig_state_dict.items(): - for state_dict in model_checkpoints: - if k in state_dict: - assert torch.equal(v, state_dict[k]) - - -def test_global_optimizer(): - model = DummyModel() - for p in model.parameters(): - p.grad = torch.rand_like(p) - optimizer = Adam(model.parameters(), lr=1e-3) - optimizer.step() - model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model, optimizer) - assert len(optimizer_checkpoints) == 1 - assert meta['param_to_os'] == {'fc.weight': 0, 'fc.bias': 1} - for state in meta['paired_os'].values(): - for k, is_paired in state.items(): - if k == 'step': - assert not is_paired - else: - assert is_paired - orig_state_dict = optimizer.state_dict() - state_dict = optimizer_checkpoints[0] - for k, orig_state in orig_state_dict['state'].items(): - state = state_dict['state'][k] - for v1, v2 in zip(orig_state.values(), state.values()): - if isinstance(v2, torch.Tensor): - assert torch.equal(v1, v2) - else: - assert v2 == v2 - assert orig_state_dict['param_groups'] == state_dict['param_groups'] - - -def test_global_optimizer_shard(): - model = DummyModel() - for p in model.parameters(): - p.grad = torch.rand_like(p) - optimizer = Adam(model.parameters(), lr=1e-3) - optimizer.step() - model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(80, model, optimizer) - assert len(optimizer_checkpoints) == 2 - assert 'param_groups' in optimizer_checkpoints[0] and 'param_groups' not in optimizer_checkpoints[1] - orig_state_dict = optimizer.state_dict() - assert set(orig_state_dict['state'].keys()) == set(optimizer_checkpoints[0]['state'].keys()) | set( - optimizer_checkpoints[1]['state'].keys()) - assert len(set(optimizer_checkpoints[0]['state'].keys()) & set(optimizer_checkpoints[1]['state'].keys())) == 0 - for k, orig_state in orig_state_dict['state'].items(): - state = optimizer_checkpoints[0]['state'][k] if k in optimizer_checkpoints[0][ - 'state'] else optimizer_checkpoints[1]['state'][k] - for v1, v2 in zip(orig_state.values(), state.values()): - if isinstance(v2, torch.Tensor): - assert torch.equal(v1, v2) - else: - assert v1 == v2 - - assert orig_state_dict['param_groups'] == optimizer_checkpoints[0]['param_groups'] - - -def test_dist_model_optimizer(): - model = DummyModel() - for p in model.parameters(): - p.grad = torch.rand_like(p) - optimizer = Adam(model.parameters(), lr=1e-3) - optimizer.step() - dist_meta = {'fc.weight': ParamDistMeta(0, 2, 0, 1), 'fc.bias': ParamDistMeta(1, 2, 0, 1)} - model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model, optimizer, dist_meta=dist_meta) - assert dist_meta == meta['dist_meta'] - assert len(model_checkpoints) == 1 - assert len(optimizer_checkpoints) == 1 - assert 'fc.weight' in model_checkpoints[0] and 'fc.bias' in model_checkpoints[0] - assert 0 in optimizer_checkpoints[0]['state'] and 1 in optimizer_checkpoints[0]['state'] - dist_meta = {'fc.weight': ParamDistMeta(1, 2, 0, 1), 'fc.bias': ParamDistMeta(1, 2, 0, 1)} - model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model, optimizer, dist_meta=dist_meta) - assert dist_meta == meta['dist_meta'] - assert len(model_checkpoints) == 1 - assert len(optimizer_checkpoints) == 1 - - -if __name__ == '__main__': - test_global_model() - test_global_model_shard() - test_global_optimizer() - test_global_optimizer_shard() - test_dist_model_optimizer() diff --git a/tests/test_utils/test_checkpoint_io/test_load.py b/tests/test_utils/test_checkpoint_io/test_load.py deleted file mode 100644 index 2949c9f0752d..000000000000 --- a/tests/test_utils/test_checkpoint_io/test_load.py +++ /dev/null @@ -1,186 +0,0 @@ -from copy import deepcopy -from functools import partial -from tempfile import TemporaryDirectory -from typing import Dict - -import pytest -import torch -import torch.distributed as dist -import torch.nn as nn -from torch import Tensor -from torch.nn import Module -from torch.optim import Adam, Optimizer - -import colossalai -from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils.checkpoint_io.io import load, save -from colossalai.utils.checkpoint_io.meta import ParamDistMeta, ParamRedistMeta, RankRedistMeta, RedistMeta - - -def check_model_state_dict(a: Dict[str, Tensor], b: Dict[str, Tensor]) -> None: - assert set(a.keys()) == set(b.keys()) - for k, v in a.items(): - assert torch.equal(v, b[k]) - - -def check_optim_state_dict(a: dict, b: dict, ignore_param_groups: bool = False) -> None: - assert set(a['state'].keys()) == set(b['state'].keys()) - for k, state in a['state'].items(): - b_state = b['state'][k] - for v1, v2 in zip(state.values(), b_state.values()): - if isinstance(v1, Tensor): - assert torch.equal(v1, v2) - else: - assert v1 == v2 - if not ignore_param_groups: - assert a['param_groups'] == b['param_groups'] - - -class DummyModel(nn.Module): - - def __init__(self) -> None: - super().__init__() - self.fc = nn.Linear(20, 1) - - -def prepare_model_optim(shard: bool = False, zero: bool = False): - model = DummyModel() - if shard: - model.fc.weight.data = model.fc.weight.chunk(2, 1)[dist.get_rank() % 2] - if zero: - dp_rank = dist.get_rank() // 2 - model.fc.weight.data = model.fc.weight.reshape(-1).split([3, model.fc.weight.size(1) - 3], 0)[dp_rank] - if dp_rank != 0: - model.fc.bias.data = torch.empty(0, dtype=model.fc.bias.dtype) - for p in model.parameters(): - p.grad = torch.rand_like(p) - optimizer = Adam(model.parameters(), lr=1e-3) - optimizer.step() - return model, optimizer - - -def reset_model_optim(model: Module, optimizer: Optimizer, scalar: float = 0.0): - with torch.no_grad(): - for p in model.parameters(): - p.fill_(scalar) - for state in optimizer.state.values(): - for v in state.values(): - if isinstance(v, Tensor): - v.fill_(scalar) - - -def get_dist_metas(nprocs: int, zero: bool = False): - dp_world_size = nprocs // 2 - dist_metas = [] - for rank in range(nprocs): - if zero: - dist_metas.append({ - 'fc.weight': - ParamDistMeta(rank // 2, - dp_world_size, - rank % 2, - 2, - tp_shard_dims=[1], - tp_num_parts=[2], - zero_numel=10, - zero_orig_shape=[1, 10]), - 'fc.bias': - ParamDistMeta(rank // 2, dp_world_size, 0, 1, zero_numel=1, zero_orig_shape=[1]) - }) - else: - dist_metas.append({ - 'fc.weight': ParamDistMeta(rank // 2, dp_world_size, rank % 2, 2, tp_shard_dims=[1], tp_num_parts=[2]), - 'fc.bias': ParamDistMeta(rank // 2, dp_world_size, 0, 1) - }) - return dist_metas - - -def get_redist_meta(nprocs: int): - dp_world_size = nprocs // 2 - rank_meta = { - 'fc.weight': {rank: RankRedistMeta(rank // 2, rank % 2, 0) for rank in range(nprocs)}, - 'fc.bias': {rank: RankRedistMeta(rank // 2, 0, 0) for rank in range(nprocs)} - } - param_meta = { - 'fc.weight': ParamRedistMeta(dp_world_size, 2, tp_shard_dims=[1], tp_num_parts=[2]), - 'fc.bias': ParamRedistMeta(dp_world_size, 1) - } - return RedistMeta(rank_meta, [], param_meta) - - -@pytest.mark.parametrize('max_shard_size_gb', [80 / 1024**3, 0]) -def test_save_global_load_global(max_shard_size_gb: float): - model, optimizer = prepare_model_optim() - with TemporaryDirectory() as dir_name: - save(dir_name, model, optimizer, max_shard_size_gb=max_shard_size_gb) - new_model, new_optimizer = prepare_model_optim() - load(dir_name, new_model, new_optimizer, max_shard_size_gb=max_shard_size_gb) - check_model_state_dict(model.state_dict(), new_model.state_dict()) - check_optim_state_dict(optimizer.state_dict(), new_optimizer.state_dict()) - - -def run_dist(rank, world_size, port, test_fn): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - test_fn() - - -def launch_dist(fn, world_size: int): - spawn(run_dist, world_size, test_fn=fn) - - -def save_dist(dir_name: str, zero: bool): - model, optimizer = prepare_model_optim(shard=True, zero=zero) - reset_model_optim(model, optimizer) - world_size = dist.get_world_size() - rank = dist.get_rank() - save(dir_name, model, optimizer, dist_meta=get_dist_metas(world_size, zero)[rank]) - - -def load_and_check_dist(dir_name: str): - world_size = dist.get_world_size() - model, optimizer = prepare_model_optim(shard=True) - reset_model_optim(model, optimizer) - model_state_dict = deepcopy(model.state_dict()) - optimizer_state_dict = deepcopy(optimizer.state_dict()) - reset_model_optim(model, optimizer, 1) - load(dir_name, model, optimizer, get_redist_meta(world_size), get_dist_metas(world_size)) - check_model_state_dict(model_state_dict, model.state_dict()) - check_optim_state_dict(optimizer_state_dict, optimizer.state_dict()) - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_save_global_load_dist(): - model, optimizer = prepare_model_optim() - reset_model_optim(model, optimizer) - with TemporaryDirectory() as dir_name: - save(dir_name, model, optimizer) - fn = partial(load_and_check_dist, dir_name) - launch_dist(fn, 4) - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_save_dist_load_dist(): - with TemporaryDirectory() as dir_name: - # save tp + dp - fn = partial(save_dist, dir_name, False) - launch_dist(fn, 2) - # load tp + dp - fn = partial(load_and_check_dist, dir_name) - launch_dist(fn, 2) - with TemporaryDirectory() as dir_name: - # save tp + zero - fn = partial(save_dist, dir_name, True) - launch_dist(fn, 4) - # load tp + dp - fn = partial(load_and_check_dist, dir_name) - launch_dist(fn, 2) - launch_dist(fn, 4) - - -if __name__ == '__main__': - test_save_global_load_global(80 / 1024**3) - test_save_global_load_global(0) - test_save_global_load_dist() - test_save_dist_load_dist() diff --git a/tests/test_utils/test_checkpoint_io/test_merge.py b/tests/test_utils/test_checkpoint_io/test_merge.py deleted file mode 100644 index 07d4597f8391..000000000000 --- a/tests/test_utils/test_checkpoint_io/test_merge.py +++ /dev/null @@ -1,126 +0,0 @@ -import os -from functools import partial -from tempfile import TemporaryDirectory - -import pytest -import torch -import torch.distributed as dist -import torch.nn as nn -from torch.optim import Adam - -import colossalai -from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils.checkpoint_io.constant import GLOBAL_META_FILE_NAME -from colossalai.utils.checkpoint_io.io import merge, save -from colossalai.utils.checkpoint_io.meta import ParamDistMeta - - -class DummyModel(nn.Module): - - def __init__(self) -> None: - super().__init__() - self.fc = nn.Linear(20, 1) - - -def prepare_model_optim(shard: bool = False, zero: bool = False): - model = DummyModel() - if shard: - model.fc.weight.data = model.fc.weight.chunk(2, 1)[dist.get_rank() % 2] - if zero: - dp_rank = dist.get_rank() // 2 - model.fc.weight.data = model.fc.weight.reshape(-1).split([3, model.fc.weight.size(1) - 3], 0)[dp_rank] - if dp_rank != 0: - model.fc.bias.data = torch.empty(0, dtype=model.fc.bias.dtype) - for p in model.parameters(): - p.grad = torch.ones_like(p) - optimizer = Adam(model.parameters(), lr=1e-3) - optimizer.step() - return model, optimizer - - -def test_merge_global(): - model, optimizer = prepare_model_optim() - with TemporaryDirectory() as dir_name: - save(dir_name, model, optimizer) - with TemporaryDirectory() as output_dir: - merge(dir_name, output_dir) - assert len(os.listdir(output_dir)) == 0 - with TemporaryDirectory() as dir_name: - save(dir_name, model, optimizer, max_shard_size_gb=80 / 1024**3) - with TemporaryDirectory() as output_dir: - merge(dir_name, output_dir) - assert len(os.listdir(output_dir)) == 0 - - -def run_dist(rank, world_size, port, test_fn): - colossalai.launch(config={'parallel': { - 'tensor': { - 'mode': '1d', - 'size': 2 - } - }}, - rank=rank, - world_size=world_size, - host='localhost', - port=port, - backend='nccl') - test_fn() - - -def run_save_dist(dir_name: str, zero: bool): - model, optimizer = prepare_model_optim(shard=True, zero=zero) - rank = dist.get_rank() - dp_world_size = dist.get_world_size() // 2 - if not zero: - dist_metas = { - 'fc.weight': ParamDistMeta(rank // 2, dp_world_size, rank % 2, 2, tp_shard_dims=[1], tp_num_parts=[2]), - 'fc.bias': ParamDistMeta(rank // 2, dp_world_size, 0, 1) - } - else: - dist_metas = { - 'fc.weight': - ParamDistMeta(rank // 2, - dp_world_size, - rank % 2, - 2, - tp_shard_dims=[1], - tp_num_parts=[2], - zero_numel=10, - zero_orig_shape=[1, 10]), - 'fc.bias': - ParamDistMeta(rank // 2, dp_world_size, 0, 1, zero_numel=1, zero_orig_shape=[1]) - } - save(dir_name, model, optimizer, dist_meta=dist_metas) - - -@pytest.mark.dist -@pytest.mark.parametrize("zero", [False, True]) -@rerun_if_address_is_in_use() -def test_merge_tp_dp(zero: bool): - with TemporaryDirectory() as dir_name: - fn = partial(run_save_dist, dir_name, zero) - world_size = 4 - spawn(run_dist, world_size, test_fn=fn) - with TemporaryDirectory() as output_dir: - merge(dir_name, output_dir) - assert len(os.listdir(output_dir)) == 5 - global_meta = torch.load(os.path.join(output_dir, GLOBAL_META_FILE_NAME)) - assert len(global_meta['meta']) == 1 - meta = torch.load(os.path.join(output_dir, global_meta['meta'][0])) - assert meta['dist_meta'] is None - assert len(meta['params']) == 2 - assert len(meta['model']) == 1 and len(meta['optimizer']) == 1 - model_state_dict = torch.load(os.path.join(output_dir, meta['model'][0])) - assert len(model_state_dict) == 2 - assert model_state_dict['fc.weight'].size(1) == 20 - optimizer_state_dict = torch.load(os.path.join(output_dir, meta['optimizer'][0])) - assert len(optimizer_state_dict['state']) == 2 - assert 'param_groups' in optimizer_state_dict and 'state' in optimizer_state_dict - assert optimizer_state_dict['state'][0]['exp_avg'].size(1) == 20 - assert optimizer_state_dict['state'][0]['exp_avg_sq'].size(1) == 20 - - -if __name__ == '__main__': - test_merge_global() - test_merge_tp_dp(False) - test_merge_tp_dp(True) diff --git a/tests/test_utils/test_checkpoint_io/test_merge_param.py b/tests/test_utils/test_checkpoint_io/test_merge_param.py deleted file mode 100644 index 5da2ae4fe1f8..000000000000 --- a/tests/test_utils/test_checkpoint_io/test_merge_param.py +++ /dev/null @@ -1,101 +0,0 @@ -import torch -from colossalai.utils.checkpoint_io.meta import ParamDistMeta -from colossalai.utils.checkpoint_io.distributed import unflatten_zero_param, gather_tp_param, merge_param - - -def test_unflatten_zero_param_even() -> None: - dist_metas = [ParamDistMeta(i, 4, 0, 1, zero_numel=16, zero_orig_shape=[4, 4]) for i in range(4)] - orig_tensor = torch.rand(4, 4) - tensors = list(orig_tensor.reshape(-1).chunk(4)) - unflattened_tensor = unflatten_zero_param(tensors, dist_metas) - assert torch.equal(orig_tensor, unflattened_tensor) - merged_tensor = merge_param(tensors, dist_metas) - assert torch.equal(orig_tensor, merged_tensor) - - -def test_unflatten_zero_param_uneven() -> None: - dist_metas = [ParamDistMeta(i, 4, 0, 1, zero_numel=16, zero_orig_shape=[4, 4]) for i in range(1, 3)] - orig_tensor = torch.rand(4, 4) - tensors = list(orig_tensor.reshape(-1).split([13, 3])) - unflattened_tensor = unflatten_zero_param(tensors, dist_metas) - assert torch.equal(orig_tensor, unflattened_tensor) - merged_tensor = merge_param(tensors, dist_metas) - assert torch.equal(orig_tensor, merged_tensor) - - -def test_gather_tp_param_1d_row() -> None: - dist_metas = [ParamDistMeta(0, 1, i, 4, tp_shard_dims=[0], tp_num_parts=[4]) for i in range(4)] - orig_tensor = torch.rand(4, 4) - tensors = [t.contiguous() for t in orig_tensor.chunk(4, 0)] - gathered_tensor = gather_tp_param(tensors, dist_metas) - assert torch.equal(orig_tensor, gathered_tensor) - merged_tensor = merge_param(tensors, dist_metas) - assert torch.equal(orig_tensor, merged_tensor) - - -def test_gather_tp_param_1d_col() -> None: - dist_metas = [ParamDistMeta(0, 1, i, 4, tp_shard_dims=[1], tp_num_parts=[4]) for i in range(4)] - orig_tensor = torch.rand(4, 4) - tensors = [t.contiguous() for t in orig_tensor.chunk(4, 1)] - gathered_tensor = gather_tp_param(tensors, dist_metas) - assert torch.equal(orig_tensor, gathered_tensor) - merged_tensor = merge_param(tensors, dist_metas) - assert torch.equal(orig_tensor, merged_tensor) - - -def test_gather_tp_param_2d() -> None: - dist_metas = [ParamDistMeta(0, 1, i, 6, tp_shard_dims=[0, 1], tp_num_parts=[2, 3]) for i in range(6)] - orig_tensor = torch.rand(4, 6) - tensors = [t.contiguous() for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)] - gathered_tensor = gather_tp_param(tensors, dist_metas) - assert torch.equal(orig_tensor, gathered_tensor) - merged_tensor = merge_param(tensors, dist_metas) - assert torch.equal(orig_tensor, merged_tensor) - - -def test_gather_tp_param_2d_reverse() -> None: - dist_metas = [ParamDistMeta(0, 1, i, 6, tp_shard_dims=[1, 0], tp_num_parts=[3, 2]) for i in range(6)] - orig_tensor = torch.rand(4, 6) - tensors = [t.contiguous() for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)] - gathered_tensor = gather_tp_param(tensors, dist_metas) - assert torch.equal(orig_tensor, gathered_tensor) - merged_tensor = merge_param(tensors, dist_metas) - assert torch.equal(orig_tensor, merged_tensor) - - -def test_merge_param_hybrid() -> None: - dist_metas = [ - ParamDistMeta(i % 2, - 2, - i // 2, - 6, - tp_shard_dims=[1, 0], - tp_num_parts=[3, 2], - zero_numel=4, - zero_orig_shape=[2, 2]) for i in range(12) - ] - orig_tensor = torch.rand(4, 6) - tensors = [ - chunk for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1) - for chunk in t.contiguous().reshape(-1).split([1, 3]) - ] - merged_tensor = merge_param(tensors, dist_metas) - assert torch.equal(orig_tensor, merged_tensor) - - -def test_merge_param_dummy() -> None: - dist_metas = [ParamDistMeta(0, 1, 0, 1)] - orig_tensor = torch.rand(4, 6) - merged_tensor = merge_param([orig_tensor], dist_metas) - assert torch.equal(orig_tensor, merged_tensor) - - -if __name__ == '__main__': - test_unflatten_zero_param_even() - test_unflatten_zero_param_uneven() - test_gather_tp_param_1d_row() - test_gather_tp_param_1d_col() - test_gather_tp_param_2d() - test_gather_tp_param_2d_reverse() - test_merge_param_hybrid() - test_merge_param_dummy() diff --git a/tests/test_utils/test_checkpoint_io/test_redist.py b/tests/test_utils/test_checkpoint_io/test_redist.py deleted file mode 100644 index fdc849a5ecc0..000000000000 --- a/tests/test_utils/test_checkpoint_io/test_redist.py +++ /dev/null @@ -1,152 +0,0 @@ -import os -from functools import partial -from tempfile import TemporaryDirectory - -import pytest -import torch -import torch.distributed as dist -import torch.nn as nn -from torch.optim import Adam - -import colossalai -from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils.checkpoint_io.constant import GLOBAL_META_FILE_NAME -from colossalai.utils.checkpoint_io.io import redist, save -from colossalai.utils.checkpoint_io.meta import ( - ParamDistMeta, - ParamRedistMeta, - PipelineRedistMeta, - RankRedistMeta, - RedistMeta, -) - - -class DummyModel(nn.Module): - - def __init__(self) -> None: - super().__init__() - self.fc = nn.Linear(20, 1) - - -def prepare_model_optim(shard: bool = False, zero: bool = False): - model = DummyModel() - if shard: - model.fc.weight.data = model.fc.weight.chunk(2, 1)[dist.get_rank() % 2] - if zero: - dp_rank = dist.get_rank() // 2 - model.fc.weight.data = model.fc.weight.reshape(-1).split([3, model.fc.weight.size(1) - 3], 0)[dp_rank] - if dp_rank != 0: - model.fc.bias.data = torch.empty(0, dtype=model.fc.bias.dtype) - for p in model.parameters(): - p.grad = torch.ones_like(p) - optimizer = Adam(model.parameters(), lr=1e-3) - optimizer.step() - return model, optimizer - - -def get_dist_metas(nprocs: int, zero: bool = False): - dp_world_size = nprocs // 2 - dist_metas = [] - for rank in range(nprocs): - if zero: - dist_metas.append({ - 'fc.weight': - ParamDistMeta(rank // 2, - dp_world_size, - rank % 2, - 2, - tp_shard_dims=[1], - tp_num_parts=[2], - zero_numel=10, - zero_orig_shape=[1, 10]), - 'fc.bias': - ParamDistMeta(rank // 2, dp_world_size, 0, 1, zero_numel=1, zero_orig_shape=[1]) - }) - else: - dist_metas.append({ - 'fc.weight': ParamDistMeta(rank // 2, dp_world_size, rank % 2, 2, tp_shard_dims=[1], tp_num_parts=[2]), - 'fc.bias': ParamDistMeta(rank // 2, dp_world_size, 0, 1) - }) - return dist_metas - - -def get_redist_meta(nprocs: int): - dp_world_size = nprocs // 2 - rank_meta = { - 'fc.weight': {rank: RankRedistMeta(rank // 2, rank % 2, 0) for rank in range(nprocs)}, - 'fc.bias': {rank: RankRedistMeta(rank // 2, 0, 0) for rank in range(nprocs)} - } - param_meta = { - 'fc.weight': ParamRedistMeta(dp_world_size, 2, tp_shard_dims=[1], tp_num_parts=[2]), - 'fc.bias': ParamRedistMeta(dp_world_size, 1) - } - return RedistMeta(rank_meta, [], param_meta) - - -def check_checkpoint_shape(dir_name: str): - global_meta = torch.load(os.path.join(dir_name, GLOBAL_META_FILE_NAME)) - for meta_name in global_meta['meta']: - meta = torch.load(os.path.join(dir_name, meta_name)) - assert meta['dist_meta'] is not None - assert len(meta['params']) == 2 - assert len(meta['model']) == 1 and len(meta['optimizer']) == 1 - model_state_dict = torch.load(os.path.join(dir_name, meta['model'][0])) - assert len(model_state_dict) == 2 - assert model_state_dict['fc.weight'].size(1) == 10 - optimizer_state_dict = torch.load(os.path.join(dir_name, meta['optimizer'][0])) - assert len(optimizer_state_dict['state']) == 2 - assert 'param_groups' in optimizer_state_dict and 'state' in optimizer_state_dict - assert optimizer_state_dict['state'][0]['exp_avg'].size(1) == 10 - assert optimizer_state_dict['state'][0]['exp_avg_sq'].size(1) == 10 - - -def test_global_to_dist(): - model, optimizer = prepare_model_optim() - with TemporaryDirectory() as dir_name: - save(dir_name, model, optimizer) - with TemporaryDirectory() as output_dir: - redist(dir_name, output_dir, get_redist_meta(4), get_dist_metas(4)) - check_checkpoint_shape(output_dir) - - -def run_dist(rank, world_size, port, test_fn): - colossalai.launch(config={'parallel': { - 'tensor': { - 'mode': '1d', - 'size': 2 - } - }}, - rank=rank, - world_size=world_size, - host='localhost', - port=port, - backend='nccl') - test_fn() - - -def run_save_dist(dir_name: str, zero: bool): - model, optimizer = prepare_model_optim(shard=True, zero=zero) - rank = dist.get_rank() - save(dir_name, model, optimizer, dist_meta=get_dist_metas(4, zero)[rank]) - - -@pytest.mark.dist -@pytest.mark.parametrize("zero", [False, True]) -@rerun_if_address_is_in_use() -def test_dist_to_dist(zero: bool): - with TemporaryDirectory() as dir_name: - fn = partial(run_save_dist, dir_name, zero) - world_size = 4 - spawn(run_dist, world_size, test_fn=fn) - with TemporaryDirectory() as output_dir: - redist(dir_name, output_dir, get_redist_meta(4), get_dist_metas(4)) - if not zero: - assert len(os.listdir(output_dir)) == 0 - else: - check_checkpoint_shape(output_dir) - - -if __name__ == '__main__': - test_global_to_dist() - test_dist_to_dist(False) - test_dist_to_dist(True) diff --git a/tests/test_utils/test_checkpoint_io/test_save.py b/tests/test_utils/test_checkpoint_io/test_save.py deleted file mode 100644 index 2abdd95a6481..000000000000 --- a/tests/test_utils/test_checkpoint_io/test_save.py +++ /dev/null @@ -1,149 +0,0 @@ -import os -from functools import partial -from tempfile import TemporaryDirectory -from typing import Dict - -import pytest -import torch -import torch.distributed as dist -import torch.nn as nn -from torch import Tensor -from torch.optim import Adam - -import colossalai -from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils.checkpoint_io.constant import ( - GLOBAL_META_FILE_NAME, - META_CKPT_FILE_NAME, - MODEL_CKPT_FILE_NAME, - OTHER_CKPT_FILE_NAME, -) -from colossalai.utils.checkpoint_io.io import save -from colossalai.utils.checkpoint_io.meta import ParamDistMeta - - -def check_model_state_dict(a: Dict[str, Tensor], b: Dict[str, Tensor]) -> None: - assert set(a.keys()) == set(b.keys()) - for k, v in a.items(): - assert torch.equal(v, b[k]) - - -def check_optim_state_dict(a: dict, b: dict, ignore_param_groups: bool = False) -> None: - assert set(a['state'].keys()) == set(b['state'].keys()) - for k, state in a['state'].items(): - b_state = b['state'][k] - for v1, v2 in zip(state.values(), b_state.values()): - if isinstance(v1, Tensor): - assert torch.equal(v1, v2) - else: - assert v1 == v2 - if not ignore_param_groups: - assert a['param_groups'] == b['param_groups'] - - -class DummyModel(nn.Module): - - def __init__(self) -> None: - super().__init__() - self.fc = nn.Linear(20, 1) - - -def prepare_model_optim(): - model = DummyModel() - for p in model.parameters(): - p.grad = torch.ones_like(p) - optimizer = Adam(model.parameters(), lr=1e-3) - optimizer.step() - return model, optimizer - - -def test_overwrite(): - model = DummyModel() - with TemporaryDirectory() as dir_name: - with open(os.path.join(dir_name, MODEL_CKPT_FILE_NAME.replace('.bin', '-shard0.bin')), 'a') as f: - pass - with pytest.raises(RuntimeError, match=r'Save error: Checkpoint ".+" exists\. \(overwrite = False\)'): - save(dir_name, model) - - -def test_save_global(): - model, optimizer = prepare_model_optim() - with TemporaryDirectory() as dir_name: - save(dir_name, model, optimizer) - assert len(os.listdir(dir_name)) == 5 - global_meta = torch.load(os.path.join(dir_name, GLOBAL_META_FILE_NAME)) - assert len(global_meta['meta']) == 1 and global_meta['meta'][0] == META_CKPT_FILE_NAME - meta = torch.load(os.path.join(dir_name, META_CKPT_FILE_NAME)) - assert len(meta['model']) == 1 - assert len(meta['optimizer']) == 1 - model_state_dict = torch.load(os.path.join(dir_name, meta['model'][0])) - check_model_state_dict(model.state_dict(), model_state_dict) - optimizer_state_dict = torch.load(os.path.join(dir_name, meta['optimizer'][0])) - check_optim_state_dict(optimizer.state_dict(), optimizer_state_dict) - other_state_dict = torch.load(os.path.join(dir_name, OTHER_CKPT_FILE_NAME)) - assert len(other_state_dict) == 0 - - -def test_save_global_shard(): - model, optimizer = prepare_model_optim() - with TemporaryDirectory() as dir_name: - save(dir_name, model, optimizer, max_shard_size_gb=80 / 1024**3) - assert len(os.listdir(dir_name)) == 7 - meta = torch.load(os.path.join(dir_name, META_CKPT_FILE_NAME)) - assert len(meta['model']) == 2 and len(meta['optimizer']) == 2 - model_state_dicts = [torch.load(os.path.join(dir_name, name)) for name in meta['model']] - assert len(set(model_state_dicts[0].keys()) & set(model_state_dicts[1].keys())) == 0 - check_model_state_dict(model.state_dict(), {**model_state_dicts[0], **model_state_dicts[1]}) - optimizer_state_dicts = [torch.load(os.path.join(dir_name, name)) for name in meta['optimizer']] - assert len(set(optimizer_state_dicts[0]['state'].keys()) & set(optimizer_state_dicts[1]['state'].keys())) == 0 - assert 'param_groups' in optimizer_state_dicts[0] and 'param_groups' not in optimizer_state_dicts[1] - check_optim_state_dict( - optimizer.state_dict(), { - 'state': { - **optimizer_state_dicts[0]['state'], - **optimizer_state_dicts[1]['state'] - }, - 'param_groups': optimizer_state_dicts[0]['param_groups'] - }) - - -def run_dist(rank, world_size, port, test_fn): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - test_fn() - - -def run_save_dist(dir_name): - model, optimizer = prepare_model_optim() - dist_metas = { - 'fc.weight': ParamDistMeta(dist.get_rank(), dist.get_world_size(), 0, 1), - 'fc.bias': ParamDistMeta(dist.get_rank(), dist.get_world_size(), 0, 1) - } - save(dir_name, model, optimizer, dist_meta=dist_metas) - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_save_dist(): - with TemporaryDirectory() as dir_name: - fn = partial(run_save_dist, dir_name) - world_size = 2 - spawn(run_dist, world_size, test_fn=fn) - assert len(os.listdir(dir_name)) == 8 - global_meta = torch.load(os.path.join(dir_name, GLOBAL_META_FILE_NAME)) - assert len(global_meta['meta']) == 2 - for rank, meta_name in enumerate(global_meta['meta']): - meta = torch.load(os.path.join(dir_name, meta_name)) - assert meta.get('dist_meta', None) is not None - assert len(meta['model']) == 1 and len(meta['optimizer']) == 1 - model_state_dict = torch.load(os.path.join(dir_name, meta['model'][0])) - assert len(model_state_dict) == 2 - optimizer_state_dict = torch.load(os.path.join(dir_name, meta['optimizer'][0])) - assert len(optimizer_state_dict['state']) == 2 - assert 'param_groups' in optimizer_state_dict - - -if __name__ == '__main__': - test_overwrite() - test_save_global() - test_save_global_shard() - test_save_dist() diff --git a/tests/test_utils/test_checkpoint_io/test_unmerge_param.py b/tests/test_utils/test_checkpoint_io/test_unmerge_param.py deleted file mode 100644 index 8b83caa12359..000000000000 --- a/tests/test_utils/test_checkpoint_io/test_unmerge_param.py +++ /dev/null @@ -1,137 +0,0 @@ -import torch -from colossalai.utils.checkpoint_io.meta import ParamRedistMeta -from colossalai.utils.checkpoint_io.distributed import flatten_zero_param, split_tp_param, unmerge_param - - -def test_flatten_zero_param_even() -> None: - redist_meta = ParamRedistMeta(4, 1, zero_start_dp_rank=0, zero_offsets=[0, 4, 8, 12]) - orig_tensor = torch.rand(4, 4) - tensors = list(orig_tensor.reshape(-1).chunk(4)) - flat_tensors = flatten_zero_param(orig_tensor, redist_meta) - assert len(tensors) == len(flat_tensors) - for t, st in zip(tensors, flat_tensors): - assert torch.equal(t, st) - unmerged_tensors = unmerge_param(orig_tensor, redist_meta) - assert len(unmerged_tensors) == 1 - unmerged_tensors = unmerged_tensors[0] - assert len(tensors) == len(unmerged_tensors) - for t, tl in zip(tensors, unmerged_tensors): - assert torch.equal(t, tl) - - -def test_flatten_zero_param_uneven() -> None: - redist_meta = ParamRedistMeta(4, 1, zero_start_dp_rank=1, zero_offsets=[0, 13]) - orig_tensor = torch.rand(4, 4) - tensors = list(orig_tensor.reshape(-1).split([13, 3])) - flat_tensors = flatten_zero_param(orig_tensor, redist_meta) - assert flat_tensors[0].size(0) == 0 and flat_tensors[-1].size(0) == 0 - flat_tensors = flat_tensors[1:-1] - assert len(tensors) == len(flat_tensors) - for t, st in zip(tensors, flat_tensors): - assert torch.equal(t, st) - unmerged_tensors = unmerge_param(orig_tensor, redist_meta) - assert len(unmerged_tensors) == 1 - unmerged_tensors = unmerged_tensors[0] - assert unmerged_tensors[0].size(0) == 0 and unmerged_tensors[-1].size(0) == 0 - unmerged_tensors = unmerged_tensors[1:-1] - assert len(tensors) == len(unmerged_tensors) - for t, tl in zip(tensors, unmerged_tensors): - assert torch.equal(t, tl) - - -def test_split_tp_param_1d_row() -> None: - redist_meta = ParamRedistMeta(1, 4, tp_shard_dims=[0], tp_num_parts=[4]) - orig_tensor = torch.rand(4, 4) - tensors = [t.contiguous() for t in orig_tensor.chunk(4, 0)] - split_tensors = split_tp_param(orig_tensor, redist_meta) - assert len(tensors) == len(split_tensors) - for t, st in zip(tensors, split_tensors): - assert torch.equal(t, st) - unmerged_tensors = unmerge_param(orig_tensor, redist_meta) - assert len(tensors) == len(unmerged_tensors) - for t, tl in zip(tensors, unmerged_tensors): - assert len(tl) == 1 - assert torch.equal(t, tl[0]) - - -def test_split_tp_param_1d_col() -> None: - redist_meta = ParamRedistMeta(1, 4, tp_shard_dims=[1], tp_num_parts=[4]) - orig_tensor = torch.rand(4, 4) - tensors = [t.contiguous() for t in orig_tensor.chunk(4, 1)] - split_tensors = split_tp_param(orig_tensor, redist_meta) - assert len(tensors) == len(split_tensors) - for t, st in zip(tensors, split_tensors): - assert torch.equal(t, st) - unmerged_tensors = unmerge_param(orig_tensor, redist_meta) - assert len(tensors) == len(unmerged_tensors) - for t, tl in zip(tensors, unmerged_tensors): - assert len(tl) == 1 - assert torch.equal(t, tl[0]) - - -def test_split_tp_param_2d() -> None: - redist_meta = ParamRedistMeta(1, 6, tp_shard_dims=[0, 1], tp_num_parts=[2, 3]) - orig_tensor = torch.rand(4, 6) - tensors = [t.contiguous() for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)] - split_tensors = split_tp_param(orig_tensor, redist_meta) - assert len(tensors) == len(split_tensors) - for t, st in zip(tensors, split_tensors): - assert torch.equal(t, st) - unmerged_tensors = unmerge_param(orig_tensor, redist_meta) - assert len(tensors) == len(unmerged_tensors) - for t, tl in zip(tensors, unmerged_tensors): - assert len(tl) == 1 - assert torch.equal(t, tl[0]) - - -def test_split_tp_param_2d_reverse() -> None: - redist_meta = ParamRedistMeta(1, 6, tp_shard_dims=[1, 0], tp_num_parts=[3, 2]) - orig_tensor = torch.rand(4, 6) - tensors = [t.contiguous() for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)] - split_tensors = split_tp_param(orig_tensor, redist_meta) - assert len(tensors) == len(split_tensors) - for t, st in zip(tensors, split_tensors): - assert torch.equal(t, st) - unmerged_tensors = unmerge_param(orig_tensor, redist_meta) - assert len(tensors) == len(unmerged_tensors) - for t, tl in zip(tensors, unmerged_tensors): - assert len(tl) == 1 - assert torch.equal(t, tl[0]) - - -def test_unmerge_param_hybrid() -> None: - redist_meta = ParamRedistMeta(2, - 6, - tp_shard_dims=[1, 0], - tp_num_parts=[3, 2], - zero_start_dp_rank=0, - zero_offsets=[0, 1]) - orig_tensor = torch.rand(4, 6) - tensors = [ - chunk for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1) - for chunk in t.contiguous().reshape(-1).split([1, 3]) - ] - unmerged_tensors = unmerge_param(orig_tensor, redist_meta) - assert len(unmerged_tensors) == 6 and len(unmerged_tensors[0]) == 2 - for tp_rank in range(6): - for dp_rank in range(2): - assert torch.equal(tensors[tp_rank * 2 + dp_rank], unmerged_tensors[tp_rank][dp_rank]) - - -def test_unmerge_param_dummy() -> None: - redist_meta = ParamRedistMeta(1, 1) - orig_tensor = torch.rand(4, 6) - unmerged_tensors = unmerge_param(orig_tensor, redist_meta) - assert len(unmerged_tensors) == 1 and len(unmerged_tensors[0]) == 1 - assert torch.equal(orig_tensor, unmerged_tensors[0][0]) - - -if __name__ == '__main__': - test_flatten_zero_param_even() - test_flatten_zero_param_uneven() - test_split_tp_param_1d_row() - test_split_tp_param_1d_col() - test_split_tp_param_2d() - test_split_tp_param_2d_reverse() - test_unmerge_param_hybrid() - test_unmerge_param_dummy() diff --git a/tests/test_zero/test_legacy/common.py b/tests/test_zero/test_legacy/common.py deleted file mode 100644 index 2c3d122c79af..000000000000 --- a/tests/test_zero/test_legacy/common.py +++ /dev/null @@ -1,140 +0,0 @@ -from functools import partial - -import torch -import torch.distributed as dist - -from colossalai.logging import get_dist_logger -from colossalai.utils import checkpoint -from colossalai.zero.legacy.shard_utils import TensorShardStrategy -from colossalai.zero.legacy.sharded_model import ShardedModelV2 - -LOGGER = get_dist_logger('zero_test') - -MP_PARALLEL_CONFIG = dict(fp16=dict(mode=None,), parallel=dict(pipeline=dict(size=1), tensor=dict(size=2, mode=None))) - -_ZERO_MODEL_CONFIG = dict(reduce_scatter_bucket_size_mb=25, - fp32_reduce_scatter=False, - tensor_placement_policy='cuda', - gradient_predivide_factor=1.0, - shard_strategy=TensorShardStrategy(), - reuse_fp16_shard=False) - -_ZERO_OPTIMIZER_CONFIG = dict(initial_scale=2**5, - min_scale=1, - growth_factor=2, - backoff_factor=0.5, - growth_interval=1000, - hysteresis=2, - max_scale=2**32) - -ZERO_PARALLEL_CONFIG = dict(fp16=dict(mode=None,), - zero=dict( - model_config=_ZERO_MODEL_CONFIG, - optimizer_config=_ZERO_OPTIMIZER_CONFIG, - ), - parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None))) - -CONFIG = dict(fp16=dict(mode=None,), - zero=dict(level=3, - verbose=False, - offload_optimizer_config=dict(device='cpu', pin_memory=True, buffer_count=5, fast_init=False), - offload_param_config=dict(device='cpu', - pin_memory=True, - buffer_count=5, - buffer_size=1e8, - max_in_cpu=1e9)), - parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None))) - - -def run_fwd_bwd(model, data, label, criterion, enable_autocast=False): - model.train() - with torch.cuda.amp.autocast(enabled=enable_autocast): - if criterion: - y = model(data) - loss = criterion(y, label) - else: - loss = model(data, label) - loss = loss.float() - if isinstance(model, ShardedModelV2): - model.backward(loss) - else: - loss.backward() - - -def checkpoint_wrapper(module, enable=True): - if enable: - module.forward = partial(checkpoint, module.forward) - return module - - -def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool: - if loose: - return torch.allclose(tensor_a, tensor_b, atol=1e-2, rtol=1e-3) - return torch.allclose(tensor_a, tensor_b) - - -def check_grads(model, zero_model, loose=False): - for p, zero_p in zip(model.parameters(), zero_model.parameters()): - zero_grad = zero_p.grad.clone().to(p.device) - grad = p.grad.float() - assert grad.dtype == zero_grad.dtype - assert allclose(grad, zero_grad, loose=loose) - - -def check_params(model, zero_model, loose=False): - for p, zero_p in zip(model.parameters(), zero_model.parameters()): - zero_p = zero_p.clone().to(p.device) - # assert p.dtype == zero_p.dtype - assert allclose(p.float(), zero_p.float(), loose=loose), f"diff {p.float() - zero_p.float()}" - - -def check_grads_padding(model, zero_model, loose=False): - rank = dist.get_rank() - for (name, p), (zero_name, zero_p) in zip(model.named_parameters(), zero_model.named_parameters()): - # zero_grad = zero_p.grad.clone().to(p.device) - if zero_p.colo_attr.is_replicated: - zero_grad = zero_p.colo_attr.grad_payload.clone().to(p.device) - chunks = torch.flatten(p.grad).chunk(dist.get_world_size()) - if rank >= len(chunks): - continue - grad = chunks[rank].float() - if zero_grad.size(0) > grad.size(0): - zero_grad = zero_grad[:grad.size(0)] - else: - zero_grad = zero_p.colo_attr.grad_payload - grad = p.grad.to(zero_grad.dtype) - - assert grad.dtype == zero_grad.dtype - assert allclose(grad, zero_grad, loose=loose), f'diff: {grad - zero_grad}' - - -def check_params_padding(model, zero_model, loose=False): - rank = dist.get_rank() - for p, zero_p in zip(model.parameters(), zero_model.parameters()): - zero_p = zero_p.clone().to(p.device) - chunks = torch.flatten(p).chunk(dist.get_world_size()) - if rank >= len(chunks): - continue - p = chunks[rank] - if zero_p.size(0) > p.size(0): - zero_p = zero_p[:p.size(0)] - assert p.dtype == zero_p.dtype - assert allclose(p, zero_p, loose=loose) - - -def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard=False): - rank = dist.get_rank() - for (name, p), (zero_name, zero_p) in zip(model.named_parameters(), zero_model.named_parameters()): - if zero_p.colo_attr.param_is_sharded: - zero_p = zero_p.colo_attr.data_payload.to(p.device).float() - chunks = torch.flatten(p).chunk(dist.get_world_size()) - if rank >= len(chunks): - continue - p = chunks[rank].float() - if zero_p.size(0) > p.size(0): - zero_p = zero_p[:p.size(0)] - else: - zero_p = zero_p.colo_attr.data_payload.to(p.device) - - assert p.dtype == zero_p.dtype, "Parameter `{}`:\n{} vs {}".format(name, p.dtype, zero_p.dtype) - assert allclose(p, zero_p, loose=loose), f'{p} vs {zero_p}' diff --git a/tests/test_zero/test_legacy/test_found_inf.py b/tests/test_zero/test_legacy/test_found_inf.py deleted file mode 100644 index e90158e0a43b..000000000000 --- a/tests/test_zero/test_legacy/test_found_inf.py +++ /dev/null @@ -1,67 +0,0 @@ -import pytest -import torch -from common import CONFIG -from test_sharded_optim_v2 import _run_step - -import colossalai -from colossalai.nn.optimizer import HybridAdam -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.utils.cuda import get_current_device -from colossalai.zero.legacy.init_ctx import ZeroInitContext -from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy -from colossalai.zero.legacy.sharded_model import ShardedModelV2 -from colossalai.zero.legacy.sharded_optim import ShardedOptimizerV2 -from colossalai.zero.low_level._utils import has_inf_or_nan -from tests.components_to_test.registry import non_distributed_component_funcs - - -@parameterize("cpu_offload", [True, False]) -@parameterize("shard_strategy_class", [BucketTensorShardStrategy]) -@parameterize("gpu_margin_mem_ratio", [0.0, 0.7]) -def _run_test_found_inf(cpu_offload, shard_strategy_class, gpu_margin_mem_ratio): - test_models = ['repeated_computed_layers'] - shard_strategy = shard_strategy_class() - - for model_name in test_models: - get_components_func = non_distributed_component_funcs.get_callable(model_name) - model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func() - - with ZeroInitContext(target_device=torch.device(f'cpu:0') if cpu_offload else get_current_device(), - shard_strategy=shard_strategy, - shard_param=True): - zero_model = model_builder(checkpoint=True) - zero_model = ShardedModelV2( - zero_model, - shard_strategy, - tensor_placement_policy='cpu' if cpu_offload else 'cuda', - reuse_fp16_shard=True, - ) - - sharded_optim = HybridAdam(zero_model.parameters(), lr=1e-3) - sharded_optim = ShardedOptimizerV2(zero_model, sharded_optim, gpu_margin_mem_ratio=gpu_margin_mem_ratio) - - for i, (data, label) in enumerate(train_dataloader): - if i > 1: - break - assert zero_model.overflow_counter == 0 - data, label = data.cuda(), label.cuda() - _run_step(zero_model, sharded_optim, data, label, criterion, False) - for param in zero_model.parameters(): - assert not has_inf_or_nan(param.colo_attr.data_payload) - - -def _run_dist(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - _run_test_found_inf() - - -# use_cpuadam = True can be used with cpu_offload = False -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [1, 2]) -@rerun_if_address_is_in_use() -def test_found_inf(world_size): - spawn(_run_dist, world_size) - - -if __name__ == '__main__': - test_found_inf(world_size=2) diff --git a/tests/test_zero/test_legacy/test_gemini_manager.py b/tests/test_zero/test_legacy/test_gemini_manager.py deleted file mode 100644 index 0e956f7cc617..000000000000 --- a/tests/test_zero/test_legacy/test_gemini_manager.py +++ /dev/null @@ -1,75 +0,0 @@ -import pytest -import torch - -from colossalai.testing import clear_cache_before_run -from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor, TensorState - - -@pytest.mark.dist -@clear_cache_before_run() -def test_gemini_manager(): - # reset the manager, in case that there exists memory information left - manager = StatefulTensor.GST_MGR - manager.reset() - - # occupation 8 - st1 = StatefulTensor(torch.empty(2, 2, dtype=torch.float16, device='cuda')) - # occupation 60 - st2 = StatefulTensor(torch.empty(3, 5, dtype=torch.float32, device='cpu')) - - # occupation 28 - t1 = torch.empty(7, device='cuda') - # occupation 12 - t2 = torch.empty(3, device='cpu') - st3 = StatefulTensor(t1, TensorState.HOLD_AFTER_FWD) - st4 = StatefulTensor(None, TensorState.FREE) - - assert manager.total_number == 4 - assert manager.total_mem['cpu'] == 60 - assert manager.total_mem['cuda'] == 36 - assert manager.state_mem['cpu'][TensorState.HOLD] == 60 - assert manager.state_mem['cuda'][TensorState.HOLD] == 8 - assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 28 - - st4.payload_reset(t2) - st3.payload_reset(t2) - - assert manager.total_number == 4 - assert manager.total_mem['cpu'] == 84 - assert manager.total_mem['cuda'] == 8 - assert manager.state_mem['cpu'][TensorState.HOLD] == 72 - assert manager.state_mem['cuda'][TensorState.HOLD] == 8 - assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 12 - assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 0 - - st1.move_to(torch.device('cpu')) - st2.move_to(torch.device('cpu')) - st3.move_to(torch.device('cuda', 0)) - - assert manager.total_number == 4 - assert manager.total_mem['cpu'] == 80 - assert manager.total_mem['cuda'] == 12 - assert manager.state_mem['cpu'][TensorState.HOLD] == 80 - assert manager.state_mem['cuda'][TensorState.HOLD] == 0 - assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 0 - assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 12 - - st1.trans_state(TensorState.COMPUTE) - st2.trans_state(TensorState.COMPUTE) - st2.trans_state(TensorState.HOLD_AFTER_BWD) - - assert manager.total_number == 4 - assert manager.total_mem['cpu'] == 80 - assert manager.total_mem['cuda'] == 12 - assert manager.state_mem['cpu'][TensorState.HOLD] == 12 - assert manager.state_mem['cuda'][TensorState.HOLD] == 0 - assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 0 - assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 12 - assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_BWD] == 60 - assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_BWD] == 0 - assert manager.state_mem['cpu'][TensorState.COMPUTE] == 8 - assert manager.state_mem['cuda'][TensorState.COMPUTE] == 0 - - -if __name__ == '__main__': - test_gemini_manager() diff --git a/tests/test_zero/test_legacy/test_init_context.py b/tests/test_zero/test_legacy/test_init_context.py deleted file mode 100644 index 84493827193e..000000000000 --- a/tests/test_zero/test_legacy/test_init_context.py +++ /dev/null @@ -1,73 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import pytest -import torch -from common import CONFIG - -import colossalai -from colossalai.logging import get_dist_logger -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.utils.cuda import get_current_device -from colossalai.utils.memory import colo_device_memory_used -from colossalai.zero.gemini.memory_tracer.utils import colo_model_mem_usage -from colossalai.zero.legacy.init_ctx import ZeroInitContext -from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy -from tests.components_to_test.registry import non_distributed_component_funcs - - -@parameterize("init_device_type", ['cpu', 'cuda']) -@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) -def run_model_test(init_device_type, shard_strategy_class): - logger = get_dist_logger("test_zero_init") - - for name, get_components_func in non_distributed_component_funcs._registry.items(): - # because the ZeroInitContext automatically turns parameters to fp16 - # and the beit model use tensor.erfinv_() function to initialize weights - # tensor.erfinv_() doesn't support Half in CPU, we omit the beit model - if name == 'beit': - continue - model_builder, _, _, _, _ = get_components_func() - if init_device_type == 'cuda': - init_device = get_current_device() - elif init_device_type == 'cpu': - init_device = torch.device("cpu") - else: - continue - - model_numel_tensor = torch.zeros(1, dtype=torch.int) - with ZeroInitContext(target_device=init_device, - shard_strategy=shard_strategy_class(), - shard_param=True, - model_numel_tensor=model_numel_tensor): - model = model_builder(checkpoint=True) - - for param in model.parameters(): - assert hasattr(param, 'colo_attr') - assert param.colo_attr.sharded_data_tensor.dtype == torch.half - assert param.colo_attr.sharded_data_tensor.is_sharded - assert param.colo_attr.data_payload.device.type == init_device.type, \ - f'{param.colo_attr.data_payload.device.type} vs. {init_device.type}' - - cuda_mem_use, _ = colo_model_mem_usage(model) - model_data_cuda_mem_MB = cuda_mem_use / 1e6 - logger.info(f"Existing ZeRO Context.\nModel Data CUDA Memory {model_data_cuda_mem_MB} MB", ranks=[0]) - sys_cuda_mem_MB = colo_device_memory_used(get_current_device()) / 1e6 - logger.info(f"System CUDA Memory Usage {sys_cuda_mem_MB} MB", ranks=[0]) - logger.info(f"Model Number Parameter {model_numel_tensor.numpy()[0]/1e6} M", ranks=[0]) - - -def run_dist(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_model_test() - - -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [1, 4]) -@rerun_if_address_is_in_use() -def test_zero_init_context(world_size): - spawn(run_dist, world_size) - - -if __name__ == '__main__': - test_zero_init_context(1) diff --git a/tests/test_zero/test_legacy/test_param_op.py b/tests/test_zero/test_legacy/test_param_op.py deleted file mode 100644 index b91371b98922..000000000000 --- a/tests/test_zero/test_legacy/test_param_op.py +++ /dev/null @@ -1,82 +0,0 @@ -import copy - -import torch - -from colossalai.testing import clear_cache_before_run -from colossalai.zero.legacy.gemini.paramhooks import BaseParamHookMgr -from tests.components_to_test.registry import non_distributed_component_funcs - - -def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool: - if loose: - return torch.allclose(tensor_a, tensor_b, atol=1e-3, rtol=1e-3) - return torch.allclose(tensor_a, tensor_b) - - -def run_model(model, inputs, label, criterion, use_param_hook=False): - if use_param_hook: - - class HooKWrapper: - - def __init__(self) -> None: - self.hook_triggered_times = 0 - - def wrapper_func(self): - - def hook(param, grad) -> torch.Tensor or None: - self.hook_triggered_times += 1 - return grad - - return hook - - hookwrapper = HooKWrapper() - param_list = [p for p in model.parameters()] - hook_mgr = BaseParamHookMgr(param_list) - hook_mgr.register_backward_hooks(hookwrapper.wrapper_func()) - - model.zero_grad(set_to_none=True) - - with torch.cuda.amp.autocast(): - if criterion: - y = model(inputs) - loss = criterion(y, label) - else: - loss = model(inputs, label) - loss = loss.float() - loss.backward() - - if use_param_hook: - hook_mgr.remove_hooks() - return hookwrapper.hook_triggered_times - - -@clear_cache_before_run() -def test_base_param_hook(): - test_models = ['repeated_computed_layers', 'resnet18', 'hanging_param_model', 'inline_op_model'] - # test_models = ['bert'] - - for model_name in test_models: - get_components_func = non_distributed_component_funcs.get_callable(model_name) - model_builder, train_dataloader, _, _, criterion = get_components_func() - - torch.manual_seed(0) - model = model_builder(checkpoint=True).cuda() - model.train() - - for i, (inputs, label) in enumerate(train_dataloader): - if i > 0: - break - model_copy = copy.deepcopy(model) - - run_model(model, inputs.cuda(), label.cuda(), criterion, False) - ret2 = run_model(model_copy, inputs.cuda(), label.cuda(), criterion, True) - - # Make sure param hook has only be fired once in case of parameter sharing - assert ret2 == len(list(model.parameters())) - - for p, p_copy in zip(model.parameters(), model_copy.parameters()): - assert allclose(p.grad, p_copy.grad), f"{p.grad} vs {p_copy.grad}" - - -if __name__ == '__main__': - test_base_param_hook() diff --git a/tests/test_zero/test_legacy/test_shard_model_v2.py b/tests/test_zero/test_legacy/test_shard_model_v2.py deleted file mode 100644 index 93d624aa2bbd..000000000000 --- a/tests/test_zero/test_legacy/test_shard_model_v2.py +++ /dev/null @@ -1,64 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import pytest -import torch -from common import CONFIG, check_grads_padding, run_fwd_bwd -from torch.nn.parallel import DistributedDataParallel as DDP - -import colossalai -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.zero.legacy.init_ctx import ZeroInitContext -from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy -from colossalai.zero.legacy.sharded_model import ShardedModelV2 -from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_fp16 -from colossalai.zero.legacy.sharded_model.utils import col_model_deepcopy -from tests.components_to_test.registry import non_distributed_component_funcs - - -@parameterize("enable_autocast", [True]) -@parameterize("shard_strategy_class", [BucketTensorShardStrategy]) -def run_model_test(enable_autocast, shard_strategy_class): - test_models = ['repeated_computed_layers', 'resnet18', 'bert', 'hanging_param_model'] - shard_strategy = shard_strategy_class() - for model_name in test_models: - get_components_func = non_distributed_component_funcs.get_callable(model_name) - model_builder, train_dataloader, _, _, criterion = get_components_func() - - with ZeroInitContext(target_device=torch.device('cuda', torch.cuda.current_device()), - shard_strategy=shard_strategy, - shard_param=True): - zero_model = model_builder(checkpoint=True) - zero_model = ShardedModelV2(zero_model, shard_strategy) - - model = model_builder(checkpoint=True).half() - col_model_deepcopy(zero_model, model) - model = model.cuda() - - model = DDP(model, device_ids=[torch.cuda.current_device()]) - - for i, (data, label) in enumerate(train_dataloader): - if i > 5: - break - - data, label = cast_tensor_to_fp16(data).cuda(), label.cuda() - run_fwd_bwd(model, data, label, criterion, enable_autocast) - run_fwd_bwd(zero_model, data, label, criterion, enable_autocast) - - check_grads_padding(model, zero_model, loose=True) - - -def run_dist(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_model_test() - - -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [1, 2]) -@rerun_if_address_is_in_use() -def test_shard_model_v2(world_size): - spawn(run_dist, world_size) - - -if __name__ == '__main__': - test_shard_model_v2(world_size=2) diff --git a/tests/test_zero/test_legacy/test_shard_param.py b/tests/test_zero/test_legacy/test_shard_param.py deleted file mode 100644 index 4ba43edceb5d..000000000000 --- a/tests/test_zero/test_legacy/test_shard_param.py +++ /dev/null @@ -1,91 +0,0 @@ -from copy import deepcopy - -import pytest -import torch -from common import CONFIG, allclose - -import colossalai -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor -from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy -from colossalai.zero.legacy.sharded_param import ShardedTensor -from colossalai.zero.legacy.sharded_param.sharded_param import ShardedParamV2 - - -@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) -def run_shard_tensor_with_strategy(shard_strategy_class, world_size): - t = ShardedTensor(tensor=torch.randn(world_size * 2, 3)) - assert list(t.origin_shape) == [world_size * 2, 3] - assert list(t.shape) == [world_size * 2, 3] - - shard_strategy = shard_strategy_class() - - # test shard strategy - shard_strategy.shard([t]) - assert list(t.shape) == [6], f"{list(t.shape)} vs 6" - shard_strategy.gather([t]) - assert list(t.shape) == [world_size * 2, 3], f"{list(t.shape)} vs {[world_size * 2, 3]}" - - -def _run_shard_tensor(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_shard_tensor_with_strategy(world_size=world_size) - - -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [1, 2]) -@rerun_if_address_is_in_use() -def test_shard_tensor(world_size): - spawn(_run_shard_tensor, world_size) - - -def _run_shard_param_v2(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - - param = torch.nn.Parameter(torch.randn(2, 3)) - param_ref = deepcopy(param) - sparam = ShardedParamV2(param=param) - - allclose(sparam.data_payload, param_ref.data) - - # Test get memory usage - sparam.saved_grad = StatefulTensor(torch.randn(2, 3)) - cuda_mem_use, cpu_mem_use = sparam.get_memory_usage() - assert cpu_mem_use == 2 * 3 * 4 * 2, f"cpu_mem_use: {cpu_mem_use}" - - sparam.set_data_none() - assert (param.data.numel() == 0) - cuda_mem_use, cpu_mem_use = sparam.get_memory_usage() - # 4 is size of dummy tensor of param.data - assert cpu_mem_use == 2 * 3 * 4 * 2 - - sparam.saved_grad = StatefulTensor(torch.randn(2, 3)) - sparam.set_data_none() - cuda_mem_use, cpu_mem_use = sparam.get_memory_usage() - assert cpu_mem_use == 2 * 3 * 4 * 2 - assert cuda_mem_use == 0 - - # append a grad to torch param - param.data = sparam.data_payload - param.grad = torch.randn(2, 3) - cuda_mem_use, cpu_mem_use = sparam.get_memory_usage() - assert cpu_mem_use == 2 * 3 * 4 * 2 + 2 * 3 * 4, f"cpu_mem_use {cpu_mem_use}" - assert cuda_mem_use == 0 - - # reuse torch grad for sparam - sparam.saved_grad = StatefulTensor(param.grad) - cuda_mem_use, cpu_mem_use = sparam.get_memory_usage() - assert cpu_mem_use == 2 * 3 * 4 * 2 - assert cuda_mem_use == 0 - - -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [1, 2]) -@rerun_if_address_is_in_use() -def test_shard_param_v2(world_size): - spawn(_run_shard_param_v2, world_size) - - -if __name__ == '__main__': - # test_shard_tensor(2) - test_shard_param_v2(2) diff --git a/tests/test_zero/test_legacy/test_sharded_optim_state_dict.py b/tests/test_zero/test_legacy/test_sharded_optim_state_dict.py deleted file mode 100644 index 1ca144662722..000000000000 --- a/tests/test_zero/test_legacy/test_sharded_optim_state_dict.py +++ /dev/null @@ -1,89 +0,0 @@ -import pytest -import torch - -import colossalai -from colossalai.nn.optimizer import HybridAdam -from colossalai.tensor import ProcessGroup -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.utils.cuda import get_current_device -from colossalai.zero.legacy.init_ctx import ZeroInitContext -from colossalai.zero.legacy.shard_utils import TensorShardStrategy -from colossalai.zero.legacy.sharded_model import ShardedModelV2 -from colossalai.zero.legacy.sharded_optim import ShardedOptimizerV2 -from tests.components_to_test.registry import non_distributed_component_funcs -from tests.test_tensor.common_utils import set_seed - - -def init_zero(model_builder, placement_policy): - device = get_current_device() if placement_policy == 'cuda' else torch.device('cpu') - shard_strategy = TensorShardStrategy() - with ZeroInitContext(target_device=device, shard_strategy=shard_strategy, shard_param=True): - model = model_builder() - model = ShardedModelV2( - model, - shard_strategy, - tensor_placement_policy=placement_policy, - reuse_fp16_shard=True, - ) - optim = HybridAdam(model.parameters(), lr=1e-3) - optim = ShardedOptimizerV2(model, optim, initial_scale=32) - return model, optim - - -def run_step(model, optim, criterion, data, label): - optim.zero_grad() - logits = model(data) - loss = criterion(logits, label) - optim.backward(loss) - optim.step() - - -def check_state_dict_eq(state_dict, other): - for p, state in state_dict['state'].items(): - other_state = other['state'][p] - for k, v in state.items(): - if isinstance(v, torch.Tensor): - assert torch.allclose(v, other_state[k], atol=1e-3), f'{v} vs {other_state[k]}' - else: - assert v == other_state[k] - - -@parameterize('placement_policy', ['cuda', 'cpu']) -def run_nested_model(placement_policy): - get_components_func = non_distributed_component_funcs.get_callable('simple_net') - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() - - set_seed(42) - model, optim = init_zero(model_builder, placement_policy) - set_seed(42) - model_copy, optim_copy = init_zero(model_builder, placement_policy) - - model.train() - model_copy.train() - pg = ProcessGroup() - set_seed(pg.dp_local_rank()) - data_iter = iter(train_dataloader) - - data, label = map(lambda x: x.cuda(), next(data_iter)) - run_step(model, optim, criterion, data, label) - optim_copy.load_state_dict(optim.state_dict()) - check_state_dict_eq(optim.state_dict(), optim_copy.state_dict()) - - data, label = map(lambda x: x.cuda(), next(data_iter)) - run_step(model_copy, optim_copy, criterion, data, label) - - -def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_nested_model() - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 2]) -@rerun_if_address_is_in_use() -def test_sharded_optim_state_dist(world_size): - spawn(run_dist, world_size) - - -if __name__ == '__main__': - test_sharded_optim_state_dist(2) diff --git a/tests/test_zero/test_legacy/test_sharded_optim_v2.py b/tests/test_zero/test_legacy/test_sharded_optim_v2.py deleted file mode 100644 index c6f77995ebcd..000000000000 --- a/tests/test_zero/test_legacy/test_sharded_optim_v2.py +++ /dev/null @@ -1,110 +0,0 @@ -import pytest -import torch -import torch.distributed as dist -from common import CONFIG, check_sharded_model_params -from torch.nn.parallel import DistributedDataParallel as DDP - -import colossalai -from colossalai.amp import convert_to_apex_amp -from colossalai.nn.optimizer import CPUAdam -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.utils.cuda import get_current_device -from colossalai.zero.legacy.init_ctx import ZeroInitContext -from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy -from colossalai.zero.legacy.sharded_model import ShardedModelV2 -from colossalai.zero.legacy.sharded_model.utils import col_model_deepcopy -from colossalai.zero.legacy.sharded_optim import ShardedOptimizerV2 -from colossalai.zero.low_level._utils import has_inf_or_nan -from tests.components_to_test.registry import non_distributed_component_funcs - - -def _run_step(model, optimizer, data, label, criterion, enable_autocast=False): - model.train() - optimizer.zero_grad() - with torch.cuda.amp.autocast(enabled=enable_autocast): - if criterion: - y = model(data) - loss = criterion(y, label) - else: - loss = model(data, label) - - loss = loss.float() - if isinstance(model, ShardedModelV2): - optimizer.backward(loss) - else: - loss.backward() - optimizer.step() - - -@parameterize("cpu_offload", [True, False]) -@parameterize("use_cpuadam", [True, False]) -@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) -@parameterize("gpu_margin_mem_ratio", [0.0, 0.7]) -def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, gpu_margin_mem_ratio): - test_models = ['repeated_computed_layers', 'resnet18', 'bert', 'hanging_param_model'] - shard_strategy = shard_strategy_class() - - if use_cpuadam and cpu_offload is False: - return - if gpu_margin_mem_ratio > 0.0 and not (cpu_offload and use_cpuadam): - return - - for model_name in test_models: - get_components_func = non_distributed_component_funcs.get_callable(model_name) - model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func() - - with ZeroInitContext(target_device=torch.device(f'cpu:0') if cpu_offload else get_current_device(), - shard_strategy=shard_strategy, - shard_param=True): - zero_model = model_builder(checkpoint=True) - zero_model = ShardedModelV2( - zero_model, - shard_strategy, - tensor_placement_policy='cpu' if cpu_offload else 'auto', - reuse_fp16_shard=use_cpuadam, - ) - - model = model_builder(checkpoint=True).half() - col_model_deepcopy(zero_model, model) - model = model.cuda().float() - - if use_cpuadam: - optimizer_class = CPUAdam - optim = optimizer_class(model.parameters(), lr=1e-3) - sharded_optim = optimizer_class(zero_model.parameters(), lr=1e-3) - sharded_optim = ShardedOptimizerV2(zero_model, - sharded_optim, - initial_scale=2**5, - gpu_margin_mem_ratio=gpu_margin_mem_ratio) - - amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False) - apex_model, apex_optimizer = convert_to_apex_amp(model, optim, amp_config) - if dist.get_world_size() > 1: - apex_model = DDP(apex_model, device_ids=[torch.cuda.current_device()]) - - for i, (data, label) in enumerate(train_dataloader): - if i > 5: - break - data, label = data.cuda(), label.cuda() - _run_step(apex_model, apex_optimizer, data, label, criterion, False) - _run_step(zero_model, sharded_optim, data, label, criterion, False) - check_sharded_model_params(model, zero_model, loose=True, reuse_fp16_shard=use_cpuadam) - for param in model.parameters(): - assert not has_inf_or_nan(param) - - -def _run_dist(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - _run_test_sharded_optim_v2() - - -# use_cpuadam = True can be used with cpu_offload = False -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [1, 2]) -@rerun_if_address_is_in_use() -def test_sharded_optim_v2(world_size): - spawn(_run_dist, world_size) - - -if __name__ == '__main__': - test_sharded_optim_v2(world_size=2) diff --git a/tests/test_zero/test_legacy/test_sharded_optim_with_sync_bn.py b/tests/test_zero/test_legacy/test_sharded_optim_with_sync_bn.py deleted file mode 100644 index 0223f18c29d6..000000000000 --- a/tests/test_zero/test_legacy/test_sharded_optim_with_sync_bn.py +++ /dev/null @@ -1,87 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import pytest -import torch -import torch.distributed as dist -from torchvision.models import resnet50 - -import colossalai -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.zero.legacy.init_ctx import ZeroInitContext -from colossalai.zero.legacy.shard_utils import TensorShardStrategy - - -def run_dist(rank, world_size, port): - # this test only runs on resnet18 - # as this model has sync batch normalization - # need to configure cudnn deterministic so that - # randomness of convolution layers will be disabled - zero_config = dict(model_config=dict(shard_strategy=TensorShardStrategy())) - colossalai.launch(config=dict(zero=zero_config, cudnn_deterministic=True, cudnn_benchmark=False), - rank=rank, - world_size=world_size, - host='localhost', - port=port, - backend='nccl') - - with ZeroInitContext(target_device=torch.cuda.current_device(), - shard_strategy=gpc.config.zero.model_config.shard_strategy, - shard_param=True): - model = resnet50() - optimizer = torch.optim.Adam(model.parameters(), lr=0.001) - criterion = torch.nn.CrossEntropyLoss() - - engine, *args = colossalai.initialize(model, optimizer, criterion) - - # train for dummy iterations - engine.train() - for _ in range(2): - data = torch.rand(4, 3, 128, 128).cuda().half() - label = torch.randint(0, 10, size=(4,)).cuda() - engine.zero_grad() - out = engine(data) - loss = engine.criterion(out, label) - engine.backward(loss) - engine.step() - - # test - # need to make sure the batch norm stats are synchronized - # so that given the same input, the model will produce the same - # output on different ranks - engine.eval() - data = torch.rand(4, 3, 128, 128).cuda().half() - dist.broadcast(data, src=0, group=gpc.get_group(ParallelMode.DATA)) - - # predict - out = engine(data) - - # test if results are equal - tensor_list = [torch.empty_like(out) for _ in range(world_size - 1)] - tensor_list.insert(rank, out) - dist.all_gather(tensor_list=tensor_list, tensor=out, group=gpc.get_group(ParallelMode.DATA)) - - assert torch.all(tensor_list[0] == tensor_list[1]), \ - 'expected the output from different ranks to be the same, but got different values' - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_sharded_optim_with_sync_bn(): - """ - This test is to make sure that buffers are synchronized between ranks - when using ZeRO. An example of module buffer is the running stats of - BatchNormalization layer, i.e. mean and var. - - If the buffers are not synchronized, the model will produce different - output even though the input and parameters are the same. This is not - wanted if we are doing predictions. - - """ - spawn(run_dist, 2) - - -if __name__ == '__main__': - test_sharded_optim_with_sync_bn() diff --git a/tests/test_zero/test_legacy/test_state_dict.py b/tests/test_zero/test_legacy/test_state_dict.py deleted file mode 100644 index 5f76fff3e5c3..000000000000 --- a/tests/test_zero/test_legacy/test_state_dict.py +++ /dev/null @@ -1,55 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -from functools import partial - -import pytest -import torch -from common import CONFIG - -import colossalai -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.zero.legacy.init_ctx import ZeroInitContext -from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy -from colossalai.zero.legacy.sharded_model import ShardedModelV2 -from colossalai.zero.legacy.sharded_model.utils import col_model_deepcopy -from tests.components_to_test.registry import non_distributed_component_funcs - - -@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) -def run_zero_state_dict(shard_strategy_class): - test_models = ['repeated_computed_layers', 'resnet18'] - shard_strategy = shard_strategy_class() - for model_name in test_models: - get_components_func = non_distributed_component_funcs.get_callable(model_name) - model_builder, train_dataloader, test_dataloader, optimizer, criterion = get_components_func() - - with ZeroInitContext(target_device=torch.device('cuda', torch.cuda.current_device()), - shard_strategy=shard_strategy, - shard_param=True): - zero_model = model_builder(checkpoint=True) - zero_model = ShardedModelV2(zero_model, shard_strategy) - - model = model_builder(checkpoint=True).half() - col_model_deepcopy(zero_model, model) - model = model.cuda() - - zero_state_dict = zero_model.state_dict() - for key, val in model.state_dict().items(): - assert torch.equal(val, zero_state_dict[key].to(val.device)) - - -def run_dist(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_zero_state_dict() - - -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [1, 2]) -@rerun_if_address_is_in_use() -def test_zero_state_dict(world_size): - spawn(run_dist, world_size) - - -if __name__ == '__main__': - test_zero_state_dict(2) diff --git a/tests/test_zero/test_legacy/test_tensor_utils.py b/tests/test_zero/test_legacy/test_tensor_utils.py deleted file mode 100644 index 238bc3fe1a98..000000000000 --- a/tests/test_zero/test_legacy/test_tensor_utils.py +++ /dev/null @@ -1,94 +0,0 @@ -import pytest -import torch - -import colossalai -from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils.cuda import get_current_device -from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor -from colossalai.zero.legacy.gemini.tensor_utils import ( - colo_model_data_move_to_cpu, - colo_model_data_tensor_move, - colo_model_data_tensor_move_inline, - colo_model_tensor_clone, - colo_tensor_mem_usage, -) - - -def _run_colo_tensor_mem_usage(): - for i in range(1): - if i == 1: - t1 = StatefulTensor(torch.randn(2, 2)) - t2 = StatefulTensor(torch.randn(4, 4)) - c1, g1 = colo_tensor_mem_usage(t1) - c2, g2 = colo_tensor_mem_usage(t2) - assert c1 * 4 == c2 - assert g1 * 4 == g2 - else: - t1 = torch.randn(2, 2) - t2 = torch.randn(4, 4) - c1, g1 = colo_tensor_mem_usage(t1) - c2, g2 = colo_tensor_mem_usage(t2) - assert c1 * 4 == c2 - assert g1 * 4 == g2 - - -def _run_colo_model_data_tensor_move_inline(): - for t in [StatefulTensor(torch.randn(2, 3)), torch.randn(2, 3)]: - colo_model_data_tensor_move_inline(t, get_current_device()) - assert t.device == get_current_device() - - -def _run_colo_model_data_tensor_move(): - for t in [(StatefulTensor(torch.ones(2, 3)), StatefulTensor(torch.zeros(2, 3).to(get_current_device()))), - (torch.ones(2, 3), torch.zeros(2, 3).to(get_current_device()))]: - cpu_t, cuda_t = t - colo_model_data_tensor_move(cpu_t, cuda_t) - assert cuda_t.device == get_current_device() - - -def _run_colo_model_data_move_to_cpu(): - for t in [StatefulTensor(torch.randn(2, 2)), torch.randn(4, 4)]: - colo_model_data_move_to_cpu(t) - assert t.device == torch.device("cpu") - - -def _run_colo_model_tensor_clone(): - for t in [ - StatefulTensor(torch.randn(2, 2).cuda(torch.cuda.current_device())), - torch.randn(4, 4).cuda(torch.cuda.current_device()) - ]: - if issubclass(type(t), StatefulTensor): - assert t.payload.device == get_current_device() - else: - assert t.device == get_current_device() - p = colo_model_tensor_clone(t, get_current_device()) - assert p.device == get_current_device() - for i in range(2): - for j in range(2): - if issubclass(type(t), StatefulTensor): - assert t.payload.device == p.device - assert t.payload[i][j] == p[i][j] - else: - assert t.device == p.device - assert t[i][j] == p[i][j] - - -def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - - _run_colo_tensor_mem_usage() - _run_colo_model_data_tensor_move_inline() - _run_colo_model_data_tensor_move() - _run_colo_model_data_move_to_cpu() - _run_colo_model_tensor_clone() - - -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [2, 4]) -@rerun_if_address_is_in_use() -def test_zero_tensor_utils(world_size): - spawn(run_dist, world_size) - - -if __name__ == '__main__': - test_zero_tensor_utils(world_size=2) diff --git a/tests/test_zero/test_legacy/test_zero_engine.py b/tests/test_zero/test_legacy/test_zero_engine.py deleted file mode 100644 index 826a543db861..000000000000 --- a/tests/test_zero/test_legacy/test_zero_engine.py +++ /dev/null @@ -1,113 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import pytest -import torch -import torch.distributed as dist -from common import MP_PARALLEL_CONFIG, ZERO_PARALLEL_CONFIG, check_params, check_sharded_model_params -from torch.nn.parallel import DistributedDataParallel as DDP - -import colossalai -from colossalai.core import global_context as gpc -from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.zero.legacy.init_ctx import ZeroInitContext -from colossalai.zero.legacy.sharded_model.utils import col_model_deepcopy -from colossalai.zero.low_level._utils import has_inf_or_nan -from tests.components_to_test.registry import non_distributed_component_funcs - - -def run_dist(rank, world_size, port, parallel_config, bf16): - is_mp_config = parallel_config == MP_PARALLEL_CONFIG - is_zero_config = parallel_config == ZERO_PARALLEL_CONFIG - if bf16: - parallel_config['zero']['model_config']['bf16'] = True - colossalai.launch(config=parallel_config, - rank=rank, - world_size=world_size, - host='localhost', - port=port, - backend='nccl') - - test_models = ['repeated_computed_layers', 'resnet18', 'bert'] - for model_name in test_models: - get_components_func = non_distributed_component_funcs.get_callable(model_name) - model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func() - with ZeroInitContext(target_device=torch.cuda.current_device(), - shard_strategy=gpc.config.zero.model_config.shard_strategy, - shard_param=True, - bf16=bf16): - colo_model = model_builder(checkpoint=True) - - colo_optimizer = optimizer_class(colo_model.parameters(), lr=1e-3) - engine, train_dataloader, _, _ = colossalai.initialize(colo_model, - optimizer=colo_optimizer, - criterion=criterion, - train_dataloader=train_dataloader) - dtype = torch.bfloat16 if bf16 else torch.float16 - torch_model = model_builder(checkpoint=True).to(dtype) - col_model_deepcopy(engine.model, torch_model) - torch_model = torch_model.cuda().float() - - engine.train() - torch_optimizer = optimizer_class(torch_model.parameters(), lr=1e-3) - - if dist.get_world_size() > 1: - torch_model = DDP(torch_model, device_ids=[torch.cuda.current_device()]) - - i = 0 - for data, label in train_dataloader: - if i > 4: - break - - data, label = data.cuda(), label.cuda() - - engine.zero_grad() - torch_optimizer.zero_grad() - - if criterion: - output = engine(data) - loss = engine.criterion(output, label) - - torch_output = torch_model(data) - torch_loss = engine.criterion(torch_output, label) - else: - loss = engine(data, label) - torch_loss = torch_model(data, label) - - engine.backward(loss) - engine.step() - - torch_loss.backward() - - for param in torch_model.parameters(): - if param.grad is not None: - assert not has_inf_or_nan(param.grad) - - torch_optimizer.step() - i += 1 - - if is_mp_config: - check_params(torch_model, colo_model, loose=True) - elif is_zero_config: - check_sharded_model_params(torch_model, colo_model, loose=True) - - -# FIXME: enable this test in next PR -@pytest.mark.skip -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [2, 4]) -@rerun_if_address_is_in_use() -def test_mp_engine(world_size): - spawn(run_dist, world_size, parallel_config=MP_PARALLEL_CONFIG) - - -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [1, 2]) -@pytest.mark.parametrize("bf16", [True, False]) -@rerun_if_address_is_in_use() -def test_zero_engine(world_size, bf16): - spawn(run_dist, world_size, parallel_config=ZERO_PARALLEL_CONFIG, bf16=bf16) - - -if __name__ == '__main__': - test_zero_engine(world_size=4) diff --git a/tests/test_zero/test_low_level/test_grad_acc.py b/tests/test_zero/test_low_level/test_grad_acc.py index c264a8077d2a..a1d14f1d5a9d 100644 --- a/tests/test_zero/test_low_level/test_grad_acc.py +++ b/tests/test_zero/test_low_level/test_grad_acc.py @@ -9,6 +9,7 @@ import colossalai from colossalai.testing import spawn from colossalai.testing.random import seed_all +from colossalai.utils import conditional_context from colossalai.zero import LowLevelZeroOptimizer @@ -50,26 +51,27 @@ def exam_zero_1_2_grad_acc(): input_data1 = torch.randn(32, 128).cuda() input_data2 = torch.randn(32, 128).cuda() - def fwd_bwd_func(number, cur_data): + def fwd_bwd_func(number, cur_data, check_flag): # zero-dp forward zero1_output = zero1_model(cur_data) zero2_output = zero2_model(cur_data) assert torch.equal(zero1_output, zero2_output) # zero-dp backward - zero1_optimizer.backward(zero1_output.sum().float(), sync_grad=False) - zero2_optimizer.backward(zero2_output.sum().float(), sync_grad=False) - - for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()): - if z2p.grad is not None: - # print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad))) - assert torch.equal(z1p.grad, z2p.grad) + no_sync = number == 0 + with conditional_context(zero1_optimizer.no_sync(), no_sync): + zero1_optimizer.backward(zero1_output.sum().float()) + with conditional_context(zero2_optimizer.no_sync(), no_sync): + zero2_optimizer.backward(zero2_output.sum().float()) - zero1_optimizer._sync_grad() - zero2_optimizer._sync_grad() + if check_flag: + for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()): + if z2p.grad is not None: + # print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad))) + assert torch.equal(z1p.grad, z2p.grad) - fwd_bwd_func(0, input_data1) - fwd_bwd_func(1, input_data2) + fwd_bwd_func(0, input_data1, True) + fwd_bwd_func(1, input_data2, False) # step zero1_optimizer.step() @@ -111,26 +113,24 @@ def exam_zero_1_grad_acc(): input_data2 = torch.randn(32, 128).cuda() def fwd_bwd_func(number, cur_data, check_flag): - # zero-dp forward - zero_output = zero_model(cur_data) - # torch-ddp forward - torch_output = torch_model(cur_data) - assert torch.equal(zero_output, torch_output) + no_sync = number == 0 + # zero1 fwd and bwd + with conditional_context(zero_optimizer.no_sync(), no_sync): + zero_output = zero_model(cur_data) + zero_optimizer.backward(zero_output.sum().float()) - # zero-dp backward - zero_optimizer.backward(zero_output.sum().float(), sync_grad=False) - # torch-ddp backward - torch_output.sum().backward() + # torch-ddp fwd and bwd + with conditional_context(torch_model.no_sync(), no_sync): + torch_output = torch_model(cur_data) + assert torch.equal(zero_output, torch_output) + torch_output.sum().backward() if check_flag: # check grad for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): - # print(n, p.shape, torch.max(torch.abs(p.grad - unscale_grad))) assert torch.equal(p.grad, z1p.grad) - zero_optimizer._sync_grad() - fwd_bwd_func(0, input_data1, True) fwd_bwd_func(1, input_data2, False) @@ -148,7 +148,8 @@ def run_dist(rank, world_size, port): colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') exam_zero_1_grad_acc() - exam_zero_1_2_grad_acc() + # gradient accumulation is not compatible with ZeRO-2 + # exam_zero_1_2_grad_acc() @pytest.mark.dist diff --git a/tests/test_zero/test_low_level/test_zero1_2.py b/tests/test_zero/test_low_level/test_zero1_2.py index 8e2206fe6c8d..5a0609bff192 100644 --- a/tests/test_zero/test_low_level/test_zero1_2.py +++ b/tests/test_zero/test_low_level/test_zero1_2.py @@ -2,6 +2,7 @@ import pytest import torch +import torch.distributed as dist import torch.nn as nn from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing import assert_close @@ -16,8 +17,9 @@ class MlpModel(nn.Module): def __init__(self): super(MlpModel, self).__init__() - self.linear1 = nn.Linear(128, 256) - self.linear2 = nn.Linear(256, 512) + self.linear1 = nn.Linear(123, 253) + self.linear_drop = nn.Linear(253, 253) + self.linear2 = nn.Linear(253, 512) def forward(self, x): x = self.linear1(x) @@ -41,6 +43,16 @@ def loose_close(a, b, dtype: torch.dtype = torch.float32): assert_close(a, b, rtol=rtol, atol=atol) +def split_ddp_grad(grad, world_size): + with torch.no_grad(): + grad = grad.clone().detach().flatten() + padding_size = (world_size - grad.numel() % world_size) % world_size + if padding_size > 0: + grad = torch.nn.functional.pad(grad, [0, padding_size]) + splited_grad = grad.split(grad.numel() // world_size) + return splited_grad + + def exam_zero_1_2(): """ In this test, we want to test whether zero stage 1 and 2 @@ -72,23 +84,21 @@ def exam_zero_1_2(): initial_scale=128) # create data seed_all(2001 + local_rank) - input_data = torch.randn(32, 128).cuda() + input_data = torch.randn(32, 123).cuda() zero1_output = zero1_model(input_data) zero2_output = zero2_model(input_data) assert torch.equal(zero1_output, zero2_output) # zero-dp backward - zero1_optimizer.backward(zero1_output.mean().float(), sync_grad=False) - zero2_optimizer.backward(zero2_output.mean().float(), sync_grad=False) + zero1_optimizer.backward(zero1_output.mean().float()) + zero2_optimizer.backward(zero2_output.mean().float()) - for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()): - if z2p.grad is not None: - # print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad))) - assert torch.equal(z1p.grad, z2p.grad) - - zero1_optimizer._sync_grad() - zero2_optimizer._sync_grad() + # check grad + z1g_list = zero1_optimizer._grad_store.get_working_grads_by_group_id(0) + z2g_list = zero2_optimizer._grad_store.get_working_grads_by_group_id(0) + for z1g, z2g in zip(z1g_list, z2g_list): + assert torch.equal(z1g, z2g) # step zero1_optimizer.step() @@ -100,7 +110,7 @@ def exam_zero_1_2(): @parameterize('dtype', [torch.float16, torch.bfloat16]) -def exam_zero_1_torch_ddp(dtype: torch.dtype): +def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype): """ In this test, two pairs of model and optimizers are created. 1. zero: use sharded optimizer and fp16 parameters @@ -116,7 +126,7 @@ def exam_zero_1_torch_ddp(dtype: torch.dtype): torch_model = MlpModel().cuda() zero_model = copy.deepcopy(torch_model).to(dtype) - torch_model = DDP(torch_model.cuda(), bucket_cap_mb=0).cuda() + torch_model = DDP(torch_model.cuda(), static_graph=True).cuda() # create optimizer zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1) @@ -133,7 +143,7 @@ def exam_zero_1_torch_ddp(dtype: torch.dtype): seed_all(1453 + local_rank) # create - input_data = torch.rand(32, 128).cuda() + input_data = torch.rand(32, 123).cuda() # zero-dp forward zero_output = zero_model(input_data.to(dtype)) @@ -143,17 +153,20 @@ def exam_zero_1_torch_ddp(dtype: torch.dtype): loose_close(zero_output, torch_output, dtype=dtype) # zero-dp backward - zero_optimizer.backward(zero_output.mean().float(), sync_grad=False) + zero_optimizer.backward(zero_output.mean().float()) # torch-ddp backward torch_output.mean().backward() # check grad for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): - loose_close(p.grad, z1p.grad, dtype=dtype) + if p.grad is not None: + zero_grad_list = zero_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(z1p)) + torch_grad_list = split_ddp_grad(p.grad, world_size) + for zero_grad, torch_grad in zip(zero_grad_list, torch_grad_list): + loose_close(zero_grad, torch_grad, dtype=dtype) # zero-dp step - zero_optimizer._sync_grad() zero_optimizer.step() # torch ddp step @@ -161,14 +174,13 @@ def exam_zero_1_torch_ddp(dtype: torch.dtype): # check updated param for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): - # print(n, torch.max(torch.abs(p.data - z1p.data))) loose_close(p.data, z1p.data, dtype=dtype) def run_dist(rank, world_size, port): colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') - exam_zero_1_torch_ddp() + exam_zero_1_torch_ddp(world_size=world_size) exam_zero_1_2() diff --git a/tests/test_zero/test_low_level/test_zero_ckpt.py b/tests/test_zero/test_low_level/test_zero_ckpt.py new file mode 100644 index 000000000000..23356fe718a6 --- /dev/null +++ b/tests/test_zero/test_low_level/test_zero_ckpt.py @@ -0,0 +1,121 @@ +import copy + +import pytest +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.testing import assert_close + +import colossalai +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.testing.random import seed_all +from colossalai.zero import LowLevelZeroOptimizer + + +class MlpModel(nn.Module): + + def __init__(self): + super(MlpModel, self).__init__() + self.linear1 = nn.Linear(12, 24) + self.linear2 = nn.Linear(24, 12) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + +def loose_close(a, b, dtype: torch.dtype = torch.float32): + rtol = None + atol = None + if dtype is torch.float16: + rtol = 5e-2 + atol = 5e-4 + elif dtype is torch.bfloat16: + rtol = 4e-3 + atol = 4e-3 + + a = a.detach().to(dtype) + b = b.detach().to(dtype) + + assert_close(a, b, rtol=rtol, atol=atol) + + +def exam_zero_1_torch_ddp_ckpt(): + """ + We examine the state_dict of zero and DDP. + Moreover, we examine the zero's loading checkpoint of a torch ckpt. + """ + local_rank = torch.distributed.get_rank() + seed_all(1453) + + # create models + torch_model = MlpModel().cuda() + zero_model = copy.deepcopy(torch_model) + + torch_model = DDP(torch_model.cuda(), static_graph=True).cuda() + + # create optimizer + zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1) + + # we only test stage 1 here + # the state dicts of stage 1 and stage 2 are the same + zero_optimizer = LowLevelZeroOptimizer(zero_optimizer, + overlap_communication=True, + initial_scale=1, + reduce_bucket_size=262144) + + torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=1) + + seed_all(1453 + local_rank) + # create + input_data = torch.rand(4, 12).cuda() + + # forward + zero_output = zero_model(input_data) + torch_output = torch_model(input_data) + + # backward + zero_optimizer.backward(zero_output.mean().float()) + torch_output.mean().backward() + + # step + zero_optimizer.step() + torch_optimizer.step() + + torch_state_dict = torch_optimizer.state_dict() + zero_state_dict = zero_optimizer.state_dict() + + # examine the original state dict + for torch_state, zero_state in zip(torch_state_dict['state'].values(), zero_state_dict['state'].values()): + for t_v, z_v in zip(torch_state.values(), zero_state.values()): + loose_close(t_v, z_v) + + # empty the optimzer state + zero_optimizer.optim.state = [] + + # zero load a torch checkpoint + zero_optimizer.load_state_dict(copy.deepcopy(torch_state_dict)) + zero_state_dict = zero_optimizer.state_dict() + + # examine the loaded state dict + for torch_state, zero_state in zip(torch_state_dict['state'].values(), zero_state_dict['state'].values()): + for t_v, z_v in zip(torch_state.values(), zero_state.values()): + loose_close(t_v, z_v) + + +def run_dist(rank, world_size, port): + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + + exam_zero_1_torch_ddp_ckpt() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_zero_ckpt(): + spawn(run_dist, 2) + + +if __name__ == '__main__': + test_zero_ckpt() diff --git a/tests/test_zero/test_low_level/test_zero_init.py b/tests/test_zero/test_low_level/test_zero_init.py index aeeaff5b5cb9..368ef976ef6e 100644 --- a/tests/test_zero/test_low_level/test_zero_init.py +++ b/tests/test_zero/test_low_level/test_zero_init.py @@ -33,10 +33,9 @@ def exam_zero_init(): assert optimizer1._local_rank == optimizer2._local_rank assert optimizer1._world_size == optimizer2._world_size - assert optimizer1._dp_global_ranks == optimizer2._dp_global_ranks - mp_group1 = optimizer1._mp_torch_group - mp_group2 = optimizer2._mp_torch_group + mp_group1 = optimizer1.tp_pg + mp_group2 = optimizer2.tp_pg assert dist.get_world_size(mp_group1) == dist.get_world_size(mp_group2) assert dist.get_rank(mp_group1) == dist.get_rank(mp_group2) diff --git a/tests/test_zero/test_low_level/test_zero_tp.py b/tests/test_zero/test_low_level/test_zero_tp.py index f0804f4bb5ba..238de3334c80 100644 --- a/tests/test_zero/test_low_level/test_zero_tp.py +++ b/tests/test_zero/test_low_level/test_zero_tp.py @@ -57,7 +57,9 @@ def exam_zero_with_tp(overlap_flag, partition_flag): initial_scale=2, clip_grad_norm=1.0, overlap_communication=overlap_flag, - partition_grad=partition_flag) + partition_grad=partition_flag, + dp_process_group=tp_pg.dp_process_group(), + tp_process_group=tp_pg.tp_process_group()) dp_local_rank = tp_pg.dp_local_rank() set_seed(255 + dp_local_rank) diff --git a/version.txt b/version.txt index 0d91a54c7d43..9e11b32fcaa9 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.3.0 +0.3.1